@@ -173,6 +173,8 @@ async def register_models_dir(self, dir_path: str) -> None:
173173 ):
174174 self .models .add (member [1 ])
175175
176+ await self ._prepare_registered_models ()
177+
176178 @ensure_connection
177179 async def register_models (self , models : List [Type [Union [NodeModel , RelationshipModel ]]]) -> None :
178180 """
@@ -193,49 +195,8 @@ async def register_models(self, models: List[Type[Union[NodeModel, RelationshipM
193195
194196 # If the model is a valid model, add it to the set of models stored by the client
195197 self .models .add (model )
196- setattr (model , "_client" , self )
197-
198- for property_name , property_definition in get_model_fields (model ).items ():
199- entity_type = EntityType .NODE if issubclass (model , NodeModel ) else EntityType .RELATIONSHIP
200- labels_or_type = (
201- list (getattr (model ._settings , "labels" ))
202- if issubclass (model , NodeModel )
203- else getattr (model ._settings , "type" )
204- )
205198
206- # Check if we need to create any constraints
207- if not self ._skip_constraints :
208- if getattr (get_field_type (property_definition ), "_unique" , False ):
209- await self .create_uniqueness_constraint (
210- name = model .__name__ ,
211- entity_type = entity_type ,
212- properties = [property_name ],
213- labels_or_type = labels_or_type ,
214- )
215-
216- # Check if we need to create any indexes
217- if not self ._skip_indexes :
218- if getattr (get_field_type (property_definition ), "_range_index" , False ):
219- await self .create_range_index (
220- name = model .__name__ ,
221- entity_type = entity_type ,
222- properties = [property_name ],
223- labels_or_type = labels_or_type ,
224- )
225- if getattr (get_field_type (property_definition ), "_point_index" , False ):
226- await self .create_point_index (
227- name = model .__name__ ,
228- entity_type = entity_type ,
229- properties = [property_name ],
230- labels_or_type = labels_or_type ,
231- )
232- if getattr (get_field_type (property_definition ), "_text_index" , False ):
233- await self .create_text_index (
234- name = model .__name__ ,
235- entity_type = entity_type ,
236- properties = [property_name ],
237- labels_or_type = labels_or_type ,
238- )
199+ await self ._prepare_registered_models ()
239200
240201 @ensure_connection
241202 async def close (self ) -> None :
@@ -782,6 +743,56 @@ def _resolve_database_model(self, query_result: Any) -> Optional[Any]:
782743 logger .debug ("Query result %s is not a node, relationship, or path, skipping" , type (query_result ))
783744 return None
784745
746+ async def _prepare_registered_models (self ) -> None :
747+ """
748+ Prepares the registered models by setting the client and creating all indexes and constraints.
749+ """
750+
751+ for model in self .models :
752+ setattr (model , "_client" , self )
753+
754+ for property_name , property_definition in get_model_fields (model ).items ():
755+ entity_type = EntityType .NODE if issubclass (model , NodeModel ) else EntityType .RELATIONSHIP
756+ labels_or_type = (
757+ list (getattr (model ._settings , "labels" ))
758+ if issubclass (model , NodeModel )
759+ else getattr (model ._settings , "type" )
760+ )
761+
762+ # Check if we need to create any constraints
763+ if not self ._skip_constraints :
764+ if getattr (get_field_type (property_definition ), "_unique" , False ):
765+ await self .create_uniqueness_constraint (
766+ name = model .__name__ ,
767+ entity_type = entity_type ,
768+ properties = [property_name ],
769+ labels_or_type = labels_or_type ,
770+ )
771+
772+ # Check if we need to create any indexes
773+ if not self ._skip_indexes :
774+ if getattr (get_field_type (property_definition ), "_range_index" , False ):
775+ await self .create_range_index (
776+ name = model .__name__ ,
777+ entity_type = entity_type ,
778+ properties = [property_name ],
779+ labels_or_type = labels_or_type ,
780+ )
781+ if getattr (get_field_type (property_definition ), "_point_index" , False ):
782+ await self .create_point_index (
783+ name = model .__name__ ,
784+ entity_type = entity_type ,
785+ properties = [property_name ],
786+ labels_or_type = labels_or_type ,
787+ )
788+ if getattr (get_field_type (property_definition ), "_text_index" , False ):
789+ await self .create_text_index (
790+ name = model .__name__ ,
791+ entity_type = entity_type ,
792+ properties = [property_name ],
793+ labels_or_type = labels_or_type ,
794+ )
795+
785796 @property
786797 def is_connected (self ) -> bool :
787798 """
0 commit comments