Skip to content

Commit 38bbe05

Browse files
author
matmoncon
committed
chore: restructure tests dir
1 parent 81e2d06 commit 38bbe05

File tree

71 files changed

+4088
-4590
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+4088
-4590
lines changed

pyneo4j_ogm/core/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
RelationshipPropertyCardinality,
3636
RelationshipPropertyDirection,
3737
)
38-
from pyneo4j_ogm.fields.settings import BaseModelSettings
38+
from pyneo4j_ogm.fields.settings import (
39+
BaseModelSettings,
40+
NodeModelSettings,
41+
RelationshipModelSettings,
42+
)
3943
from pyneo4j_ogm.logger import logger
4044
from pyneo4j_ogm.pydantic_utils import (
4145
IS_PYDANTIC_V2,
@@ -236,7 +240,7 @@ class ModelBase(BaseModel, Generic[V]):
236240
directly and is only used as a base class for `NodeModel` and `RelationshipModel`.
237241
"""
238242

239-
_settings: BaseModelSettings = PrivateAttr()
243+
_settings: Union[NodeModelSettings, RelationshipModelSettings] = PrivateAttr()
240244
_client: Pyneo4jClient = PrivateAttr()
241245
_query_builder: QueryBuilder = PrivateAttr()
242246
_db_properties: Dict[str, Any] = PrivateAttr(default={})

pyneo4j_ogm/core/client.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,22 +153,22 @@ async def register_models(self, models: List[Type[Union[NodeModel, RelationshipM
153153
logger.info("Registering models %s with client %s", models, self)
154154

155155
for model in models:
156-
for registered_model in self.models:
157-
registered_model_settings = registered_model.model_settings()
158-
model_settings = model.model_settings()
159-
160-
if (
161-
isinstance(registered_model_settings, RelationshipModelSettings)
162-
and isinstance(model_settings, RelationshipModelSettings)
163-
and registered_model_settings.type == model_settings.type
164-
):
165-
raise AlreadyRegistered(cast(str, model_settings.type))
166-
elif (
167-
isinstance(registered_model_settings, NodeModelSettings)
168-
and isinstance(model_settings, NodeModelSettings)
169-
and set(registered_model_settings.labels) == set(model_settings.labels)
170-
):
171-
raise AlreadyRegistered(cast(str, model_settings.labels))
156+
# for registered_model in self.models:
157+
# registered_model_settings = registered_model.model_settings()
158+
# model_settings = model.model_settings()
159+
160+
# if (
161+
# isinstance(registered_model_settings, RelationshipModelSettings)
162+
# and isinstance(model_settings, RelationshipModelSettings)
163+
# and registered_model_settings.type == model_settings.type
164+
# ):
165+
# raise AlreadyRegistered(cast(str, model_settings.type))
166+
# elif (
167+
# isinstance(registered_model_settings, NodeModelSettings)
168+
# and isinstance(model_settings, NodeModelSettings)
169+
# and set(registered_model_settings.labels) == set(model_settings.labels)
170+
# ):
171+
# raise AlreadyRegistered(cast(str, model_settings.labels))
172172

173173
if issubclass(model, (NodeModel, RelationshipModel)):
174174
logger.debug("Found valid mode %s, registering with client", model.__name__)

pyneo4j_ogm/core/node.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ class NodeModel(ModelBase[NodeModelSettings]):
100100
model and are defined in `pyneo4j_ogm.fields.settings.NodeModelSettings`.
101101
"""
102102

103-
_settings: NodeModelSettings
103+
_settings: NodeModelSettings = PrivateAttr()
104104
_relationship_properties: Set[str] = PrivateAttr()
105105
Settings: ClassVar[Type[NodeModelSettings]]
106106

@@ -196,6 +196,8 @@ async def create(self: T) -> T:
196196
if len(results) == 0 or len(results[0]) == 0 or results[0][0] is None:
197197
raise UnexpectedEmptyResult()
198198

199+
# Since the instance is now hydrated, we can set the element id and id and reset the modified properties
200+
# to the current instance values
199201
logger.debug("Hydrating instance values")
200202
setattr(self, "_element_id", getattr(cast(T, results[0][0]), "_element_id"))
201203
setattr(self, "_id", getattr(cast(T, results[0][0]), "_id"))
@@ -230,6 +232,8 @@ async def update(self) -> None:
230232
]
231233
)
232234

235+
# We return the updated node to check if the query was successful
236+
# since Neo4j does not raise any exceptions if the node does not exist
233237
results, _ = await self._client.cypher(
234238
query=f"""
235239
MATCH {self._query_builder.node_match(list(self._settings.labels))}
@@ -269,6 +273,7 @@ async def delete(self) -> None:
269273
parameters={"element_id": self._element_id},
270274
)
271275

276+
# If the returned value is empty, the node does not exist and the query failed
272277
logger.debug("Checking if query returned a result")
273278
if len(results) == 0 or len(results[0]) == 0 or results[0][0] is None:
274279
raise UnexpectedEmptyResult()
@@ -296,6 +301,8 @@ async def refresh(self) -> None:
296301
parameters={"element_id": self._element_id},
297302
)
298303

304+
# If the returned value is empty, we can not refresh the instance
305+
# since the node does not exist anymore
299306
logger.debug("Checking if query returned a result")
300307
if len(results) == 0 or len(results[0]) == 0 or results[0][0] is None:
301308
raise UnexpectedEmptyResult()
@@ -520,6 +527,7 @@ async def find_one(
520527
)
521528

522529
if do_auto_fetch:
530+
# If auto-fetch is enabled, we need to build the auto-fetch queries in addition to the normal query
523531
logger.debug("Querying database with auto-fetch enabled")
524532
projection_query = (
525533
"RETURN n" if cls._query_builder.query["projections"] == "" else cls._query_builder.query["projections"]
@@ -562,10 +570,12 @@ async def find_one(
562570
or results[0][0] is None
563571
or (isinstance(results[0][0], dict) and len(results[0][0]) == 0)
564572
):
573+
# If no results are found, we return None or raise an exception
565574
if raise_on_empty:
566575
raise NoResultFound(filters)
567576
return None
568577

578+
# Normalize results to a single instance
569579
if isinstance(results[0][0], Node):
570580
instance = cls._inflate(graph_entity=results[0][0])
571581
elif isinstance(results[0][0], list):

pyneo4j_ogm/core/relationship.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
relationships. It provides base functionality like de-/inflation and validation and methods for interacting with
44
the database for CRUD operations on relationships.
55
"""
6+
67
import json
78
import re
89
from functools import wraps
@@ -83,7 +84,7 @@ class RelationshipModel(ModelBase[RelationshipModelSettings]):
8384
`pyneo4j_ogm.fields.settings.RelationshipModelSettings`.
8485
"""
8586

86-
_settings: RelationshipModelSettings
87+
_settings: RelationshipModelSettings = PrivateAttr()
8788
_start_node_element_id: Optional[str] = PrivateAttr(default=None)
8889
_start_node_id: Optional[int] = PrivateAttr(default=None)
8990
_end_node_element_id: Optional[str] = PrivateAttr(default=None)
@@ -427,6 +428,7 @@ async def find_many(
427428
parameters=cls._query_builder.parameters,
428429
)
429430

431+
# Normalize results to instance classes
430432
logger.debug("Building instances from query results")
431433
for result_list in results:
432434
for result in result_list:

pyneo4j_ogm/migrations/actions/down.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ async def down(down_count: RunMigrationCount = "all", config_path: Optional[str]
3535

3636
logger.debug("Filtering migration files for applied migrations")
3737
applied_migration_identifiers = migration_node.get_applied_migration_identifiers
38+
# Remove all migration files that have not been applied
3839
for identifier in deepcopy(migration_files).keys():
3940
if identifier not in applied_migration_identifiers:
4041
migration_files.pop(identifier, None)
@@ -43,6 +44,8 @@ async def down(down_count: RunMigrationCount = "all", config_path: Optional[str]
4344
if down_count != "all" and count >= down_count:
4445
break
4546

47+
# We can get the current migration by getting the max identifier, which is a
48+
# UNIX timestamp meaning the highest value is the most recent migration
4649
current_migration_identifier = max(migration_files.keys())
4750
current_migration = migration_files[current_migration_identifier]
4851

pyneo4j_ogm/migrations/actions/up.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ async def up(up_count: RunMigrationCount = "all", config_path: Optional[str] = N
4141
if up_count != "all" and count >= up_count:
4242
break
4343

44+
# Since the migration files are sorted by identifier, we can get the current migration
45+
# by getting the min identifier, which is a UNIX timestamp meaning the lowest value is the oldest migration
4446
current_migration_identifier = min(migration_files.keys())
4547
current_migration = migration_files[current_migration_identifier]
4648

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,168 @@
22
# pyright: reportGeneralTypeIssues=false
33

44
import json
5-
from platform import node
5+
from typing import Union, cast
6+
from unittest.mock import AsyncMock, MagicMock
67

78
import pytest
89

9-
from pyneo4j_ogm.core.base import ModelBase
10+
from pyneo4j_ogm.core.base import ModelBase, hooks
1011
from pyneo4j_ogm.core.node import NodeModel
1112
from pyneo4j_ogm.core.relationship import RelationshipModel
1213
from pyneo4j_ogm.exceptions import ListItemNotEncodable, UnregisteredModel
14+
from pyneo4j_ogm.fields.settings import BaseModelSettings
1315
from pyneo4j_ogm.pydantic_utils import get_model_dump, get_model_dump_json
16+
from tests.fixtures.db_setup import Developer
17+
18+
19+
def hook_func():
20+
pass
21+
22+
23+
def test_pre_hooks():
24+
Developer.register_pre_hooks("test_hook", lambda: None)
25+
assert len(Developer._settings.pre_hooks["test_hook"]) == 1
26+
assert all(callable(func) for func in Developer._settings.pre_hooks["test_hook"])
27+
Developer._settings.pre_hooks["test_hook"] = []
28+
29+
Developer.register_pre_hooks("test_hook", [lambda: None, lambda: None])
30+
assert len(Developer._settings.pre_hooks["test_hook"]) == 2
31+
assert all(callable(func) for func in Developer._settings.pre_hooks["test_hook"])
32+
Developer._settings.pre_hooks["test_hook"] = []
33+
34+
Developer.register_pre_hooks("test_hook", [lambda: None, "invalid"]) # type: ignore
35+
assert len(Developer._settings.pre_hooks["test_hook"]) == 1
36+
assert all(callable(func) for func in Developer._settings.pre_hooks["test_hook"])
37+
Developer._settings.pre_hooks["test_hook"] = []
38+
39+
Developer.register_pre_hooks("test_hook", lambda: None)
40+
Developer.register_pre_hooks("test_hook", lambda: None, overwrite=True)
41+
assert len(Developer._settings.pre_hooks["test_hook"]) == 1
42+
assert all(callable(func) for func in Developer._settings.pre_hooks["test_hook"])
43+
Developer._settings.pre_hooks["test_hook"] = []
44+
45+
Developer.register_pre_hooks("test_hook", lambda: None)
46+
Developer.register_pre_hooks("test_hook", lambda: None)
47+
assert len(Developer._settings.pre_hooks["test_hook"]) == 2
48+
assert all(callable(func) for func in Developer._settings.pre_hooks["test_hook"])
49+
Developer._settings.pre_hooks["test_hook"] = []
50+
51+
52+
def test_post_hooks():
53+
Developer.register_post_hooks("test_hook", lambda: None)
54+
assert len(Developer._settings.post_hooks["test_hook"]) == 1
55+
assert all(callable(func) for func in Developer._settings.post_hooks["test_hook"])
56+
Developer._settings.post_hooks["test_hook"] = []
57+
58+
Developer.register_post_hooks("test_hook", [lambda: None, lambda: None])
59+
assert len(Developer._settings.post_hooks["test_hook"]) == 2
60+
assert all(callable(func) for func in Developer._settings.post_hooks["test_hook"])
61+
Developer._settings.post_hooks["test_hook"] = []
62+
63+
Developer.register_post_hooks("test_hook", [lambda: None, "invalid"]) # type: ignore
64+
assert len(Developer._settings.post_hooks["test_hook"]) == 1
65+
assert all(callable(func) for func in Developer._settings.post_hooks["test_hook"])
66+
Developer._settings.post_hooks["test_hook"] = []
67+
68+
Developer.register_post_hooks("test_hook", lambda: None)
69+
Developer.register_post_hooks("test_hook", lambda: None, overwrite=True)
70+
assert len(Developer._settings.post_hooks["test_hook"]) == 1
71+
assert all(callable(func) for func in Developer._settings.post_hooks["test_hook"])
72+
Developer._settings.post_hooks["test_hook"] = []
73+
74+
Developer.register_post_hooks("test_hook", lambda: None)
75+
Developer.register_post_hooks("test_hook", lambda: None)
76+
assert len(Developer._settings.post_hooks["test_hook"]) == 2
77+
assert all(callable(func) for func in Developer._settings.post_hooks["test_hook"])
78+
Developer._settings.post_hooks["test_hook"] = []
79+
80+
81+
def test_model_settings():
82+
class ModelSettingsTest(NodeModel):
83+
pass
84+
85+
class Settings:
86+
pre_hooks = {"test_hook": [hook_func]}
87+
post_hooks = {"test_hook": [hook_func, hook_func]}
88+
89+
assert ModelSettingsTest.model_settings().pre_hooks == {"test_hook": [hook_func]}
90+
assert ModelSettingsTest.model_settings().post_hooks == {"test_hook": [hook_func, hook_func]}
91+
92+
93+
def test_node_model_modified_properties():
94+
class ModifiedPropertiesTest(NodeModel):
95+
a: str = "a"
96+
b: int = 1
97+
c: bool = True
98+
99+
setattr(ModifiedPropertiesTest, "_client", None)
100+
101+
model = ModifiedPropertiesTest()
102+
model.a = "modified"
103+
assert model.modified_properties == {"a"}
104+
105+
model.b = 2
106+
assert model.modified_properties == {"a", "b"}
107+
108+
109+
def test_relationship_model_modified_properties():
110+
class ModifiedPropertiesTest(RelationshipModel):
111+
a: str = "a"
112+
b: int = 1
113+
c: bool = True
114+
115+
setattr(ModifiedPropertiesTest, "_client", None)
116+
117+
model = ModifiedPropertiesTest()
118+
model.a = "modified"
119+
assert model.modified_properties == {"a"}
120+
121+
model.b = 2
122+
assert model.modified_properties == {"a", "b"}
123+
124+
125+
async def test_hooks_decorator():
126+
class TestClass:
127+
def __init__(self):
128+
self._client = None # type: ignore
129+
self._settings = BaseModelSettings()
130+
self._settings.pre_hooks["async_test_func"] = [MagicMock(__name__="MagicMock"), AsyncMock(), AsyncMock()]
131+
self._settings.post_hooks["async_test_func"] = [MagicMock(__name__="MagicMock"), AsyncMock(), AsyncMock()]
132+
self._settings.pre_hooks["sync_test_func"] = [
133+
MagicMock(__name__="MagicMock"),
134+
MagicMock(__name__="MagicMock"),
135+
AsyncMock(),
136+
]
137+
self._settings.post_hooks["sync_test_func"] = [
138+
MagicMock(__name__="MagicMock"),
139+
MagicMock(__name__="MagicMock"),
140+
AsyncMock(),
141+
]
142+
143+
@hooks
144+
async def async_test_func(self):
145+
pass
146+
147+
@hooks
148+
def sync_test_func(self):
149+
pass
150+
151+
test_instance = TestClass()
152+
await test_instance.async_test_func()
153+
154+
for hook_function in test_instance._settings.pre_hooks["async_test_func"]:
155+
cast(Union[MagicMock, AsyncMock], hook_function).assert_called_once_with(test_instance)
156+
157+
for hook_function in test_instance._settings.post_hooks["async_test_func"]:
158+
cast(Union[MagicMock, AsyncMock], hook_function).assert_called_once_with(test_instance, None)
159+
160+
test_instance.sync_test_func()
161+
162+
for hook_function in test_instance._settings.pre_hooks["sync_test_func"]:
163+
cast(Union[MagicMock, AsyncMock], hook_function).assert_called_once_with(test_instance)
164+
165+
for hook_function in test_instance._settings.post_hooks["sync_test_func"]:
166+
cast(Union[MagicMock, AsyncMock], hook_function).assert_called_once_with(test_instance, None)
14167

15168

16169
def test_unregistered_model_exc():

0 commit comments

Comments
 (0)