Skip to content

Commit 8a4b93d

Browse files
author
matmoncon
committed
feat: allow model registration using directory path
1 parent 1636a32 commit 8a4b93d

File tree

10 files changed

+83
-44
lines changed

10 files changed

+83
-44
lines changed

.pylintrc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ unsafe-load-any-extension=no
77

88
[MESSAGES CONTROL]
99
disable=abstract-method,
10+
inconsistent-quotes,
1011
protected-access,
1112
broad-exception-caught,
1213
ungrouped-imports,

pyneo4j_ogm/core/client.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Pyneo4j database client class for running operations on the database.
33
"""
44

5+
import importlib.util
6+
import inspect
57
import os
68
from enum import Enum
79
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, cast
@@ -14,7 +16,6 @@
1416
from pyneo4j_ogm.core.node import NodeModel
1517
from pyneo4j_ogm.core.relationship import RelationshipModel
1618
from pyneo4j_ogm.exceptions import (
17-
AlreadyRegistered,
1819
InvalidBookmark,
1920
InvalidEntityType,
2021
InvalidLabelOrType,
@@ -23,7 +24,6 @@
2324
TransactionInProgress,
2425
UnsupportedNeo4jVersion,
2526
)
26-
from pyneo4j_ogm.fields.settings import NodeModelSettings, RelationshipModelSettings
2727
from pyneo4j_ogm.logger import logger
2828
from pyneo4j_ogm.pydantic_utils import get_field_type, get_model_fields
2929
from pyneo4j_ogm.queries.query_builder import QueryBuilder
@@ -138,6 +138,41 @@ async def connect(
138138
logger.info("Connected to database")
139139
return self
140140

141+
@ensure_connection
142+
async def register_models_dir(self, dir_path: str) -> None:
143+
"""
144+
Registers all models in a directory and all subdirectories.
145+
"""
146+
logger.info("Registering models in directory %s", dir_path)
147+
for root, _, files in os.walk(dir_path):
148+
# Check all files for models
149+
logger.debug("Checking %s files for models", len(files))
150+
for file in files:
151+
if not file.endswith(".py"):
152+
continue
153+
154+
filepath = os.path.join(root, file)
155+
156+
# Load the module
157+
logger.debug("Found file %s, importing", filepath)
158+
module_name = os.path.splitext(os.path.basename(filepath))[0]
159+
spec = importlib.util.spec_from_file_location(module_name, filepath)
160+
161+
if spec is None or spec.loader is None:
162+
raise ImportError(f"Could not import migration file {filepath}")
163+
164+
module = importlib.util.module_from_spec(spec)
165+
spec.loader.exec_module(module)
166+
167+
for member in inspect.getmembers(
168+
module,
169+
lambda x: inspect.isclass(x)
170+
and issubclass(x, (NodeModel, RelationshipModel))
171+
and x is not NodeModel
172+
and x is not RelationshipModel,
173+
):
174+
self.models.add(member[1])
175+
141176
@ensure_connection
142177
async def register_models(self, models: List[Type[Union[NodeModel, RelationshipModel]]]) -> None:
143178
"""
@@ -153,23 +188,6 @@ async def register_models(self, models: List[Type[Union[NodeModel, RelationshipM
153188
logger.info("Registering models %s with client %s", models, self)
154189

155190
for model in models:
156-
for registered_model in self.models:
157-
registered_model_settings = registered_model.model_settings()
158-
model_settings = model.model_settings()
159-
160-
if (
161-
isinstance(registered_model_settings, RelationshipModelSettings)
162-
and isinstance(model_settings, RelationshipModelSettings)
163-
and registered_model_settings.type == model_settings.type
164-
):
165-
raise AlreadyRegistered(cast(str, model_settings.type))
166-
elif (
167-
isinstance(registered_model_settings, NodeModelSettings)
168-
and isinstance(model_settings, NodeModelSettings)
169-
and set(registered_model_settings.labels) == set(model_settings.labels)
170-
):
171-
raise AlreadyRegistered(cast(str, model_settings.labels))
172-
173191
if issubclass(model, (NodeModel, RelationshipModel)):
174192
logger.debug("Found valid mode %s, registering with client", model.__name__)
175193

pyneo4j_ogm/exceptions.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,3 @@ class ListItemNotEncodable(Pyneo4jException):
215215

216216
def __init__(self, *args: object) -> None:
217217
super().__init__("List item is not JSON encodable and can not be stored inside the database", *args)
218-
219-
220-
class AlreadyRegistered(Pyneo4jException):
221-
"""
222-
Multiple models are using the same labels/type as a already registered model.
223-
"""
224-
225-
def __init__(self, labels_or_type: Union[Set[str], str], *args: object) -> None:
226-
received = f"[{', '.join(labels_or_type)}]" if isinstance(labels_or_type, set) else f"[{labels_or_type}]"
227-
label_or_type = "labels" if isinstance(labels_or_type, set) else "type"
228-
super().__init__(f"A model is using the same {label_or_type} as another model. Got {received}", *args)

tests/core/test_client.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from pyneo4j_ogm.core.node import NodeModel
1515
from pyneo4j_ogm.core.relationship import RelationshipModel
1616
from pyneo4j_ogm.exceptions import (
17-
AlreadyRegistered,
1817
InvalidEntityType,
1918
InvalidLabelOrType,
2019
MissingDatabaseURI,
@@ -25,6 +24,12 @@
2524
from pyneo4j_ogm.fields.property_options import WithOptions
2625
from pyneo4j_ogm.logger import logger
2726
from tests.fixtures.db_setup import client, session
27+
from tests.fixtures.models.models_top import ModelOne, ModelTwo
28+
from tests.fixtures.models.nested.deeply_nested.model_deeply_nested import (
29+
ModelFive,
30+
ModelSix,
31+
)
32+
from tests.fixtures.models.nested.model_nested import ModelFour, ModelThree
2833

2934

3035
class CypherResolvingNode(NodeModel):
@@ -41,19 +46,6 @@ class Settings:
4146
type = "TEST_RELATIONSHIP"
4247

4348

44-
async def test_duplicate_model_register(client: Pyneo4jClient):
45-
class TestNode(NodeModel):
46-
name: str
47-
48-
class Settings:
49-
labels = {"TestNode"}
50-
51-
await client.register_models([TestNode])
52-
53-
with pytest.raises(AlreadyRegistered):
54-
await client.register_models([TestNode])
55-
56-
5749
async def test_batch(client: Pyneo4jClient, session: AsyncSession):
5850
async with client.batch():
5951
await client.cypher("CREATE (n:Node) SET n.name = $name", parameters={"name": "TestName"})
@@ -621,3 +613,8 @@ async def test_drop_indexes(client: Pyneo4jClient, session: AsyncSession):
621613
await query_results.consume()
622614

623615
assert len(results) == 0
616+
617+
618+
async def test_register_models_dir(client: Pyneo4jClient):
619+
await client.register_models_dir("tests/fixtures/models")
620+
assert len(client.models) == 6
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from pyneo4j_ogm.core.node import NodeModel
2+
from pyneo4j_ogm.core.relationship import RelationshipModel
3+
4+
5+
class ModelOne(NodeModel):
6+
pass
7+
8+
9+
class ModelTwo(RelationshipModel):
10+
pass
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from pyneo4j_ogm.core.node import NodeModel
2+
from pyneo4j_ogm.core.relationship import RelationshipModel
3+
4+
5+
class ModelFive(NodeModel):
6+
pass
7+
8+
9+
class ModelSix(RelationshipModel):
10+
pass
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class NotMePleaseButDeeplyNested:
2+
pass
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from pyneo4j_ogm.core.node import NodeModel
2+
from pyneo4j_ogm.core.relationship import RelationshipModel
3+
4+
5+
class ModelThree(NodeModel):
6+
pass
7+
8+
9+
class ModelFour(RelationshipModel):
10+
pass

tests/fixtures/models/nested/no_models.py

Whitespace-only changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class NotMePlease:
2+
pass

0 commit comments

Comments
 (0)