|
34 | 34 | from pyneo4j_ogm.fields.relationship_property import RelationshipProperty |
35 | 35 | from pyneo4j_ogm.fields.settings import BaseModelSettings |
36 | 36 | from pyneo4j_ogm.logger import logger |
37 | | -from pyneo4j_ogm.pydantic_utils import IS_PYDANTIC_V2, get_model_dump |
| 37 | +from pyneo4j_ogm.pydantic_utils import ( |
| 38 | + IS_PYDANTIC_V2, |
| 39 | + get_field_type, |
| 40 | + get_model_dump, |
| 41 | + get_model_fields, |
| 42 | +) |
38 | 43 | from pyneo4j_ogm.queries.query_builder import QueryBuilder |
39 | 44 |
|
40 | 45 | if TYPE_CHECKING: |
|
48 | 53 |
|
49 | 54 | if IS_PYDANTIC_V2: |
50 | 55 | from pydantic import SerializationInfo, model_serializer, model_validator |
| 56 | + from pydantic.config import JsonDict |
51 | 57 | from pydantic.json_schema import GenerateJsonSchema |
52 | 58 | else: |
53 | 59 | from pydantic.class_validators import root_validator |
@@ -216,6 +222,30 @@ def model_json_schema(cls, *args, **kwargs) -> Dict[str, Any]: |
216 | 222 | kwargs.setdefault("schema_generator", CustomGenerateJsonSchema) |
217 | 223 | return super().model_json_schema(*args, **kwargs) |
218 | 224 |
|
| 225 | + # Pydantic does not initialize either `__fields__` or `model_fields` in the __init_subclass__ |
| 226 | + # method anymore in V2, thus we have to call this logic here as well |
| 227 | + @classmethod |
| 228 | + def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: |
| 229 | + super().__pydantic_init_subclass__(**kwargs) |
| 230 | + |
| 231 | + for _, field in get_model_fields(cls).items(): |
| 232 | + point_index = getattr(get_field_type(field), "_point_index", False) |
| 233 | + range_index = getattr(get_field_type(field), "_range_index", False) |
| 234 | + text_index = getattr(get_field_type(field), "_text_index", False) |
| 235 | + unique = getattr(get_field_type(field), "_unique", False) |
| 236 | + |
| 237 | + if field.json_schema_extra is None: |
| 238 | + field.json_schema_extra = {} |
| 239 | + |
| 240 | + if point_index: |
| 241 | + cast(JsonDict, field.json_schema_extra)["point_index"] = True |
| 242 | + if range_index: |
| 243 | + cast(JsonDict, field.json_schema_extra)["range_index"] = True |
| 244 | + if text_index: |
| 245 | + cast(JsonDict, field.json_schema_extra)["text_index"] = True |
| 246 | + if unique: |
| 247 | + cast(JsonDict, field.json_schema_extra)["uniqueness_constraint"] = True |
| 248 | + |
219 | 249 | else: |
220 | 250 |
|
221 | 251 | @root_validator |
@@ -356,6 +386,22 @@ def __init_subclass__(cls, *args, **kwargs) -> None: |
356 | 386 | else: |
357 | 387 | setattr(cls._settings, setting, value) |
358 | 388 |
|
| 389 | + if not IS_PYDANTIC_V2: |
| 390 | + for _, field in get_model_fields(cls).items(): |
| 391 | + point_index = getattr(get_field_type(field), "_point_index", False) |
| 392 | + range_index = getattr(get_field_type(field), "_range_index", False) |
| 393 | + text_index = getattr(get_field_type(field), "_text_index", False) |
| 394 | + unique = getattr(get_field_type(field), "_unique", False) |
| 395 | + |
| 396 | + if point_index: |
| 397 | + field.field_info.extra["point_index"] = True |
| 398 | + if range_index: |
| 399 | + field.field_info.extra["range_index"] = True |
| 400 | + if text_index: |
| 401 | + field.field_info.extra["text_index"] = True |
| 402 | + if unique: |
| 403 | + field.field_info.extra["uniqueness_constraint"] = True |
| 404 | + |
359 | 405 | super().__init_subclass__(*args, **kwargs) |
360 | 406 |
|
361 | 407 | def __eq__(self, other: Any) -> bool: |
|
0 commit comments