Skip to content

Commit cdeda94

Browse files
committed
Flushing the session now expire the instance and it's children fixes #33
1 parent fa98828 commit cdeda94

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

sqlalchemy_mptt/events.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"""
1212
import weakref
1313

14-
from sqlalchemy import and_, case, select, event, inspection
14+
from sqlalchemy import and_, case, event, inspection, select
1515
from sqlalchemy.orm import object_session
1616
from sqlalchemy.orm.base import NO_VALUE
1717
from sqlalchemy.sql import func
@@ -423,13 +423,30 @@ def after_flush_postexec(self, session, context):
423423
if instance not in session:
424424
continue
425425
parent = self.get_parent_value(instance)
426+
426427
while parent != NO_VALUE and parent is not None:
427428
instances.discard(parent)
428-
session.expire(parent, ['left', 'right'])
429+
session.expire(parent, ['left', 'right', 'tree_id', 'level'])
429430
parent = self.get_parent_value(parent)
430431
else:
431-
session.expire(instance, ['tree_id', 'level'])
432+
session.expire(instance, ['left', 'right', 'tree_id', 'level'])
433+
self.expire_session_for_children(session, instance)
432434

433435
@staticmethod
434436
def get_parent_value(instance):
435437
return inspection.inspect(instance).attrs.parent.loaded_value
438+
439+
@staticmethod
440+
def expire_session_for_children(session, instance):
441+
children = instance.children
442+
443+
def expire_recursively(node):
444+
children = node.children
445+
for item in children:
446+
session.expire(item, ['left', 'right', 'tree_id', 'level'])
447+
expire_recursively(item)
448+
449+
if children != NO_VALUE and children is not None:
450+
for item in children:
451+
session.expire(item, ['left', 'right', 'tree_id', 'level'])
452+
expire_recursively(item)

sqlalchemy_mptt/tests/tree_testing_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def test_tree_orm_initialize(self):
171171
self.assertEqual(t5.right, 7)
172172

173173
def test_flush_with_transient_nodes_present(self):
174+
"""https://github.com/ITCase/sqlalchemy_mptt/issues/34"""
174175
transient_node = self.model(ppk=1, parent=None)
175176
self.session.add(transient_node)
176177
try:

0 commit comments

Comments
 (0)