Skip to content

Commit 2a64464

Browse files
author
matmoncon
committed
fix: update CustomGenerateJsonSchema class to add index/constraint info
1 parent c7aa6b8 commit 2a64464

File tree

1 file changed

+33
-29
lines changed

1 file changed

+33
-29
lines changed

pyneo4j_ogm/core/base.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353

5454
if IS_PYDANTIC_V2:
5555
from pydantic import SerializationInfo, model_serializer, model_validator
56-
from pydantic.config import JsonDict
5756
from pydantic.json_schema import GenerateJsonSchema
5857
else:
5958
from pydantic.class_validators import root_validator
@@ -155,6 +154,35 @@ def wrapper(self, *args, **kwargs):
155154
if IS_PYDANTIC_V2:
156155

157156
class CustomGenerateJsonSchema(GenerateJsonSchema):
157+
"""
158+
Custom JSON schema generator which adds support for generating JSON schemas for `RelationshipProperty` fields
159+
and adds index and uniqueness constraint information to the generated schema.
160+
"""
161+
162+
def generate(self, *args, **kwargs):
163+
model_cls = cast(Type[BaseModel], args[0]["schema"]["cls"])
164+
generated_schema = super().generate(*args, **kwargs)
165+
166+
for field_name, field in get_model_fields(model_cls).items():
167+
point_index = getattr(get_field_type(field), "_point_index", False)
168+
range_index = getattr(get_field_type(field), "_range_index", False)
169+
text_index = getattr(get_field_type(field), "_text_index", False)
170+
unique = getattr(get_field_type(field), "_unique", False)
171+
172+
if field_name not in generated_schema["properties"]:
173+
continue
174+
175+
if point_index:
176+
generated_schema["properties"][field_name]["point_index"] = True
177+
if range_index:
178+
generated_schema["properties"][field_name]["range_index"] = True
179+
if text_index:
180+
generated_schema["properties"][field_name]["text_index"] = True
181+
if unique:
182+
generated_schema["properties"][field_name]["uniqueness_constraint"] = True
183+
184+
return generated_schema
185+
158186
def encode_default(self, dft: Any) -> Any:
159187
if isinstance(dft, RelationshipProperty):
160188
dft = str(dft)
@@ -222,30 +250,6 @@ def model_json_schema(cls, *args, **kwargs) -> Dict[str, Any]:
222250
kwargs.setdefault("schema_generator", CustomGenerateJsonSchema)
223251
return super().model_json_schema(*args, **kwargs)
224252

225-
# Pydantic does not initialize either `__fields__` or `model_fields` in the __init_subclass__
226-
# method anymore in V2, thus we have to call this logic here as well
227-
@classmethod
228-
def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
229-
super().__pydantic_init_subclass__(**kwargs)
230-
231-
for _, field in get_model_fields(cls).items():
232-
point_index = getattr(get_field_type(field), "_point_index", False)
233-
range_index = getattr(get_field_type(field), "_range_index", False)
234-
text_index = getattr(get_field_type(field), "_text_index", False)
235-
unique = getattr(get_field_type(field), "_unique", False)
236-
237-
if field.json_schema_extra is None:
238-
field.json_schema_extra = {}
239-
240-
if point_index:
241-
cast(JsonDict, field.json_schema_extra)["point_index"] = True
242-
if range_index:
243-
cast(JsonDict, field.json_schema_extra)["range_index"] = True
244-
if text_index:
245-
cast(JsonDict, field.json_schema_extra)["text_index"] = True
246-
if unique:
247-
cast(JsonDict, field.json_schema_extra)["uniqueness_constraint"] = True
248-
249253
else:
250254

251255
@root_validator
@@ -394,13 +398,13 @@ def __init_subclass__(cls, *args, **kwargs) -> None:
394398
unique = getattr(get_field_type(field), "_unique", False)
395399

396400
if point_index:
397-
field.field_info.extra["point_index"] = True
401+
field.field_info.extra["point_index"] = True # type: ignore
398402
if range_index:
399-
field.field_info.extra["range_index"] = True
403+
field.field_info.extra["range_index"] = True # type: ignore
400404
if text_index:
401-
field.field_info.extra["text_index"] = True
405+
field.field_info.extra["text_index"] = True # type: ignore
402406
if unique:
403-
field.field_info.extra["uniqueness_constraint"] = True
407+
field.field_info.extra["uniqueness_constraint"] = True # type: ignore
404408

405409
super().__init_subclass__(*args, **kwargs)
406410

0 commit comments

Comments
 (0)