Skip to content

Commit 6a55521

Browse files
committed
Remove recursion from BaseNestedSets.get_tree method #39
1 parent 7b5b5df commit 6a55521

File tree

3 files changed

+152
-30
lines changed

3 files changed

+152
-30
lines changed

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
all: test
22

33
test:
4-
nosetests --with-coverage --nocapture --cover-package=sqlalchemy_mptt --cover-erase --with-doctest
4+
nosetests --with-coverage --cover-package=sqlalchemy_mptt --cover-erase --with-doctest
5+
6+
nocapture:
7+
nosetests --with-coverage --cover-package=sqlalchemy_mptt --cover-erase --with-doctest --nocapture

sqlalchemy_mptt/mixins.py

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,13 @@ def parent_id(cls):
8888
@declared_attr
8989
def parent(cls):
9090
pk = getattr(cls, cls.get_pk())
91-
return relationship(cls, primaryjoin=lambda: pk == cls.parent_id,
92-
order_by=lambda: cls.left,
93-
backref=backref('children', cascade="all,delete",
94-
order_by=lambda: cls.left),
95-
remote_side=cls.get_class_pk(), # for show in sacrud relation
96-
)
91+
return relationship(
92+
cls, primaryjoin=lambda: pk == cls.parent_id,
93+
order_by=lambda: cls.left,
94+
backref=backref('children', cascade="all,delete",
95+
order_by=lambda: cls.left),
96+
remote_side=cls.get_class_pk(), # for show in sacrud relation
97+
)
9798

9899
@declared_attr
99100
def left(cls):
@@ -133,7 +134,7 @@ def move_inside(self, parent_id):
133134
134135
* :mod:`sqlalchemy_mptt.tests.TestTree.test_move_inside_function`
135136
* :mod:`sqlalchemy_mptt.tests.TestTree.test_move_inside_to_the_same_parent_function`
136-
"""
137+
""" # noqa
137138
session = Session.object_session(self)
138139
self.parent_id = parent_id
139140
self.mptt_move_inside = parent_id
@@ -143,7 +144,7 @@ def move_after(self, node_id):
143144
""" Moving one node of tree after another
144145
145146
For example see :mod:`sqlalchemy_mptt.tests.TestTree.test_move_after_function`
146-
"""
147+
""" # noqa
147148
session = Session.object_session(self)
148149
self.parent_id = self.parent_id
149150
self.mptt_move_after = node_id
@@ -170,7 +171,7 @@ def leftsibling_in_level(self):
170171
""" Node to the left of the current node at the same level
171172
172173
For example see :mod:`sqlalchemy_mptt.tests.TestTree.test_leftsibling_in_level`
173-
"""
174+
""" # noqa
174175
table = _get_tree_table(self.__mapper__)
175176
session = Session.object_session(self)
176177
current_lvl_nodes = session.query(table)\
@@ -181,8 +182,23 @@ def leftsibling_in_level(self):
181182
return None
182183

183184
@classmethod
184-
def get_tree(cls, session, json=False, json_fields=None):
185-
""" This function generate tree of current node in dict or json format.
185+
def _get_tree_node(cls, node, json, json_fields):
186+
""" Helper method for ``get_tree`` and ``get_tree_reqursively``.
187+
"""
188+
if json:
189+
pk = getattr(node, node.get_pk())
190+
# jqTree or jsTree format
191+
result = {'id': pk, 'label': node.__repr__()}
192+
if json_fields:
193+
result.update(json_fields(node))
194+
else:
195+
result = {'node': node}
196+
return result
197+
198+
@classmethod
199+
def get_tree_reqursively(cls, session, json=False, json_fields=None):
200+
""" This function recursively generate tree of current node in dict or
201+
json format.
186202
187203
Args:
188204
session (:mod:`sqlalchemy.orm.session.Session`): SQLAlchemy session
@@ -196,15 +212,9 @@ def get_tree(cls, session, json=False, json_fields=None):
196212
* :mod:`sqlalchemy_mptt.tests.TestTree.test_get_tree`
197213
* :mod:`sqlalchemy_mptt.tests.TestTree.test_get_json_tree`
198214
* :mod:`sqlalchemy_mptt.tests.TestTree.test_get_json_tree_with_custom_field`
199-
"""
215+
""" # noqa
200216
def recursive_node_to_dict(node):
201-
result = {'node': node}
202-
pk = getattr(node, node.get_pk())
203-
if json:
204-
# jqTree or jsTree format
205-
result = {'id': pk, 'label': node.__repr__()}
206-
if json_fields:
207-
result.update(json_fields(node))
217+
result = cls._get_tree_node(node, json, json_fields)
208218
children = [recursive_node_to_dict(c) for c in node.children]
209219
if children:
210220
result['children'] = children
@@ -218,6 +228,48 @@ def recursive_node_to_dict(node):
218228

219229
return tree
220230

231+
@classmethod
232+
def get_tree(cls, session, json=False, json_fields=None):
233+
""" This function generate tree of current node in dict or json format.
234+
235+
Args:
236+
session (:mod:`sqlalchemy.orm.session.Session`): SQLAlchemy session
237+
238+
Kwargs:
239+
json (bool): if True return JSON jqTree format
240+
json_fields (function): append custom fields in JSON
241+
242+
Example:
243+
244+
* :mod:`sqlalchemy_mptt.tests.TestTree.test_get_tree`
245+
* :mod:`sqlalchemy_mptt.tests.TestTree.test_get_json_tree`
246+
* :mod:`sqlalchemy_mptt.tests.TestTree.test_get_json_tree_with_custom_field`
247+
""" # noqa
248+
nodes = session.query(cls).order_by(cls.level).all()
249+
tree = []
250+
nodes_of_level = {}
251+
252+
def get_node_id(node):
253+
return getattr(node, node.get_pk())
254+
255+
for node in nodes:
256+
result = cls._get_tree_node(node, json, json_fields)
257+
parent_id = node.parent_id
258+
# Parent detect!
259+
if parent_id:
260+
# Find parent in tree list!
261+
if parent_id in nodes_of_level.keys():
262+
if 'children' not in nodes_of_level[parent_id]:
263+
nodes_of_level[parent_id]['children'] = []
264+
# Append to parent!
265+
nl = nodes_of_level[parent_id]['children']
266+
nl.append(result)
267+
nodes_of_level[get_node_id(node)] = nl[-1]
268+
else:
269+
tree.append(result)
270+
nodes_of_level[get_node_id(node)] = tree[-1]
271+
return tree
272+
221273
@classmethod
222274
def rebuild_tree(cls, session, tree_id):
223275
""" This function rebuid tree.

sqlalchemy_mptt/tests/tree_testing_base.py

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# Distributed under terms of the MIT license.
88

99

10-
from sqlalchemy import create_engine
10+
from sqlalchemy import create_engine, event
1111
from sqlalchemy.exc import IntegrityError
1212
from sqlalchemy.orm import sessionmaker
1313

@@ -103,6 +103,18 @@ class TreeTestingMixin(object):
103103
base = None
104104
model = None
105105

106+
def catch_queries(self, conn, cursor, statement, *args):
107+
self.stmts.append(statement)
108+
109+
def start_query_counter(self):
110+
self.stmts = []
111+
event.listen(self.session.bind.engine, "before_cursor_execute",
112+
self.catch_queries)
113+
114+
def stop_query_counter(self):
115+
event.remove(self.session.bind.engine, "before_cursor_execute",
116+
self.catch_queries)
117+
106118
def setUp(self):
107119
self.engine = create_engine('sqlite:///:memory:')
108120
Session = mptt_sessionmaker(sessionmaker(bind=self.engine))
@@ -754,7 +766,7 @@ def test_move_tree_to_another_tree(self):
754766
| |
755767
6 26(20)27 30(22)31
756768
757-
"""
769+
""" # noqa
758770
node = self.session.query(self.model).\
759771
filter(self.model.ppk == 12).one()
760772
node.parent_id = 7
@@ -928,7 +940,8 @@ def test_move_to_toplevel_where_much_trees_from_right_side(self):
928940
4 8(20)9 12(22)13
929941
930942
"""
931-
node = self.session.query(self.model).filter(self.model.ppk == 15).one()
943+
node = self.session.query(self.model)\
944+
.filter(self.model.ppk == 15).one()
932945
node.move_after("1")
933946
# id lft rgt lvl parent tree
934947
self.assertEqual([(1, 1, 22, 1, None, 1),
@@ -957,7 +970,8 @@ def test_move_to_toplevel_where_much_trees_from_right_side(self):
957970
(21, 11, 14, 3, 18, 3),
958971
(22, 12, 13, 4, 21, 3)], self.result.all())
959972

960-
node = self.session.query(self.model).filter(self.model.ppk == 20).one()
973+
node = self.session.query(self.model)\
974+
.filter(self.model.ppk == 20).one()
961975
node.move_after("1")
962976
""" level tree_id = 1
963977
1 1(1)22
@@ -1254,11 +1268,55 @@ def test_get_tree(self):
12541268
tree = Tree.get_tree(self.session)
12551269
"""
12561270
tree = self.model.get_tree(self.session)
1271+
tree_reqursively = self.model.get_tree_reqursively(self.session)
12571272

12581273
def go(id):
12591274
return get_obj(self.session, self.model, id)
1260-
self.assertEqual(tree,
1261-
[{'node': go(1), 'children': [{'node': go(2), 'children': [{'node': go(3)}]}, {'node': go(4), 'children': [{'node': go(5)}, {'node': go(6)}]}, {'node': go(7), 'children': [{'node': go(8), 'children': [{'node': go(9)}]}, {'node': go(10), 'children': [{'node': go(11)}]}]}]}, {'node': go(12), 'children': [{'node': go(13), 'children': [{'node': go(14)}]}, {'node': go(15), 'children': [{'node': go(16)}, {'node': go(17)}]}, {'node': go(18), 'children': [{'node': go(19), 'children': [{'node': go(20)}]}, {'node': go(21), 'children': [{'node': go(22)}]}]}]}])
1275+
1276+
reference_tree = [{'node': go(1), 'children': [{'node': go(2), 'children': [{'node': go(3)}]}, {'node': go(4), 'children': [{'node': go(5)}, {'node': go(6)}]}, {'node': go(7), 'children': [{'node': go(8), 'children': [{'node': go(9)}]}, {'node': go(10), 'children': [{'node': go(11)}]}]}]}, {'node': go(12), 'children': [{'node': go(13), 'children': [{'node': go(14)}]}, {'node': go(15), 'children': [{'node': go(16)}, {'node': go(17)}]}, {'node': go(18), 'children': [{'node': go(19), 'children': [{'node': go(20)}]}, {'node': go(21), 'children': [{'node': go(22)}]}]}]}] # noqa
1277+
1278+
self.assertEqual(tree, reference_tree)
1279+
self.assertEqual(tree_reqursively, reference_tree)
1280+
1281+
def test_get_tree_count_query(self):
1282+
"""
1283+
Count num of queries to the database.
1284+
See https://github.com/ITCase/sqlalchemy_mptt/issues/39
1285+
1286+
1287+
Use ``--nocapture`` option for show run time:
1288+
1289+
::
1290+
1291+
nosetests sqlalchemy_mptt.tests.test_events:TestTree.test_get_tree_count_query --nocapture
1292+
Get tree: 0:00:00.001817
1293+
Get tree reqursively: 0:00:00.020615
1294+
.
1295+
----------------------------------------------------------------------
1296+
Ran 1 test in 0.064s
1297+
1298+
OK
1299+
""" # noqa
1300+
from datetime import datetime
1301+
self.session.commit()
1302+
1303+
# Get tree by for cycle
1304+
self.start_query_counter()
1305+
self.assertEqual(0, len(self.stmts))
1306+
startTime = datetime.now()
1307+
self.model.get_tree(self.session)
1308+
print("Get tree: {!s:>26}".format(datetime.now() - startTime))
1309+
self.assertEqual(1, len(self.stmts))
1310+
self.stop_query_counter()
1311+
1312+
# Get tree by recursion
1313+
self.start_query_counter()
1314+
self.assertEqual(0, len(self.stmts))
1315+
startTime = datetime.now()
1316+
self.model.get_tree_reqursively(self.session)
1317+
print("Get tree reqursively: {}".format(datetime.now() - startTime))
1318+
self.assertEqual(23, len(self.stmts))
1319+
self.stop_query_counter()
12621320

12631321
def test_get_json_tree(self):
12641322
""".. note::
@@ -1271,9 +1329,13 @@ def test_get_json_tree(self):
12711329
12721330
tree = Tree.get_tree(self.session, json=True)
12731331
"""
1332+
reference_tree = [{'children': [{'children': [{'id': 3, 'label': '<Node (3)>'}], 'id': 2, 'label': '<Node (2)>'}, {'children': [{'id': 5, 'label': '<Node (5)>'}, {'id': 6, 'label': '<Node (6)>'}], 'id': 4, 'label': '<Node (4)>'}, {'children': [{'children': [{'id': 9, 'label': '<Node (9)>'}], 'id': 8, 'label': '<Node (8)>'}, {'children': [{'id': 11, 'label': '<Node (11)>'}], 'id': 10, 'label': '<Node (10)>'}], 'id': 7, 'label': '<Node (7)>'}], 'id': 1, 'label': '<Node (1)>'}, {'children': [{'children': [{'id': 14, 'label': '<Node (14)>'}], 'id': 13, 'label': '<Node (13)>'}, {'children': [{'id': 16, 'label': '<Node (16)>'}, {'id': 17, 'label': '<Node (17)>'}], 'id': 15, 'label': '<Node (15)>'}, {'children': [{'children': [{'id': 20, 'label': '<Node (20)>'}], 'id': 19, 'label': '<Node (19)>'}, {'children': [{'id': 22, 'label': '<Node (22)>'}], 'id': 21, 'label': '<Node (21)>'}], 'id': 18, 'label': '<Node (18)>'}], 'id': 12, 'label': '<Node (12)>'}] # noqa
1333+
12741334
tree = self.model.get_tree(self.session, json=True)
1275-
self.assertEqual(tree, [{'children': [{'children': [{'id': 3, 'label': '<Node (3)>'}], 'id': 2, 'label': '<Node (2)>'}, {'children': [{'id': 5, 'label': '<Node (5)>'}, {'id': 6, 'label': '<Node (6)>'}], 'id': 4, 'label': '<Node (4)>'}, {'children': [{'children': [{'id': 9, 'label': '<Node (9)>'}], 'id': 8, 'label': '<Node (8)>'}, {'children': [{'id': 11, 'label': '<Node (11)>'}], 'id': 10, 'label': '<Node (10)>'}], 'id': 7, 'label': '<Node (7)>'}], 'id': 1, 'label': '<Node (1)>'}, {
1276-
'children': [{'children': [{'id': 14, 'label': '<Node (14)>'}], 'id': 13, 'label': '<Node (13)>'}, {'children': [{'id': 16, 'label': '<Node (16)>'}, {'id': 17, 'label': '<Node (17)>'}], 'id': 15, 'label': '<Node (15)>'}, {'children': [{'children': [{'id': 20, 'label': '<Node (20)>'}], 'id': 19, 'label': '<Node (19)>'}, {'children': [{'id': 22, 'label': '<Node (22)>'}], 'id': 21, 'label': '<Node (21)>'}], 'id': 18, 'label': '<Node (18)>'}], 'id': 12, 'label': '<Node (12)>'}])
1335+
tree_reqursively = self.model.get_tree_reqursively(self.session,
1336+
json=True)
1337+
self.assertEqual(tree, reference_tree)
1338+
self.assertEqual(tree_reqursively, reference_tree)
12771339

12781340
def test_get_json_tree_with_custom_field(self):
12791341
""".. note::
@@ -1292,9 +1354,14 @@ def fields(node):
12921354
"""
12931355
def fields(node):
12941356
return {'visible': node.visible}
1357+
1358+
reference_tree = [{'visible': None, 'children': [{'visible': True, 'children': [{'visible': True, 'id': 3, 'label': '<Node (3)>'}], 'id': 2, 'label': '<Node (2)>'}, {'visible': True, 'children': [{'visible': True, 'id': 5, 'label': '<Node (5)>'}, {'visible': True, 'id': 6, 'label': '<Node (6)>'}], 'id': 4, 'label': '<Node (4)>'}, {'visible': True, 'children': [{'visible': True, 'children': [{'visible': None, 'id': 9, 'label': '<Node (9)>'}], 'id': 8, 'label': '<Node (8)>'}, {'visible': None, 'children': [{'visible': None, 'id': 11, 'label': '<Node (11)>'}], 'id': 10, 'label': '<Node (10)>'}], 'id': 7, 'label': '<Node (7)>'}], 'id': 1, 'label': '<Node (1)>'}, {'visible': None, 'children': [{'visible': None, 'children': [{'visible': None, 'id': 14, 'label': '<Node (14)>'}], 'id': 13, 'label': '<Node (13)>'}, {'visible': None, 'children': [{'visible': None, 'id': 16, 'label': '<Node (16)>'}, {'visible': None, 'id': 17, 'label': '<Node (17)>'}], 'id': 15, 'label': '<Node (15)>'}, {'visible': None, 'children': [{'visible': None, 'children': [{'visible': None, 'id': 20, 'label': '<Node (20)>'}], 'id': 19, 'label': '<Node (19)>'}, {'visible': None, 'children': [{'visible': None, 'id': 22, 'label': '<Node (22)>'}], 'id': 21, 'label': '<Node (21)>'}], 'id': 18, 'label': '<Node (18)>'}], 'id': 12, 'label': '<Node (12)>'}] # noqa
1359+
12951360
tree = self.model.get_tree(self.session, json=True, json_fields=fields)
1296-
self.assertEqual(tree, [{'visible': None, 'children': [{'visible': True, 'children': [{'visible': True, 'id': 3, 'label': '<Node (3)>'}], 'id': 2, 'label': '<Node (2)>'}, {'visible': True, 'children': [{'visible': True, 'id': 5, 'label': '<Node (5)>'}, {'visible': True, 'id': 6, 'label': '<Node (6)>'}], 'id': 4, 'label': '<Node (4)>'}, {'visible': True, 'children': [{'visible': True, 'children': [{'visible': None, 'id': 9, 'label': '<Node (9)>'}], 'id': 8, 'label': '<Node (8)>'}, {'visible': None, 'children': [{'visible': None, 'id': 11, 'label': '<Node (11)>'}], 'id': 10, 'label': '<Node (10)>'}], 'id': 7, 'label': '<Node (7)>'}], 'id': 1, 'label': '<Node (1)>'}, {
1297-
'visible': None, 'children': [{'visible': None, 'children': [{'visible': None, 'id': 14, 'label': '<Node (14)>'}], 'id': 13, 'label': '<Node (13)>'}, {'visible': None, 'children': [{'visible': None, 'id': 16, 'label': '<Node (16)>'}, {'visible': None, 'id': 17, 'label': '<Node (17)>'}], 'id': 15, 'label': '<Node (15)>'}, {'visible': None, 'children': [{'visible': None, 'children': [{'visible': None, 'id': 20, 'label': '<Node (20)>'}], 'id': 19, 'label': '<Node (19)>'}, {'visible': None, 'children': [{'visible': None, 'id': 22, 'label': '<Node (22)>'}], 'id': 21, 'label': '<Node (21)>'}], 'id': 18, 'label': '<Node (18)>'}], 'id': 12, 'label': '<Node (12)>'}])
1361+
tree_reqursively = self.model.get_tree(self.session, json=True,
1362+
json_fields=fields)
1363+
self.assertEqual(tree, reference_tree)
1364+
self.assertEqual(tree_reqursively, reference_tree)
12981365

12991366
def test_rebuild(self):
13001367
""" Rebuild tree with tree_id==1

0 commit comments

Comments
 (0)