Skip to content

Commit 50c4c04

Browse files
author
matmoncon
committed
fix: prepare models correctly when using directory
1 parent 4c731c4 commit 50c4c04

File tree

1 file changed

+53
-42
lines changed

1 file changed

+53
-42
lines changed

pyneo4j_ogm/core/client.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ async def register_models_dir(self, dir_path: str) -> None:
173173
):
174174
self.models.add(member[1])
175175

176+
await self._prepare_registered_models()
177+
176178
@ensure_connection
177179
async def register_models(self, models: List[Type[Union[NodeModel, RelationshipModel]]]) -> None:
178180
"""
@@ -193,49 +195,8 @@ async def register_models(self, models: List[Type[Union[NodeModel, RelationshipM
193195

194196
# If the model is a valid model, add it to the set of models stored by the client
195197
self.models.add(model)
196-
setattr(model, "_client", self)
197-
198-
for property_name, property_definition in get_model_fields(model).items():
199-
entity_type = EntityType.NODE if issubclass(model, NodeModel) else EntityType.RELATIONSHIP
200-
labels_or_type = (
201-
list(getattr(model._settings, "labels"))
202-
if issubclass(model, NodeModel)
203-
else getattr(model._settings, "type")
204-
)
205198

206-
# Check if we need to create any constraints
207-
if not self._skip_constraints:
208-
if getattr(get_field_type(property_definition), "_unique", False):
209-
await self.create_uniqueness_constraint(
210-
name=model.__name__,
211-
entity_type=entity_type,
212-
properties=[property_name],
213-
labels_or_type=labels_or_type,
214-
)
215-
216-
# Check if we need to create any indexes
217-
if not self._skip_indexes:
218-
if getattr(get_field_type(property_definition), "_range_index", False):
219-
await self.create_range_index(
220-
name=model.__name__,
221-
entity_type=entity_type,
222-
properties=[property_name],
223-
labels_or_type=labels_or_type,
224-
)
225-
if getattr(get_field_type(property_definition), "_point_index", False):
226-
await self.create_point_index(
227-
name=model.__name__,
228-
entity_type=entity_type,
229-
properties=[property_name],
230-
labels_or_type=labels_or_type,
231-
)
232-
if getattr(get_field_type(property_definition), "_text_index", False):
233-
await self.create_text_index(
234-
name=model.__name__,
235-
entity_type=entity_type,
236-
properties=[property_name],
237-
labels_or_type=labels_or_type,
238-
)
199+
await self._prepare_registered_models()
239200

240201
@ensure_connection
241202
async def close(self) -> None:
@@ -782,6 +743,56 @@ def _resolve_database_model(self, query_result: Any) -> Optional[Any]:
782743
logger.debug("Query result %s is not a node, relationship, or path, skipping", type(query_result))
783744
return None
784745

746+
async def _prepare_registered_models(self) -> None:
747+
"""
748+
Prepares the registered models by setting the client and creating all indexes and constraints.
749+
"""
750+
751+
for model in self.models:
752+
setattr(model, "_client", self)
753+
754+
for property_name, property_definition in get_model_fields(model).items():
755+
entity_type = EntityType.NODE if issubclass(model, NodeModel) else EntityType.RELATIONSHIP
756+
labels_or_type = (
757+
list(getattr(model._settings, "labels"))
758+
if issubclass(model, NodeModel)
759+
else getattr(model._settings, "type")
760+
)
761+
762+
# Check if we need to create any constraints
763+
if not self._skip_constraints:
764+
if getattr(get_field_type(property_definition), "_unique", False):
765+
await self.create_uniqueness_constraint(
766+
name=model.__name__,
767+
entity_type=entity_type,
768+
properties=[property_name],
769+
labels_or_type=labels_or_type,
770+
)
771+
772+
# Check if we need to create any indexes
773+
if not self._skip_indexes:
774+
if getattr(get_field_type(property_definition), "_range_index", False):
775+
await self.create_range_index(
776+
name=model.__name__,
777+
entity_type=entity_type,
778+
properties=[property_name],
779+
labels_or_type=labels_or_type,
780+
)
781+
if getattr(get_field_type(property_definition), "_point_index", False):
782+
await self.create_point_index(
783+
name=model.__name__,
784+
entity_type=entity_type,
785+
properties=[property_name],
786+
labels_or_type=labels_or_type,
787+
)
788+
if getattr(get_field_type(property_definition), "_text_index", False):
789+
await self.create_text_index(
790+
name=model.__name__,
791+
entity_type=entity_type,
792+
properties=[property_name],
793+
labels_or_type=labels_or_type,
794+
)
795+
785796
@property
786797
def is_connected(self) -> bool:
787798
"""

0 commit comments

Comments
 (0)