@@ -169,13 +169,13 @@ def _node_to_dict(cls, node, json, json_fields):
169169
170170 @classmethod
171171 def _base_query (cls , session = None ):
172- # get orm session
173- if not session :
174- session = object_session (cls )
175-
176- # handle custom query
177172 return session .query (cls )
178173
174+ def _base_query_obj (self , session = None ):
175+ if not session :
176+ session = object_session (self )
177+ return self ._base_query (session )
178+
179179 @classmethod
180180 def _base_order (cls , query , order = asc ):
181181 return query .order_by (order (cls .tree_id ))\
@@ -243,8 +243,10 @@ def get_node_id(node):
243243 nodes_of_level [get_node_id (node )] = tree [- 1 ]
244244 return tree
245245
246- def _drilldown_query (self , nodes ):
246+ def _drilldown_query (self , nodes = None ):
247247 table = self .__class__
248+ if not nodes :
249+ nodes = self ._base_query_obj ()
248250 return nodes .filter (table .tree_id == self .tree_id )\
249251 .filter (table .left >= self .left )\
250252 .filter (table .right <= self .right )
@@ -275,6 +277,8 @@ def drilldown_tree(self, session=None, json=False, json_fields=None):
275277
276278 * :mod:`sqlalchemy_mptt.tests.cases.get_tree.test_drilldown_tree`
277279 """
280+ if not session :
281+ session = object_session (self )
278282 return self .get_tree (session , json = json , json_fields = json_fields ,
279283 query = self ._drilldown_query )
280284
@@ -304,7 +308,7 @@ def path_to_root(self, session=None):
304308 -------------
305309 """
306310 table = self .__class__
307- query = table . _base_query ( session )
311+ query = self . _base_query_obj ( session = session )
308312 query = query .filter (table .tree_id == self .tree_id )\
309313 .filter (table .left <= self .left )\
310314 .filter (table .right >= self .right )
0 commit comments