Skip to content

Commit 0041286

Browse files
committed
0.2.5 version! close #64
1 parent 2971c9f commit 0041286

File tree

5 files changed

+237
-67
lines changed

5 files changed

+237
-67
lines changed

CHANGES.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
0.2.5 (2019-07-23)
2+
==================
3+
4+
see issue #64
5+
6+
- Added similar `django_mptt` methods `get_siblings` and `get_children`
7+
8+
19
0.2.4 (2018-12-14)
210
==================
311

sqlalchemy_mptt/mixins.py

Lines changed: 129 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -50,31 +50,34 @@ def __repr__(self):
5050
@declared_attr
5151
def __table_args__(cls):
5252
return (
53-
Index('%s_lft_idx' % cls.__tablename__, cls.left.name),
54-
Index('%s_rgt_idx' % cls.__tablename__, cls.right.name),
55-
Index('%s_level_idx' % cls.__tablename__, cls.level.name),
53+
Index("%s_lft_idx" % cls.__tablename__, cls.left.name),
54+
Index("%s_rgt_idx" % cls.__tablename__, cls.right.name),
55+
Index("%s_level_idx" % cls.__tablename__, cls.level.name),
5656
)
5757

5858
@classmethod
5959
def __declare_first__(cls):
6060
cls.__mapper__.batch = False
6161

62-
@classmethod
63-
def get_pk_name(cls):
64-
return getattr(cls, 'sqlalchemy_mptt_pk_name', 'id')
65-
6662
@classmethod
6763
def get_default_level(cls):
68-
'''
64+
"""
6965
Compatibility with Django MPTT: level value for root node.
7066
See https://github.com/uralbash/sqlalchemy_mptt/issues/56
71-
'''
72-
return getattr(cls, 'sqlalchemy_mptt_default_level', 1)
67+
"""
68+
return getattr(cls, "sqlalchemy_mptt_default_level", 1)
69+
70+
@classmethod
71+
def get_pk_name(cls):
72+
return getattr(cls, "sqlalchemy_mptt_pk_name", "id")
7373

7474
@classmethod
7575
def get_pk_column(cls):
7676
return getattr(cls, cls.get_pk_name())
7777

78+
def get_pk_value(self):
79+
return getattr(self, self.get_pk_name())
80+
7881
@declared_attr
7982
def tree_id(cls):
8083
return Column("tree_id", Integer)
@@ -88,10 +91,7 @@ def parent_id(cls):
8891
return Column(
8992
"parent_id",
9093
pk.type,
91-
ForeignKey(
92-
'{}.{}'.format(cls.__tablename__, pk.name),
93-
ondelete='CASCADE'
94-
)
94+
ForeignKey("{}.{}".format(cls.__tablename__, pk.name), ondelete="CASCADE"),
9595
)
9696

9797
@declared_attr
@@ -100,12 +100,12 @@ def parent(self):
100100
self,
101101
order_by=lambda: self.left,
102102
foreign_keys=[self.parent_id],
103-
remote_side='{}.{}'.format(self.__name__, self.get_pk_name()),
103+
remote_side="{}.{}".format(self.__name__, self.get_pk_name()),
104104
backref=backref(
105-
'children',
105+
"children",
106106
cascade="all,delete",
107-
order_by=lambda: (self.tree_id, self.left)
108-
)
107+
order_by=lambda: (self.tree_id, self.left),
108+
),
109109
)
110110

111111
@declared_attr
@@ -131,12 +131,16 @@ def is_ancestor_of(self, other, inclusive=False):
131131
* :mod:`sqlalchemy_mptt.tests.cases.integrity.test_hierarchy_structure`
132132
"""
133133
if inclusive:
134-
return (self.tree_id == other.tree_id) \
135-
& (self.left <= other.left) \
134+
return (
135+
(self.tree_id == other.tree_id)
136+
& (self.left <= other.left)
136137
& (other.right <= self.right)
137-
return (self.tree_id == other.tree_id) \
138-
& (self.left < other.left) \
138+
)
139+
return (
140+
(self.tree_id == other.tree_id)
141+
& (self.left < other.left)
139142
& (other.right < self.right)
143+
)
140144

141145
@hybrid_method
142146
def is_descendant_of(self, other, inclusive=False):
@@ -198,9 +202,14 @@ def leftsibling_in_level(self):
198202
""" # noqa
199203
table = _get_tree_table(self.__mapper__)
200204
session = Session.object_session(self)
201-
current_lvl_nodes = session.query(table) \
202-
.filter_by(level=self.level).filter_by(tree_id=self.tree_id) \
203-
.filter(table.c.lft < self.left).order_by(table.c.lft).all()
205+
current_lvl_nodes = (
206+
session.query(table)
207+
.filter_by(level=self.level)
208+
.filter_by(tree_id=self.tree_id)
209+
.filter(table.c.lft < self.left)
210+
.order_by(table.c.lft)
211+
.all()
212+
)
204213
if current_lvl_nodes:
205214
return current_lvl_nodes[-1]
206215
return None
@@ -212,11 +221,11 @@ def _node_to_dict(cls, node, json, json_fields):
212221
if json:
213222
pk_name = node.get_pk_name()
214223
# jqTree or jsTree format
215-
result = {'id': getattr(node, pk_name), 'label': node.__repr__()}
224+
result = {"id": getattr(node, pk_name), "label": node.__repr__()}
216225
if json_fields:
217226
result.update(json_fields(node))
218227
else:
219-
result = {'node': node}
228+
result = {"node": node}
220229
return result
221230

222231
@classmethod
@@ -230,9 +239,11 @@ def _base_query_obj(self, session=None):
230239

231240
@classmethod
232241
def _base_order(cls, query, order=asc):
233-
return query.order_by(order(cls.tree_id))\
234-
.order_by(order(cls.level))\
242+
return (
243+
query.order_by(order(cls.tree_id))
244+
.order_by(order(cls.level))
235245
.order_by(order(cls.left))
246+
)
236247

237248
@classmethod
238249
def get_tree(cls, session=None, json=False, json_fields=None, query=None):
@@ -284,10 +295,10 @@ def get_node_id(node):
284295
# Find parent in the tree
285296
if parent_id not in nodes_of_level.keys():
286297
continue
287-
if 'children' not in nodes_of_level[parent_id]:
288-
nodes_of_level[parent_id]['children'] = []
298+
if "children" not in nodes_of_level[parent_id]:
299+
nodes_of_level[parent_id]["children"] = []
289300
# Append node to parent
290-
nl = nodes_of_level[parent_id]['children']
301+
nl = nodes_of_level[parent_id]["children"]
291302
nl.append(result)
292303
nodes_of_level[get_node_id(node)] = nl[-1]
293304
else: # for top level nodes
@@ -330,10 +341,7 @@ def drilldown_tree(self, session=None, json=False, json_fields=None):
330341
if not session:
331342
session = object_session(self)
332343
return self.get_tree(
333-
session,
334-
json=json,
335-
json_fields=json_fields,
336-
query=self._drilldown_query
344+
session, json=json, json_fields=json_fields, query=self._drilldown_query
337345
)
338346

339347
def path_to_root(self, session=None, order=desc):
@@ -366,6 +374,82 @@ def path_to_root(self, session=None, order=desc):
366374
query = query.filter(table.is_ancestor_of(self, inclusive=True))
367375
return self._base_order(query, order=order)
368376

377+
def get_siblings(self, include_self=False, session=None):
378+
"""
379+
https://github.com/uralbash/sqlalchemy_mptt/issues/64
380+
https://django-mptt.readthedocs.io/en/latest/models.html#get-siblings-include-self-false
381+
382+
Creates a query containing siblings of this model
383+
instance. Root nodes are considered to be siblings of other root
384+
nodes.
385+
386+
For example:
387+
388+
node10.get_siblings() -> [Node(8)]
389+
390+
Only one node is sibling of node10
391+
392+
.. code::
393+
394+
level Nested sets example
395+
396+
1 1(1)22
397+
______________|____________________
398+
| | |
399+
| | |
400+
2 2(2)5 6(4)11 12(7)21
401+
| ^ / \
402+
3 3(3)4 7(5)8 9(6)10 / \
403+
13(8)16 17(10)20
404+
| |
405+
4 14(9)15 18(11)19
406+
407+
408+
"""
409+
table = self.__class__
410+
query = self._base_query_obj(session=session)
411+
if self.parent_id:
412+
query = query.filter(table.parent_id == self.parent_id)
413+
else:
414+
query = query.filter(table.parent_id == None)
415+
if not include_self:
416+
query = query.filter(self.get_pk_column() != self.get_pk_value())
417+
return query
418+
419+
def get_children(self, session=None):
420+
"""
421+
https://github.com/uralbash/sqlalchemy_mptt/issues/64
422+
https://github.com/django-mptt/django-mptt/blob/fd76a816e05feb5fb0fc23126d33e514460a0ead/mptt/models.py#L563
423+
424+
Returns a query containing the immediate children of this
425+
model instance, in tree order.
426+
427+
For example:
428+
429+
node7.get_children() -> [Node(8), Node(10)]
430+
431+
.. code::
432+
433+
level Nested sets example
434+
435+
1 1(1)22
436+
______________|____________________
437+
| | |
438+
| | |
439+
2 2(2)5 6(4)11 12(7)21
440+
| ^ / \
441+
3 3(3)4 7(5)8 9(6)10 / \
442+
13(8)16 17(10)20
443+
| |
444+
4 14(9)15 18(11)19
445+
446+
447+
"""
448+
table = self.__class__
449+
query = self._base_query_obj(session=session)
450+
query = query.filter(table.parent_id == self.get_pk_value())
451+
return query
452+
369453
@classmethod
370454
def rebuild_tree(cls, session, tree_id):
371455
""" This method rebuid tree.
@@ -378,10 +462,15 @@ def rebuild_tree(cls, session, tree_id):
378462
379463
* :mod:`sqlalchemy_mptt.tests.cases.get_tree.test_rebuild`
380464
"""
381-
session.query(cls).filter_by(tree_id=tree_id)\
382-
.update({cls.left: 0, cls.right: 0, cls.level: 0})
383-
top = session.query(cls).filter_by(parent_id=None)\
384-
.filter_by(tree_id=tree_id).one()
465+
session.query(cls).filter_by(tree_id=tree_id).update(
466+
{cls.left: 0, cls.right: 0, cls.level: 0}
467+
)
468+
top = (
469+
session.query(cls)
470+
.filter_by(parent_id=None)
471+
.filter_by(tree_id=tree_id)
472+
.one()
473+
)
385474
top.left = left = 1
386475
top.right = right = 2
387476
top.level = level = cls.get_default_level()

sqlalchemy_mptt/tests/__init__.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@
4343

4444
# local
4545
from .cases.get_tree import Tree
46+
from .cases.get_node import GetNodes
4647
from .cases.edit_node import Changes
4748
from .cases.integrity import DataIntegrity
4849
from .cases.move_node import MoveAfter, MoveBefore, MoveInside
4950
from .cases.initialize import Initialize
5051

5152

5253
class Fixtures(object):
53-
5454
def __init__(self, session):
5555
self.session = session
5656

@@ -59,22 +59,22 @@ def add(self, model, fixtures):
5959
file = open(os.path.join(here, fixtures))
6060
fixtures = json.loads(file.read())
6161
for fixture in fixtures:
62-
if hasattr(model, 'sqlalchemy_mptt_pk_name'):
63-
fixture[model.sqlalchemy_mptt_pk_name] = fixture.pop('id')
62+
if hasattr(model, "sqlalchemy_mptt_pk_name"):
63+
fixture[model.sqlalchemy_mptt_pk_name] = fixture.pop("id")
6464
self.session.add(model(**fixture))
6565
self.session.flush()
6666

6767

6868
class TreeTestingMixin(
69-
Initialize,
70-
Changes,
71-
MoveAfter,
72-
DataIntegrity,
73-
MoveBefore,
74-
MoveInside,
75-
Tree
69+
Initialize,
70+
Changes,
71+
MoveAfter,
72+
DataIntegrity,
73+
MoveBefore,
74+
MoveInside,
75+
Tree,
76+
GetNodes,
7677
):
77-
7878
base = None
7979
model = None
8080

@@ -84,27 +84,22 @@ def catch_queries(self, conn, cursor, statement, *args):
8484
def start_query_counter(self):
8585
self.stmts = []
8686
event.listen(
87-
self.session.bind.engine,
88-
"before_cursor_execute",
89-
self.catch_queries
87+
self.session.bind.engine, "before_cursor_execute", self.catch_queries
9088
)
9189

9290
def stop_query_counter(self):
9391
event.remove(
94-
self.session.bind.engine,
95-
"before_cursor_execute",
96-
self.catch_queries
92+
self.session.bind.engine, "before_cursor_execute", self.catch_queries
9793
)
9894

9995
def setUp(self):
100-
self.engine = create_engine('sqlite:///:memory:')
96+
self.engine = create_engine("sqlite:///:memory:")
10197
Session = mptt_sessionmaker(sessionmaker(bind=self.engine))
10298
self.session = Session()
10399
self.base.metadata.create_all(self.engine)
104100
self.fixture = Fixtures(self.session)
105101
self.fixture.add(
106-
self.model,
107-
os.path.join('fixtures', getattr(self, 'fixtures', 'tree.json'))
102+
self.model, os.path.join("fixtures", getattr(self, "fixtures", "tree.json"))
108103
)
109104

110105
self.result = self.session.query(
@@ -113,7 +108,7 @@ def setUp(self):
113108
self.model.right,
114109
self.model.level,
115110
self.model.parent_id,
116-
self.model.tree_id
111+
self.model.tree_id,
117112
)
118113

119114
def tearDown(self):
@@ -123,11 +118,15 @@ def test_session_expire_for_move_after_to_new_tree(self):
123118
"""
124119
https://github.com/uralbash/sqlalchemy_mptt/issues/33
125120
"""
126-
node = self.session.query(self.model) \
127-
.filter(self.model.get_pk_column() == 4).one()
128-
children = self.session.query(self.model) \
129-
.filter(self.model.get_pk_column().in_((5, 6))).all()
130-
node.move_after('1')
121+
node = (
122+
self.session.query(self.model).filter(self.model.get_pk_column() == 4).one()
123+
)
124+
children = (
125+
self.session.query(self.model)
126+
.filter(self.model.get_pk_column().in_((5, 6)))
127+
.all()
128+
)
129+
node.move_after("1")
131130
self.session.flush()
132131

133132
_level = node.get_default_level()

0 commit comments

Comments
 (0)