Skip to content

Commit 7bf3e9e

Browse files
committed
Merge branch 'develop' into path-33
Conflicts: sqlalchemy_mptt/events.py
2 parents 55201c9 + 9f8b1ce commit 7bf3e9e

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

sqlalchemy_mptt/events.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import weakref
1313

1414
from sqlalchemy import and_, case, select, event, inspection
15+
from sqlalchemy.orm import object_session
1516
from sqlalchemy.orm.base import NO_VALUE
1617
from sqlalchemy.sql import func
1718

@@ -325,14 +326,24 @@ def pop(self):
325326
return self.popitem()[0]
326327

327328

329+
class _WeakDefaultDict(weakref.WeakKeyDictionary, object):
330+
331+
def __getitem__(self, key):
332+
try:
333+
return super(_WeakDefaultDict, self).__getitem__(key)
334+
except KeyError:
335+
self[key] = value = _WeakDictBasedSet()
336+
return value
337+
338+
328339
class TreesManager(object):
329340
"""
330341
Manages events dispatching for all subclasses of a given class.
331342
"""
332343
def __init__(self, base_class):
333344
self.base_class = base_class
334345
self.classes = set()
335-
self.instances = _WeakDictBasedSet()
346+
self.instances = _WeakDefaultDict()
336347

337348
def register_mapper(self, mapper):
338349
for e, h in (
@@ -385,29 +396,37 @@ def register_factory(self, sessionmaker):
385396
return sessionmaker
386397

387398
def before_insert(self, mapper, connection, instance):
388-
self.instances.add(instance)
399+
session = object_session(instance)
400+
self.instances[session].add(instance)
389401
mptt_before_insert(mapper, connection, instance)
390402

391403
def before_update(self, mapper, connection, instance):
392-
self.instances.add(instance)
404+
session = object_session(instance)
405+
self.instances[session].add(instance)
393406
mptt_before_update(mapper, connection, instance)
394407

395408
def before_delete(self, mapper, connection, instance):
396-
self.instances.discard(instance)
409+
session = object_session(instance)
410+
self.instances[session].discard(instance)
397411
mptt_before_delete(mapper, connection, instance)
398412

399413
def after_flush_postexec(self, session, context):
400414
"""
401415
Event listener to recursively expire `left` and `right` attributes the
402416
parents of all modified instances part of this flush.
403417
"""
404-
while self.instances:
405-
instance = self.instances.pop()
418+
instances = self.instances[session]
419+
while instances:
420+
instance = instances.pop()
421+
if instance not in session:
422+
continue
406423
parent = self.get_parent_value(instance)
407424
while parent != NO_VALUE and parent is not None:
408-
self.instances.discard(parent)
409-
session.expire(parent, ['left', 'right', 'tree_id'])
425+
instances.discard(parent)
426+
session.expire(parent, ['left', 'right'])
410427
parent = self.get_parent_value(parent)
428+
else:
429+
session.expire(instance, ['tree_id', 'level'])
411430

412431
@staticmethod
413432
def get_parent_value(instance):

sqlalchemy_mptt/tests/tree_testing_base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
from sqlalchemy import create_engine
11+
from sqlalchemy.exc import IntegrityError
1112
from sqlalchemy.orm import sessionmaker
1213

1314
from sqlalchemy_mptt import mptt_sessionmaker
@@ -169,6 +170,17 @@ def test_tree_orm_initialize(self):
169170
self.assertEqual(t5.left, 6)
170171
self.assertEqual(t5.right, 7)
171172

173+
def test_flush_with_transient_nodes_present(self):
174+
transient_node = self.model(ppk=1, parent=None)
175+
self.session.add(transient_node)
176+
try:
177+
self.session.flush()
178+
except IntegrityError:
179+
pass
180+
self.session.rollback()
181+
self.session.add(self.model(ppk=46, parent=None))
182+
self.session.flush()
183+
172184
def test_tree_initialize(self):
173185
""" Initial state of the trees
174186
@@ -1726,6 +1738,7 @@ def test_session_expire_for_move_after_to_new_tree(self):
17261738
self.session.flush()
17271739

17281740
self.assertEqual(node.tree_id, 2)
1741+
self.assertEqual(node.level, 1)
17291742
self.assertEqual(node.parent_id, None)
17301743
self.assertEqual(children[0].tree_id, 2)
17311744
self.assertEqual(children[1].tree_id, 2)

0 commit comments

Comments
 (0)