|
12 | 12 | import weakref |
13 | 13 |
|
14 | 14 | from sqlalchemy import and_, case, select, event, inspection |
| 15 | +from sqlalchemy.orm import object_session |
15 | 16 | from sqlalchemy.orm.base import NO_VALUE |
16 | 17 | from sqlalchemy.sql import func |
17 | 18 |
|
@@ -325,14 +326,24 @@ def pop(self): |
325 | 326 | return self.popitem()[0] |
326 | 327 |
|
327 | 328 |
|
| 329 | +class _WeakDefaultDict(weakref.WeakKeyDictionary, object): |
| 330 | + |
| 331 | + def __getitem__(self, key): |
| 332 | + try: |
| 333 | + return super(_WeakDefaultDict, self).__getitem__(key) |
| 334 | + except KeyError: |
| 335 | + self[key] = value = _WeakDictBasedSet() |
| 336 | + return value |
| 337 | + |
| 338 | + |
328 | 339 | class TreesManager(object): |
329 | 340 | """ |
330 | 341 | Manages events dispatching for all subclasses of a given class. |
331 | 342 | """ |
332 | 343 | def __init__(self, base_class): |
333 | 344 | self.base_class = base_class |
334 | 345 | self.classes = set() |
335 | | - self.instances = _WeakDictBasedSet() |
| 346 | + self.instances = _WeakDefaultDict() |
336 | 347 |
|
337 | 348 | def register_mapper(self, mapper): |
338 | 349 | for e, h in ( |
@@ -385,29 +396,37 @@ def register_factory(self, sessionmaker): |
385 | 396 | return sessionmaker |
386 | 397 |
|
387 | 398 | def before_insert(self, mapper, connection, instance): |
388 | | - self.instances.add(instance) |
| 399 | + session = object_session(instance) |
| 400 | + self.instances[session].add(instance) |
389 | 401 | mptt_before_insert(mapper, connection, instance) |
390 | 402 |
|
391 | 403 | def before_update(self, mapper, connection, instance): |
392 | | - self.instances.add(instance) |
| 404 | + session = object_session(instance) |
| 405 | + self.instances[session].add(instance) |
393 | 406 | mptt_before_update(mapper, connection, instance) |
394 | 407 |
|
395 | 408 | def before_delete(self, mapper, connection, instance): |
396 | | - self.instances.discard(instance) |
| 409 | + session = object_session(instance) |
| 410 | + self.instances[session].discard(instance) |
397 | 411 | mptt_before_delete(mapper, connection, instance) |
398 | 412 |
|
399 | 413 | def after_flush_postexec(self, session, context): |
400 | 414 | """ |
401 | 415 | Event listener to recursively expire `left` and `right` attributes the |
402 | 416 | parents of all modified instances part of this flush. |
403 | 417 | """ |
404 | | - while self.instances: |
405 | | - instance = self.instances.pop() |
| 418 | + instances = self.instances[session] |
| 419 | + while instances: |
| 420 | + instance = instances.pop() |
| 421 | + if instance not in session: |
| 422 | + continue |
406 | 423 | parent = self.get_parent_value(instance) |
407 | 424 | while parent != NO_VALUE and parent is not None: |
408 | | - self.instances.discard(parent) |
409 | | - session.expire(parent, ['left', 'right', 'tree_id']) |
| 425 | + instances.discard(parent) |
| 426 | + session.expire(parent, ['left', 'right']) |
410 | 427 | parent = self.get_parent_value(parent) |
| 428 | + else: |
| 429 | + session.expire(instance, ['tree_id', 'level']) |
411 | 430 |
|
412 | 431 | @staticmethod |
413 | 432 | def get_parent_value(instance): |
|
0 commit comments