From c7525e667975edef2952207d23ab08f1054ad711 Mon Sep 17 00:00:00 2001 From: luci-bytes Date: Wed, 26 Nov 2025 09:02:26 +0100 Subject: [PATCH] Fix formatting --- ariadne_codegen/client_generators/package.py | 12 +++++- ariadne_codegen/settings.py | 1 + ariadne_codegen/utils.py | 26 +++++++++++++ tests/test_utils.py | 40 ++++++++++++++++++++ 4 files changed, 78 insertions(+), 1 deletion(-) diff --git a/ariadne_codegen/client_generators/package.py b/ariadne_codegen/client_generators/package.py index 55df35e7..234b7adc 100644 --- a/ariadne_codegen/client_generators/package.py +++ b/ariadne_codegen/client_generators/package.py @@ -13,7 +13,12 @@ from ..exceptions import ParsingError from ..plugins.manager import PluginManager from ..settings import ClientSettings, CommentsStrategy -from ..utils import ast_to_str, process_name, str_to_pascal_case +from ..utils import ( + add_extra_to_base_model, + ast_to_str, + process_name, + str_to_pascal_case, +) from .arguments import ArgumentsGenerator from .client import ClientGenerator from .comments import get_comment @@ -85,6 +90,7 @@ def __init__( plugin_manager: Optional[PluginManager] = None, enable_custom_operations: bool = False, include_typename: bool = True, + ignore_extra_fields: bool = True, ) -> None: self.package_path = Path(target_path) / package_name @@ -135,6 +141,7 @@ def __init__( self.custom_scalars = custom_scalars if custom_scalars else {} self.plugin_manager = plugin_manager self.include_typename = include_typename + self.ignore_extra_fields = ignore_extra_fields self._result_types_files: Dict[str, ast.Module] = {} self._generated_files: List[str] = [] @@ -355,6 +362,8 @@ def _copy_files(self): ] for source_path in files_to_copy: code = self._add_comments_to_code(source_path.read_text(encoding="utf-8")) + if not self.ignore_extra_fields and source_path.name == "base_model.py": + code = add_extra_to_base_model(code) if self.plugin_manager: code = self.plugin_manager.copy_code(code) target_path = self.package_path / source_path.name @@ -538,4 +547,5 @@ def get_package_generator( plugin_manager=plugin_manager, enable_custom_operations=settings.enable_custom_operations, include_typename=settings.include_typename, + ignore_extra_fields=settings.ignore_extra_fields, ) diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 7d5ad570..574c9ce6 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -74,6 +74,7 @@ class ClientSettings(BaseSettings): files_to_include: List[str] = field(default_factory=list) scalars: Dict[str, ScalarData] = field(default_factory=dict) include_typename: bool = True + ignore_extra_fields: bool = True def __post_init__(self): if not self.queries_path and not self.enable_custom_operations: diff --git a/ariadne_codegen/utils.py b/ariadne_codegen/utils.py index c21485b2..ff5dd347 100644 --- a/ariadne_codegen/utils.py +++ b/ariadne_codegen/utils.py @@ -138,3 +138,29 @@ def process_name( if set(name) == {"_"} and not processed_name: return "underscore_named_field_" return processed_name + + +def add_extra_to_base_model(code: str) -> str: + "Adds `extra='forbid'` to the ConfigDict in BaseModel if not already present." + tree = ast.parse(code) + for node in tree.body: + if not isinstance(node, ast.ClassDef): + continue + if node.name != "BaseModel": + continue + for statement in node.body: + if not isinstance(statement, ast.Assign): + continue + call = statement.value + if not isinstance(call, ast.Call): + continue + if not isinstance(call.func, ast.Name): + continue + if call.func.id != "ConfigDict": + continue + if not any(kw.arg == "extra" for kw in call.keywords): + call.keywords.append( + ast.keyword(arg="extra", value=ast.Constant("forbid")) + ) + ast.fix_missing_locations(tree) + return ast.unparse(tree) diff --git a/tests/test_utils.py b/tests/test_utils.py index 3349973a..42a82331 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,6 +4,7 @@ import pytest from ariadne_codegen.utils import ( + add_extra_to_base_model, ast_to_str, convert_to_multiline_string, format_multiline_strings, @@ -201,3 +202,42 @@ def test_process_name_returns_name_returned_from_plugin_for_name_with_only_under ) == "name_from_plugin" ) + + +def test_adds_extra_to_base_model_if_missing(): + code = dedent(""" + class BaseModel: + Config = ConfigDict() + """) + expected = dedent(""" + class BaseModel: + Config = ConfigDict(extra='forbid') + """) + result = add_extra_to_base_model(code) + assert dedent(result).strip() == expected.strip() + + +def test_adds_extra_to_base_model_does_not_overwrite_existing_extra(): + code = dedent(""" + class BaseModel: + Config = ConfigDict(extra='ignore') + """) + expected = dedent(""" + class BaseModel: + Config = ConfigDict(extra='ignore') + """) + result = add_extra_to_base_model(code) + assert dedent(result).strip() == expected.strip() + + +def test_adds_extra_to_base_model_leaves_other_classes_untouched(): + code = dedent(""" + class NotBaseModel: + Config = ConfigDict() + """) + expected = dedent(""" + class NotBaseModel: + Config = ConfigDict() + """) + result = add_extra_to_base_model(code) + assert dedent(result).strip() == expected.strip()