@@ -50,31 +50,34 @@ def __repr__(self):
5050 @declared_attr
5151 def __table_args__ (cls ):
5252 return (
53- Index (' %s_lft_idx' % cls .__tablename__ , cls .left .name ),
54- Index (' %s_rgt_idx' % cls .__tablename__ , cls .right .name ),
55- Index (' %s_level_idx' % cls .__tablename__ , cls .level .name ),
53+ Index (" %s_lft_idx" % cls .__tablename__ , cls .left .name ),
54+ Index (" %s_rgt_idx" % cls .__tablename__ , cls .right .name ),
55+ Index (" %s_level_idx" % cls .__tablename__ , cls .level .name ),
5656 )
5757
5858 @classmethod
5959 def __declare_first__ (cls ):
6060 cls .__mapper__ .batch = False
6161
62- @classmethod
63- def get_pk_name (cls ):
64- return getattr (cls , 'sqlalchemy_mptt_pk_name' , 'id' )
65-
6662 @classmethod
6763 def get_default_level (cls ):
68- '''
64+ """
6965 Compatibility with Django MPTT: level value for root node.
7066 See https://github.com/uralbash/sqlalchemy_mptt/issues/56
71- '''
72- return getattr (cls , 'sqlalchemy_mptt_default_level' , 1 )
67+ """
68+ return getattr (cls , "sqlalchemy_mptt_default_level" , 1 )
69+
70+ @classmethod
71+ def get_pk_name (cls ):
72+ return getattr (cls , "sqlalchemy_mptt_pk_name" , "id" )
7373
7474 @classmethod
7575 def get_pk_column (cls ):
7676 return getattr (cls , cls .get_pk_name ())
7777
78+ def get_pk_value (self ):
79+ return getattr (self , self .get_pk_name ())
80+
7881 @declared_attr
7982 def tree_id (cls ):
8083 return Column ("tree_id" , Integer )
@@ -88,10 +91,7 @@ def parent_id(cls):
8891 return Column (
8992 "parent_id" ,
9093 pk .type ,
91- ForeignKey (
92- '{}.{}' .format (cls .__tablename__ , pk .name ),
93- ondelete = 'CASCADE'
94- )
94+ ForeignKey ("{}.{}" .format (cls .__tablename__ , pk .name ), ondelete = "CASCADE" ),
9595 )
9696
9797 @declared_attr
@@ -100,12 +100,12 @@ def parent(self):
100100 self ,
101101 order_by = lambda : self .left ,
102102 foreign_keys = [self .parent_id ],
103- remote_side = ' {}.{}' .format (self .__name__ , self .get_pk_name ()),
103+ remote_side = " {}.{}" .format (self .__name__ , self .get_pk_name ()),
104104 backref = backref (
105- ' children' ,
105+ " children" ,
106106 cascade = "all,delete" ,
107- order_by = lambda : (self .tree_id , self .left )
108- )
107+ order_by = lambda : (self .tree_id , self .left ),
108+ ),
109109 )
110110
111111 @declared_attr
@@ -131,12 +131,16 @@ def is_ancestor_of(self, other, inclusive=False):
131131 * :mod:`sqlalchemy_mptt.tests.cases.integrity.test_hierarchy_structure`
132132 """
133133 if inclusive :
134- return (self .tree_id == other .tree_id ) \
135- & (self .left <= other .left ) \
134+ return (
135+ (self .tree_id == other .tree_id )
136+ & (self .left <= other .left )
136137 & (other .right <= self .right )
137- return (self .tree_id == other .tree_id ) \
138- & (self .left < other .left ) \
138+ )
139+ return (
140+ (self .tree_id == other .tree_id )
141+ & (self .left < other .left )
139142 & (other .right < self .right )
143+ )
140144
141145 @hybrid_method
142146 def is_descendant_of (self , other , inclusive = False ):
@@ -198,9 +202,14 @@ def leftsibling_in_level(self):
198202 """ # noqa
199203 table = _get_tree_table (self .__mapper__ )
200204 session = Session .object_session (self )
201- current_lvl_nodes = session .query (table ) \
202- .filter_by (level = self .level ).filter_by (tree_id = self .tree_id ) \
203- .filter (table .c .lft < self .left ).order_by (table .c .lft ).all ()
205+ current_lvl_nodes = (
206+ session .query (table )
207+ .filter_by (level = self .level )
208+ .filter_by (tree_id = self .tree_id )
209+ .filter (table .c .lft < self .left )
210+ .order_by (table .c .lft )
211+ .all ()
212+ )
204213 if current_lvl_nodes :
205214 return current_lvl_nodes [- 1 ]
206215 return None
@@ -212,11 +221,11 @@ def _node_to_dict(cls, node, json, json_fields):
212221 if json :
213222 pk_name = node .get_pk_name ()
214223 # jqTree or jsTree format
215- result = {'id' : getattr (node , pk_name ), ' label' : node .__repr__ ()}
224+ result = {"id" : getattr (node , pk_name ), " label" : node .__repr__ ()}
216225 if json_fields :
217226 result .update (json_fields (node ))
218227 else :
219- result = {' node' : node }
228+ result = {" node" : node }
220229 return result
221230
222231 @classmethod
@@ -230,9 +239,11 @@ def _base_query_obj(self, session=None):
230239
231240 @classmethod
232241 def _base_order (cls , query , order = asc ):
233- return query .order_by (order (cls .tree_id ))\
234- .order_by (order (cls .level ))\
242+ return (
243+ query .order_by (order (cls .tree_id ))
244+ .order_by (order (cls .level ))
235245 .order_by (order (cls .left ))
246+ )
236247
237248 @classmethod
238249 def get_tree (cls , session = None , json = False , json_fields = None , query = None ):
@@ -284,10 +295,10 @@ def get_node_id(node):
284295 # Find parent in the tree
285296 if parent_id not in nodes_of_level .keys ():
286297 continue
287- if ' children' not in nodes_of_level [parent_id ]:
288- nodes_of_level [parent_id ][' children' ] = []
298+ if " children" not in nodes_of_level [parent_id ]:
299+ nodes_of_level [parent_id ][" children" ] = []
289300 # Append node to parent
290- nl = nodes_of_level [parent_id ][' children' ]
301+ nl = nodes_of_level [parent_id ][" children" ]
291302 nl .append (result )
292303 nodes_of_level [get_node_id (node )] = nl [- 1 ]
293304 else : # for top level nodes
@@ -330,10 +341,7 @@ def drilldown_tree(self, session=None, json=False, json_fields=None):
330341 if not session :
331342 session = object_session (self )
332343 return self .get_tree (
333- session ,
334- json = json ,
335- json_fields = json_fields ,
336- query = self ._drilldown_query
344+ session , json = json , json_fields = json_fields , query = self ._drilldown_query
337345 )
338346
339347 def path_to_root (self , session = None , order = desc ):
@@ -366,6 +374,82 @@ def path_to_root(self, session=None, order=desc):
366374 query = query .filter (table .is_ancestor_of (self , inclusive = True ))
367375 return self ._base_order (query , order = order )
368376
377+ def get_siblings (self , include_self = False , session = None ):
378+ """
379+ https://github.com/uralbash/sqlalchemy_mptt/issues/64
380+ https://django-mptt.readthedocs.io/en/latest/models.html#get-siblings-include-self-false
381+
382+ Creates a query containing siblings of this model
383+ instance. Root nodes are considered to be siblings of other root
384+ nodes.
385+
386+ For example:
387+
388+ node10.get_siblings() -> [Node(8)]
389+
390+ Only one node is sibling of node10
391+
392+ .. code::
393+
394+ level Nested sets example
395+
396+ 1 1(1)22
397+ ______________|____________________
398+ | | |
399+ | | |
400+ 2 2(2)5 6(4)11 12(7)21
401+ | ^ / \
402+ 3 3(3)4 7(5)8 9(6)10 / \
403+ 13(8)16 17(10)20
404+ | |
405+ 4 14(9)15 18(11)19
406+
407+
408+ """
409+ table = self .__class__
410+ query = self ._base_query_obj (session = session )
411+ if self .parent_id :
412+ query = query .filter (table .parent_id == self .parent_id )
413+ else :
414+ query = query .filter (table .parent_id == None )
415+ if not include_self :
416+ query = query .filter (self .get_pk_column () != self .get_pk_value ())
417+ return query
418+
419+ def get_children (self , session = None ):
420+ """
421+ https://github.com/uralbash/sqlalchemy_mptt/issues/64
422+ https://github.com/django-mptt/django-mptt/blob/fd76a816e05feb5fb0fc23126d33e514460a0ead/mptt/models.py#L563
423+
424+ Returns a query containing the immediate children of this
425+ model instance, in tree order.
426+
427+ For example:
428+
429+ node7.get_children() -> [Node(8), Node(10)]
430+
431+ .. code::
432+
433+ level Nested sets example
434+
435+ 1 1(1)22
436+ ______________|____________________
437+ | | |
438+ | | |
439+ 2 2(2)5 6(4)11 12(7)21
440+ | ^ / \
441+ 3 3(3)4 7(5)8 9(6)10 / \
442+ 13(8)16 17(10)20
443+ | |
444+ 4 14(9)15 18(11)19
445+
446+
447+ """
448+ table = self .__class__
449+ query = self ._base_query_obj (session = session )
450+ query = query .filter (table .parent_id == self .get_pk_value ())
451+ return query
452+
369453 @classmethod
370454 def rebuild_tree (cls , session , tree_id ):
371455 """ This method rebuid tree.
@@ -378,10 +462,15 @@ def rebuild_tree(cls, session, tree_id):
378462
379463 * :mod:`sqlalchemy_mptt.tests.cases.get_tree.test_rebuild`
380464 """
381- session .query (cls ).filter_by (tree_id = tree_id )\
382- .update ({cls .left : 0 , cls .right : 0 , cls .level : 0 })
383- top = session .query (cls ).filter_by (parent_id = None )\
384- .filter_by (tree_id = tree_id ).one ()
465+ session .query (cls ).filter_by (tree_id = tree_id ).update (
466+ {cls .left : 0 , cls .right : 0 , cls .level : 0 }
467+ )
468+ top = (
469+ session .query (cls )
470+ .filter_by (parent_id = None )
471+ .filter_by (tree_id = tree_id )
472+ .one ()
473+ )
385474 top .left = left = 1
386475 top .right = right = 2
387476 top .level = level = cls .get_default_level ()
0 commit comments