11"""
2- Base class for both `NodeModel` and `RelationshipModel`. This class handles shared logic for both
3- model types like registering hooks and exporting/importing models to/from dictionaries .
2+ Base class for both `NodeModel` and `RelationshipModel`. This class handles shared logic between the two model types
3+ and defines common serializers/validators used for Pydantic models.
44"""
55
6- # pyright: reportUnboundVariable=false
7-
86import asyncio
97import json
108from asyncio import iscoroutinefunction
7876
7977def hooks (func ):
8078 """
81- Decorator which runs defined pre- and post hooks for the decorated method. The decorator expects the
82- hooks to have the name of the decorated method. Both synchronous and asynchronous hooks are supported.
79+ Calls all defined hooks for the decorated method. Pre-hooks are called before the method is executed and they
80+ receive the same arguments as the decorated method. Post-hooks are called after the method is executed and they
81+ receive the same arguments as a pre-hook, but with the result as the second argument.
8382
8483 Args:
8584 func (Callable): The method to decorate.
@@ -95,11 +94,15 @@ async def async_wrapper(self, *args, **kwargs):
9594 settings : BaseModelSettings = getattr (self , "_settings" )
9695
9796 if func .__name__ in settings .pre_hooks :
98- logger .debug (
97+ logger .info (
9998 "Found %s pre-hook functions for method %s" , len (settings .pre_hooks [func .__name__ ]), func .__name__
10099 )
100+
101+ # Run each pre-hook function with the same arguments as the decorated method
101102 for hook_function in settings .pre_hooks [func .__name__ ]:
102103 logger .debug ("Running pre-hook function %s" , hook_function .__name__ )
104+
105+ # Check if the hook function is asynchronous, if so await it to prevent unawaited coroutine warnings
103106 if iscoroutinefunction (hook_function ):
104107 await hook_function (self , * args , ** kwargs )
105108 else :
@@ -108,11 +111,16 @@ async def async_wrapper(self, *args, **kwargs):
108111 result = await func (self , * args , ** kwargs )
109112
110113 if func .__name__ in settings .post_hooks :
111- logger .debug (
114+ logger .info (
112115 "Found %s post-hook functions for method %s" , len (settings .post_hooks [func .__name__ ]), func .__name__
113116 )
117+
118+ # Run any post-hook functions with the same arguments as the decorated method and the result
119+ # as the second argument
114120 for hook_function in settings .post_hooks [func .__name__ ]:
115121 logger .debug ("Running post-hook function %s" , hook_function .__name__ )
122+
123+ # Check again if the hook function is asynchronous and await it if necessary
116124 if iscoroutinefunction (hook_function ):
117125 await hook_function (self , result , * args , ** kwargs )
118126 else :
@@ -129,9 +137,12 @@ def sync_wrapper(self, *args, **kwargs):
129137 settings : BaseModelSettings = getattr (self , "_settings" )
130138
131139 if func .__name__ in settings .pre_hooks :
140+ # Run each pre-hook function with the same arguments as the decorated method
132141 logger .debug (
133142 "Found %s pre-hook functions for method %s" , len (settings .pre_hooks [func .__name__ ]), func .__name__
134143 )
144+
145+ # Check if the hook function is asynchronous, if so create a new task to run it
135146 for hook_function in settings .pre_hooks [func .__name__ ]:
136147 logger .debug ("Running pre-hook function %s" , hook_function .__name__ )
137148 if iscoroutinefunction (hook_function ):
@@ -142,9 +153,12 @@ def sync_wrapper(self, *args, **kwargs):
142153 result = func (self , * args , ** kwargs )
143154
144155 if func .__name__ in settings .post_hooks :
156+ # Run any post-hook functions with the same arguments as the decorated method and the result
145157 logger .debug (
146158 "Found %s post-hook functions for method %s" , len (settings .post_hooks [func .__name__ ]), func .__name__
147159 )
160+
161+ # Check again if the hook function is asynchronous and create a new task to run it if necessary
148162 for hook_function in settings .post_hooks [func .__name__ ]:
149163 logger .debug ("Running post-hook function %s" , hook_function .__name__ )
150164 if iscoroutinefunction (hook_function ):
@@ -169,13 +183,16 @@ def generate(self, *args, **kwargs):
169183 model_cls : Optional [Type [BaseModel ]] = None
170184
171185 if "definitions" in args [0 ]:
186+ # If a `definitions` key is present, the JSON schema contains multiple schemas for multiple models
187+ # We need to get the schema ref for the model we want to add the additional information to
172188 schema_ref = args [0 ]["schema" ]["schema_ref" ]
173189
174190 for definition in args [0 ]["definitions" ]:
175191 if definition ["ref" ] == schema_ref and "cls" in definition :
176192 model_cls = cast (Type [BaseModel ], definition ["cls" ])
177193 break
178194 elif "cls" in args [0 ]:
195+ # If a `cls` key is present, the JSON schema only contains our current model
179196 model_cls = cast (Type [BaseModel ], args [0 ]["cls" ])
180197
181198 if model_cls is None :
@@ -184,6 +201,8 @@ def generate(self, *args, **kwargs):
184201 generated_schema = super ().generate (* args , ** kwargs )
185202
186203 for field_name , field in get_model_fields (model_cls ).items ():
204+ # Check all fields defined on the model, if we find a index or constraint field, add the information
205+ # to the generated schema
187206 point_index = getattr (get_field_type (field ), "_point_index" , False )
188207 range_index = getattr (get_field_type (field ), "_range_index" , False )
189208 text_index = getattr (get_field_type (field ), "_text_index" , False )
@@ -211,10 +230,11 @@ def encode_default(self, dft: Any) -> Any:
211230
212231class ModelBase (BaseModel , Generic [V ]):
213232 """
214- Base class for both `NodeModel` and `RelationshipModel`. This class handles shared logic for both
215- model types like registering hooks and exporting/importing models to/from dictionaries .
233+ Base class for both `NodeModel` and `RelationshipModel`. This class handles shared logic between the two model types
234+ and defines common serializers/validators used for Pydantic models.
216235
217- Should not be used directly.
236+ If you come across this class and want to use it in your own models then DON'T. This class is not meant to be used
237+ directly and is only used as a base class for `NodeModel` and `RelationshipModel`.
218238 """
219239
220240 _settings : BaseModelSettings = PrivateAttr ()
@@ -230,11 +250,16 @@ class ModelBase(BaseModel, Generic[V]):
230250
231251 @model_serializer (mode = "wrap" )
232252 def _model_serializer (self , serializer : Any , info : SerializationInfo ) -> Any :
253+ """
254+ Custom model serializer which adds support for serializing `RelationshipProperty` fields, `element_id` and
255+ `id` fields and any additional fields defined on the model.
256+ """
233257 if isinstance (self , RelationshipProperty ):
234258 return self .nodes
235259
236260 serialized = serializer (self )
237261
262+ # If the field is not excluded and not `None`, add it to the serialized dictionary
238263 if not (self .id is None and info .exclude_none ) and not (info .exclude is not None and "id" in info .exclude ):
239264 serialized ["id" ] = self .id
240265
@@ -273,16 +298,23 @@ def _model_serializer(self, serializer: Any, info: SerializationInfo) -> Any:
273298
274299 if hasattr (self , "_relationship_properties" ):
275300 for field_name in getattr (self , "_relationship_properties" ):
301+ # If a relationship property has been found and it has fetched nodes, we add the nodes to
302+ # the serialized dictionary as well
276303 if field_name in serialized :
277304 serialized [field_name ] = cast (RelationshipProperty , getattr (self , field_name )).nodes
278305
279306 return serialized
280307
281308 @model_validator (mode = "before" ) # type: ignore
282309 def _model_validator (cls , values : Any ) -> Any :
310+ """
311+ Custom validation for validating the fetched nodes from a relationship property.
312+ """
283313 relationship_properties = getattr (cls , "_relationship_properties" , set ())
284314
285315 for field_name , field in get_model_fields (cls ).items ():
316+ # Go over each relationship property field and try to build the relationship property for that
317+ # field.
286318 if (
287319 field_name in relationship_properties
288320 and field_name in values
@@ -303,6 +335,8 @@ def _model_validator(cls, values: Any) -> Any:
303335 if target_model is not None :
304336 nodes : List [NodeModel ] = []
305337
338+ # Check if the nodes are of the correct model type and if they are alive
339+ # If so add them to the parsed instance
306340 for node in values [field_name ]:
307341 if isinstance (node , target_model ):
308342 nodes .append (node )
@@ -334,6 +368,8 @@ def _parse_dict_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
334368 relationship_properties = getattr (cls , "_relationship_properties" , set ())
335369
336370 for field_name , field in get_model_fields (cls ).items ():
371+ # Go over each relationship property field and try to build the relationship property for that
372+ # field.
337373 if (
338374 field_name in relationship_properties
339375 and field_name in values
@@ -354,6 +390,8 @@ def _parse_dict_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
354390 if target_model is not None :
355391 nodes : List [NodeModel ] = []
356392
393+ # Check if the nodes are of the correct model type and if they are alive
394+ # If so add them to the parsed instance
357395 for node in values [field_name ]:
358396 if isinstance (node , target_model ):
359397 nodes .append (node )
@@ -387,6 +425,8 @@ def dict( # type: ignore
387425 excluded_fields = set ()
388426 excluded_fields .update (exclude or set ())
389427
428+ # Add all relationship properties to the list of excluded fields since we will handle them
429+ # later on ourself
390430 if hasattr (self , "_relationship_properties" ):
391431 excluded_fields .update (cast (Set [str ], getattr (self , "_relationship_properties" )))
392432
@@ -401,6 +441,7 @@ def dict( # type: ignore
401441 exclude_none = exclude_none ,
402442 )
403443
444+ # Add all `element_id` and `id` fields it they have not specifically been excluded
404445 if not (self ._id is None and exclude_none ) and not (exclude is not None and "id" in exclude ):
405446 base_dict ["id" ] = self ._id
406447
@@ -438,6 +479,8 @@ def dict( # type: ignore
438479
439480 if hasattr (self , "_relationship_properties" ):
440481 for field_name in getattr (self , "_relationship_properties" ):
482+ # Each relationship property gets the fetched nodes when serializing it rather than it's indexes
483+ # or constraints
441484 field = cast (Union [RelationshipProperty , List ], getattr (self , field_name ))
442485
443486 if not isinstance (field , list ) and not (exclude is not None and field_name in exclude ):
@@ -482,6 +525,8 @@ def json( # type: ignore
482525
483526 modified_json = json .loads (base_json )
484527
528+ # Add all relationship properties to the list of excluded fields since we will handle them
529+ # later on ourself
485530 if not (self ._id is None and exclude_none ) and not (exclude is not None and "id" in exclude ):
486531 modified_json ["id" ] = self ._id
487532
@@ -519,16 +564,21 @@ def json( # type: ignore
519564
520565 if hasattr (self , "_relationship_properties" ):
521566 for field_name in getattr (self , "_relationship_properties" ):
567+ # Each relationship property gets the fetched nodes when serializing it rather than it's indexes
568+ # or constraints
522569 field = cast (Union [RelationshipProperty , List ], getattr (self , field_name ))
523570
524571 if not isinstance (field , list ) and not (exclude is not None and field_name in exclude ):
572+ # Serialize and then deserialize each model to prevent double serialization when doing the whole
573+ # model at the end
525574 modified_json [field_name ] = [
526575 json .loads (cast (Union [RelationshipModel , NodeModel ], node ).json ()) for node in field .nodes
527576 ]
528577
529578 return json .dumps (modified_json )
530579
531580 def __init__ (self , * args , ** kwargs ) -> None :
581+ # Check if the models has been registered with a client
532582 if not hasattr (self , "_client" ):
533583 raise UnregisteredModel (model = self .__class__ .__name__ )
534584
@@ -556,6 +606,8 @@ def __init_subclass__(cls, *args, **kwargs) -> None:
556606 text_index = getattr (get_field_type (field ), "_text_index" , False )
557607 unique = getattr (get_field_type (field ), "_unique" , False )
558608
609+ # In Pydantic 2.x.x we need to add the index and constraint information to the field's
610+ # `field_info.extra` attribute to make it available in the JSON schema
559611 if point_index :
560612 field .field_info .extra ["point_index" ] = True # type: ignore
561613 if range_index :
@@ -607,9 +659,9 @@ def register_pre_hooks(
607659 overwrite (bool, optional): Whether to overwrite all defined hook functions if a new hooks function for
608660 the same hook is registered. Defaults to `False`.
609661 """
662+ logger .info ("Registering pre-hook for %s" , hook_name )
610663 valid_hook_functions : List [Callable ] = []
611664
612- logger .info ("Registering pre-hook for %s" , hook_name )
613665 # Normalize hooks to a list of functions
614666 if isinstance (hook_functions , list ):
615667 for hook_function in hook_functions :
@@ -622,6 +674,7 @@ def register_pre_hooks(
622674 cls ._settings .pre_hooks [hook_name ] = []
623675
624676 if overwrite :
677+ # If `overwrite` is set to `True`, we overwrite all existing hook functions for the given hook
625678 logger .debug ("Overwriting %s existing pre-hook functions" , len (cls ._settings .pre_hooks [hook_name ]))
626679 cls ._settings .pre_hooks [hook_name ] = valid_hook_functions
627680 else :
@@ -646,9 +699,9 @@ def register_post_hooks(
646699 overwrite (bool, optional): Whether to overwrite all defined hook functions if a new hooks function for
647700 the same hook is registered. Defaults to `False`.
648701 """
702+ logger .info ("Registering post-hook for %s" , hook_name )
649703 valid_hook_functions : List [Callable ] = []
650704
651- logger .info ("Registering post-hook for %s" , hook_name )
652705 # Normalize hooks to a list of functions
653706 if isinstance (hook_functions , list ):
654707 for hook_function in hook_functions :
@@ -661,6 +714,7 @@ def register_post_hooks(
661714 cls ._settings .post_hooks [hook_name ] = []
662715
663716 if overwrite :
717+ # If `overwrite` is set to `True`, we overwrite all existing hook functions for the given hook
664718 logger .debug ("Overwriting %s existing post-hook functions" , len (cls ._settings .post_hooks [hook_name ]))
665719 cls ._settings .post_hooks [hook_name ] = valid_hook_functions
666720 else :
@@ -732,11 +786,13 @@ def _deflate(self, deflated: Dict[str, Any]) -> Dict[str, Any]:
732786 """
733787 logger .debug ("Deflating model %s to storable dictionary" , self )
734788
735- # Serialize nested BaseModel or dict instances to JSON strings
736789 for field_name , field in deepcopy (deflated ).items ():
737790 if isinstance (field , (dict , BaseModel )):
791+ # If the field is a dictionary or a Pydantic model, we deflate it by serializing it to a JSON string
738792 deflated [field_name ] = json .dumps (field )
739793 elif isinstance (field , list ):
794+ # If the field is a list, we deflate it by serializing each item to a JSON string
795+ # This adds the constraint that all items in the list must be encodable to a JSON string
740796 for index , item in enumerate (field ):
741797 if not isinstance (item , (int , float , str , bool )):
742798 try :
@@ -770,7 +826,11 @@ def try_property_parsing(property_value: str) -> Union[str, Dict[str, Any], Base
770826
771827 logger .debug ("Inflating node %s to model instance" , graph_entity )
772828 for node_property in graph_entity .items ():
829+ # Inflate each property of the node
830+ # If the property is a JSON string, we try to parse it to a dictionary. If the parsing fails, we know
831+ # that the property is a string and we can use it as is
773832 property_name , property_value = node_property
833+ logger .debug ("Inflating property %s with value %s" , property_name , property_value )
774834
775835 if isinstance (property_value , str ):
776836 inflated [property_name ] = try_property_parsing (property_value )
0 commit comments