|
53 | 53 |
|
54 | 54 | if IS_PYDANTIC_V2: |
55 | 55 | from pydantic import SerializationInfo, model_serializer, model_validator |
56 | | - from pydantic.config import JsonDict |
57 | 56 | from pydantic.json_schema import GenerateJsonSchema |
58 | 57 | else: |
59 | 58 | from pydantic.class_validators import root_validator |
@@ -155,6 +154,35 @@ def wrapper(self, *args, **kwargs): |
155 | 154 | if IS_PYDANTIC_V2: |
156 | 155 |
|
157 | 156 | class CustomGenerateJsonSchema(GenerateJsonSchema): |
| 157 | + """ |
| 158 | + Custom JSON schema generator which adds support for generating JSON schemas for `RelationshipProperty` fields |
| 159 | + and adds index and uniqueness constraint information to the generated schema. |
| 160 | + """ |
| 161 | + |
| 162 | + def generate(self, *args, **kwargs): |
| 163 | + model_cls = cast(Type[BaseModel], args[0]["schema"]["cls"]) |
| 164 | + generated_schema = super().generate(*args, **kwargs) |
| 165 | + |
| 166 | + for field_name, field in get_model_fields(model_cls).items(): |
| 167 | + point_index = getattr(get_field_type(field), "_point_index", False) |
| 168 | + range_index = getattr(get_field_type(field), "_range_index", False) |
| 169 | + text_index = getattr(get_field_type(field), "_text_index", False) |
| 170 | + unique = getattr(get_field_type(field), "_unique", False) |
| 171 | + |
| 172 | + if field_name not in generated_schema["properties"]: |
| 173 | + continue |
| 174 | + |
| 175 | + if point_index: |
| 176 | + generated_schema["properties"][field_name]["point_index"] = True |
| 177 | + if range_index: |
| 178 | + generated_schema["properties"][field_name]["range_index"] = True |
| 179 | + if text_index: |
| 180 | + generated_schema["properties"][field_name]["text_index"] = True |
| 181 | + if unique: |
| 182 | + generated_schema["properties"][field_name]["uniqueness_constraint"] = True |
| 183 | + |
| 184 | + return generated_schema |
| 185 | + |
158 | 186 | def encode_default(self, dft: Any) -> Any: |
159 | 187 | if isinstance(dft, RelationshipProperty): |
160 | 188 | dft = str(dft) |
@@ -222,30 +250,6 @@ def model_json_schema(cls, *args, **kwargs) -> Dict[str, Any]: |
222 | 250 | kwargs.setdefault("schema_generator", CustomGenerateJsonSchema) |
223 | 251 | return super().model_json_schema(*args, **kwargs) |
224 | 252 |
|
225 | | - # Pydantic does not initialize either `__fields__` or `model_fields` in the __init_subclass__ |
226 | | - # method anymore in V2, thus we have to call this logic here as well |
227 | | - @classmethod |
228 | | - def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: |
229 | | - super().__pydantic_init_subclass__(**kwargs) |
230 | | - |
231 | | - for _, field in get_model_fields(cls).items(): |
232 | | - point_index = getattr(get_field_type(field), "_point_index", False) |
233 | | - range_index = getattr(get_field_type(field), "_range_index", False) |
234 | | - text_index = getattr(get_field_type(field), "_text_index", False) |
235 | | - unique = getattr(get_field_type(field), "_unique", False) |
236 | | - |
237 | | - if field.json_schema_extra is None: |
238 | | - field.json_schema_extra = {} |
239 | | - |
240 | | - if point_index: |
241 | | - cast(JsonDict, field.json_schema_extra)["point_index"] = True |
242 | | - if range_index: |
243 | | - cast(JsonDict, field.json_schema_extra)["range_index"] = True |
244 | | - if text_index: |
245 | | - cast(JsonDict, field.json_schema_extra)["text_index"] = True |
246 | | - if unique: |
247 | | - cast(JsonDict, field.json_schema_extra)["uniqueness_constraint"] = True |
248 | | - |
249 | 253 | else: |
250 | 254 |
|
251 | 255 | @root_validator |
@@ -394,13 +398,13 @@ def __init_subclass__(cls, *args, **kwargs) -> None: |
394 | 398 | unique = getattr(get_field_type(field), "_unique", False) |
395 | 399 |
|
396 | 400 | if point_index: |
397 | | - field.field_info.extra["point_index"] = True |
| 401 | + field.field_info.extra["point_index"] = True # type: ignore |
398 | 402 | if range_index: |
399 | | - field.field_info.extra["range_index"] = True |
| 403 | + field.field_info.extra["range_index"] = True # type: ignore |
400 | 404 | if text_index: |
401 | | - field.field_info.extra["text_index"] = True |
| 405 | + field.field_info.extra["text_index"] = True # type: ignore |
402 | 406 | if unique: |
403 | | - field.field_info.extra["uniqueness_constraint"] = True |
| 407 | + field.field_info.extra["uniqueness_constraint"] = True # type: ignore |
404 | 408 |
|
405 | 409 | super().__init_subclass__(*args, **kwargs) |
406 | 410 |
|
|
0 commit comments