Skip to content

Commit b8c479d

Browse files
author
matmoncon
committed
docs: add comments and update docstrings
1 parent 27c6a6e commit b8c479d

File tree

1 file changed

+74
-14
lines changed

1 file changed

+74
-14
lines changed

pyneo4j_ogm/core/base.py

Lines changed: 74 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
"""
2-
Base class for both `NodeModel` and `RelationshipModel`. This class handles shared logic for both
3-
model types like registering hooks and exporting/importing models to/from dictionaries.
2+
Base class for both `NodeModel` and `RelationshipModel`. This class handles shared logic between the two model types
3+
and defines common serializers/validators used for Pydantic models.
44
"""
55

6-
# pyright: reportUnboundVariable=false
7-
86
import asyncio
97
import json
108
from asyncio import iscoroutinefunction
@@ -78,8 +76,9 @@
7876

7977
def hooks(func):
8078
"""
81-
Decorator which runs defined pre- and post hooks for the decorated method. The decorator expects the
82-
hooks to have the name of the decorated method. Both synchronous and asynchronous hooks are supported.
79+
Calls all defined hooks for the decorated method. Pre-hooks are called before the method is executed and they
80+
receive the same arguments as the decorated method. Post-hooks are called after the method is executed and they
81+
receive the same arguments as a pre-hook, but with the result as the second argument.
8382
8483
Args:
8584
func (Callable): The method to decorate.
@@ -95,11 +94,15 @@ async def async_wrapper(self, *args, **kwargs):
9594
settings: BaseModelSettings = getattr(self, "_settings")
9695

9796
if func.__name__ in settings.pre_hooks:
98-
logger.debug(
97+
logger.info(
9998
"Found %s pre-hook functions for method %s", len(settings.pre_hooks[func.__name__]), func.__name__
10099
)
100+
101+
# Run each pre-hook function with the same arguments as the decorated method
101102
for hook_function in settings.pre_hooks[func.__name__]:
102103
logger.debug("Running pre-hook function %s", hook_function.__name__)
104+
105+
# Check if the hook function is asynchronous, if so await it to prevent unawaited coroutine warnings
103106
if iscoroutinefunction(hook_function):
104107
await hook_function(self, *args, **kwargs)
105108
else:
@@ -108,11 +111,16 @@ async def async_wrapper(self, *args, **kwargs):
108111
result = await func(self, *args, **kwargs)
109112

110113
if func.__name__ in settings.post_hooks:
111-
logger.debug(
114+
logger.info(
112115
"Found %s post-hook functions for method %s", len(settings.post_hooks[func.__name__]), func.__name__
113116
)
117+
118+
# Run any post-hook functions with the same arguments as the decorated method and the result
119+
# as the second argument
114120
for hook_function in settings.post_hooks[func.__name__]:
115121
logger.debug("Running post-hook function %s", hook_function.__name__)
122+
123+
# Check again if the hook function is asynchronous and await it if necessary
116124
if iscoroutinefunction(hook_function):
117125
await hook_function(self, result, *args, **kwargs)
118126
else:
@@ -129,9 +137,12 @@ def sync_wrapper(self, *args, **kwargs):
129137
settings: BaseModelSettings = getattr(self, "_settings")
130138

131139
if func.__name__ in settings.pre_hooks:
140+
# Run each pre-hook function with the same arguments as the decorated method
132141
logger.debug(
133142
"Found %s pre-hook functions for method %s", len(settings.pre_hooks[func.__name__]), func.__name__
134143
)
144+
145+
# Check if the hook function is asynchronous, if so create a new task to run it
135146
for hook_function in settings.pre_hooks[func.__name__]:
136147
logger.debug("Running pre-hook function %s", hook_function.__name__)
137148
if iscoroutinefunction(hook_function):
@@ -142,9 +153,12 @@ def sync_wrapper(self, *args, **kwargs):
142153
result = func(self, *args, **kwargs)
143154

144155
if func.__name__ in settings.post_hooks:
156+
# Run any post-hook functions with the same arguments as the decorated method and the result
145157
logger.debug(
146158
"Found %s post-hook functions for method %s", len(settings.post_hooks[func.__name__]), func.__name__
147159
)
160+
161+
# Check again if the hook function is asynchronous and create a new task to run it if necessary
148162
for hook_function in settings.post_hooks[func.__name__]:
149163
logger.debug("Running post-hook function %s", hook_function.__name__)
150164
if iscoroutinefunction(hook_function):
@@ -169,13 +183,16 @@ def generate(self, *args, **kwargs):
169183
model_cls: Optional[Type[BaseModel]] = None
170184

171185
if "definitions" in args[0]:
186+
# If a `definitions` key is present, the JSON schema contains multiple schemas for multiple models
187+
# We need to get the schema ref for the model we want to add the additional information to
172188
schema_ref = args[0]["schema"]["schema_ref"]
173189

174190
for definition in args[0]["definitions"]:
175191
if definition["ref"] == schema_ref and "cls" in definition:
176192
model_cls = cast(Type[BaseModel], definition["cls"])
177193
break
178194
elif "cls" in args[0]:
195+
# If a `cls` key is present, the JSON schema only contains our current model
179196
model_cls = cast(Type[BaseModel], args[0]["cls"])
180197

181198
if model_cls is None:
@@ -184,6 +201,8 @@ def generate(self, *args, **kwargs):
184201
generated_schema = super().generate(*args, **kwargs)
185202

186203
for field_name, field in get_model_fields(model_cls).items():
204+
# Check all fields defined on the model, if we find a index or constraint field, add the information
205+
# to the generated schema
187206
point_index = getattr(get_field_type(field), "_point_index", False)
188207
range_index = getattr(get_field_type(field), "_range_index", False)
189208
text_index = getattr(get_field_type(field), "_text_index", False)
@@ -211,10 +230,11 @@ def encode_default(self, dft: Any) -> Any:
211230

212231
class ModelBase(BaseModel, Generic[V]):
213232
"""
214-
Base class for both `NodeModel` and `RelationshipModel`. This class handles shared logic for both
215-
model types like registering hooks and exporting/importing models to/from dictionaries.
233+
Base class for both `NodeModel` and `RelationshipModel`. This class handles shared logic between the two model types
234+
and defines common serializers/validators used for Pydantic models.
216235
217-
Should not be used directly.
236+
If you come across this class and want to use it in your own models then DON'T. This class is not meant to be used
237+
directly and is only used as a base class for `NodeModel` and `RelationshipModel`.
218238
"""
219239

220240
_settings: BaseModelSettings = PrivateAttr()
@@ -230,11 +250,16 @@ class ModelBase(BaseModel, Generic[V]):
230250

231251
@model_serializer(mode="wrap")
232252
def _model_serializer(self, serializer: Any, info: SerializationInfo) -> Any:
253+
"""
254+
Custom model serializer which adds support for serializing `RelationshipProperty` fields, `element_id` and
255+
`id` fields and any additional fields defined on the model.
256+
"""
233257
if isinstance(self, RelationshipProperty):
234258
return self.nodes
235259

236260
serialized = serializer(self)
237261

262+
# If the field is not excluded and not `None`, add it to the serialized dictionary
238263
if not (self.id is None and info.exclude_none) and not (info.exclude is not None and "id" in info.exclude):
239264
serialized["id"] = self.id
240265

@@ -273,16 +298,23 @@ def _model_serializer(self, serializer: Any, info: SerializationInfo) -> Any:
273298

274299
if hasattr(self, "_relationship_properties"):
275300
for field_name in getattr(self, "_relationship_properties"):
301+
# If a relationship property has been found and it has fetched nodes, we add the nodes to
302+
# the serialized dictionary as well
276303
if field_name in serialized:
277304
serialized[field_name] = cast(RelationshipProperty, getattr(self, field_name)).nodes
278305

279306
return serialized
280307

281308
@model_validator(mode="before") # type: ignore
282309
def _model_validator(cls, values: Any) -> Any:
310+
"""
311+
Custom validation for validating the fetched nodes from a relationship property.
312+
"""
283313
relationship_properties = getattr(cls, "_relationship_properties", set())
284314

285315
for field_name, field in get_model_fields(cls).items():
316+
# Go over each relationship property field and try to build the relationship property for that
317+
# field.
286318
if (
287319
field_name in relationship_properties
288320
and field_name in values
@@ -303,6 +335,8 @@ def _model_validator(cls, values: Any) -> Any:
303335
if target_model is not None:
304336
nodes: List[NodeModel] = []
305337

338+
# Check if the nodes are of the correct model type and if they are alive
339+
# If so add them to the parsed instance
306340
for node in values[field_name]:
307341
if isinstance(node, target_model):
308342
nodes.append(node)
@@ -334,6 +368,8 @@ def _parse_dict_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
334368
relationship_properties = getattr(cls, "_relationship_properties", set())
335369

336370
for field_name, field in get_model_fields(cls).items():
371+
# Go over each relationship property field and try to build the relationship property for that
372+
# field.
337373
if (
338374
field_name in relationship_properties
339375
and field_name in values
@@ -354,6 +390,8 @@ def _parse_dict_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
354390
if target_model is not None:
355391
nodes: List[NodeModel] = []
356392

393+
# Check if the nodes are of the correct model type and if they are alive
394+
# If so add them to the parsed instance
357395
for node in values[field_name]:
358396
if isinstance(node, target_model):
359397
nodes.append(node)
@@ -387,6 +425,8 @@ def dict( # type: ignore
387425
excluded_fields = set()
388426
excluded_fields.update(exclude or set())
389427

428+
# Add all relationship properties to the list of excluded fields since we will handle them
429+
# later on ourself
390430
if hasattr(self, "_relationship_properties"):
391431
excluded_fields.update(cast(Set[str], getattr(self, "_relationship_properties")))
392432

@@ -401,6 +441,7 @@ def dict( # type: ignore
401441
exclude_none=exclude_none,
402442
)
403443

444+
# Add all `element_id` and `id` fields it they have not specifically been excluded
404445
if not (self._id is None and exclude_none) and not (exclude is not None and "id" in exclude):
405446
base_dict["id"] = self._id
406447

@@ -438,6 +479,8 @@ def dict( # type: ignore
438479

439480
if hasattr(self, "_relationship_properties"):
440481
for field_name in getattr(self, "_relationship_properties"):
482+
# Each relationship property gets the fetched nodes when serializing it rather than it's indexes
483+
# or constraints
441484
field = cast(Union[RelationshipProperty, List], getattr(self, field_name))
442485

443486
if not isinstance(field, list) and not (exclude is not None and field_name in exclude):
@@ -482,6 +525,8 @@ def json( # type: ignore
482525

483526
modified_json = json.loads(base_json)
484527

528+
# Add all relationship properties to the list of excluded fields since we will handle them
529+
# later on ourself
485530
if not (self._id is None and exclude_none) and not (exclude is not None and "id" in exclude):
486531
modified_json["id"] = self._id
487532

@@ -519,16 +564,21 @@ def json( # type: ignore
519564

520565
if hasattr(self, "_relationship_properties"):
521566
for field_name in getattr(self, "_relationship_properties"):
567+
# Each relationship property gets the fetched nodes when serializing it rather than it's indexes
568+
# or constraints
522569
field = cast(Union[RelationshipProperty, List], getattr(self, field_name))
523570

524571
if not isinstance(field, list) and not (exclude is not None and field_name in exclude):
572+
# Serialize and then deserialize each model to prevent double serialization when doing the whole
573+
# model at the end
525574
modified_json[field_name] = [
526575
json.loads(cast(Union[RelationshipModel, NodeModel], node).json()) for node in field.nodes
527576
]
528577

529578
return json.dumps(modified_json)
530579

531580
def __init__(self, *args, **kwargs) -> None:
581+
# Check if the models has been registered with a client
532582
if not hasattr(self, "_client"):
533583
raise UnregisteredModel(model=self.__class__.__name__)
534584

@@ -556,6 +606,8 @@ def __init_subclass__(cls, *args, **kwargs) -> None:
556606
text_index = getattr(get_field_type(field), "_text_index", False)
557607
unique = getattr(get_field_type(field), "_unique", False)
558608

609+
# In Pydantic 2.x.x we need to add the index and constraint information to the field's
610+
# `field_info.extra` attribute to make it available in the JSON schema
559611
if point_index:
560612
field.field_info.extra["point_index"] = True # type: ignore
561613
if range_index:
@@ -607,9 +659,9 @@ def register_pre_hooks(
607659
overwrite (bool, optional): Whether to overwrite all defined hook functions if a new hooks function for
608660
the same hook is registered. Defaults to `False`.
609661
"""
662+
logger.info("Registering pre-hook for %s", hook_name)
610663
valid_hook_functions: List[Callable] = []
611664

612-
logger.info("Registering pre-hook for %s", hook_name)
613665
# Normalize hooks to a list of functions
614666
if isinstance(hook_functions, list):
615667
for hook_function in hook_functions:
@@ -622,6 +674,7 @@ def register_pre_hooks(
622674
cls._settings.pre_hooks[hook_name] = []
623675

624676
if overwrite:
677+
# If `overwrite` is set to `True`, we overwrite all existing hook functions for the given hook
625678
logger.debug("Overwriting %s existing pre-hook functions", len(cls._settings.pre_hooks[hook_name]))
626679
cls._settings.pre_hooks[hook_name] = valid_hook_functions
627680
else:
@@ -646,9 +699,9 @@ def register_post_hooks(
646699
overwrite (bool, optional): Whether to overwrite all defined hook functions if a new hooks function for
647700
the same hook is registered. Defaults to `False`.
648701
"""
702+
logger.info("Registering post-hook for %s", hook_name)
649703
valid_hook_functions: List[Callable] = []
650704

651-
logger.info("Registering post-hook for %s", hook_name)
652705
# Normalize hooks to a list of functions
653706
if isinstance(hook_functions, list):
654707
for hook_function in hook_functions:
@@ -661,6 +714,7 @@ def register_post_hooks(
661714
cls._settings.post_hooks[hook_name] = []
662715

663716
if overwrite:
717+
# If `overwrite` is set to `True`, we overwrite all existing hook functions for the given hook
664718
logger.debug("Overwriting %s existing post-hook functions", len(cls._settings.post_hooks[hook_name]))
665719
cls._settings.post_hooks[hook_name] = valid_hook_functions
666720
else:
@@ -732,11 +786,13 @@ def _deflate(self, deflated: Dict[str, Any]) -> Dict[str, Any]:
732786
"""
733787
logger.debug("Deflating model %s to storable dictionary", self)
734788

735-
# Serialize nested BaseModel or dict instances to JSON strings
736789
for field_name, field in deepcopy(deflated).items():
737790
if isinstance(field, (dict, BaseModel)):
791+
# If the field is a dictionary or a Pydantic model, we deflate it by serializing it to a JSON string
738792
deflated[field_name] = json.dumps(field)
739793
elif isinstance(field, list):
794+
# If the field is a list, we deflate it by serializing each item to a JSON string
795+
# This adds the constraint that all items in the list must be encodable to a JSON string
740796
for index, item in enumerate(field):
741797
if not isinstance(item, (int, float, str, bool)):
742798
try:
@@ -770,7 +826,11 @@ def try_property_parsing(property_value: str) -> Union[str, Dict[str, Any], Base
770826

771827
logger.debug("Inflating node %s to model instance", graph_entity)
772828
for node_property in graph_entity.items():
829+
# Inflate each property of the node
830+
# If the property is a JSON string, we try to parse it to a dictionary. If the parsing fails, we know
831+
# that the property is a string and we can use it as is
773832
property_name, property_value = node_property
833+
logger.debug("Inflating property %s with value %s", property_name, property_value)
774834

775835
if isinstance(property_value, str):
776836
inflated[property_name] = try_property_parsing(property_value)

0 commit comments

Comments
 (0)