Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion ariadne_codegen/client_generators/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from ..exceptions import ParsingError
from ..plugins.manager import PluginManager
from ..settings import ClientSettings, CommentsStrategy
from ..utils import ast_to_str, process_name, str_to_pascal_case
from ..utils import (
add_extra_to_base_model,
ast_to_str,
process_name,
str_to_pascal_case,
)
from .arguments import ArgumentsGenerator
from .client import ClientGenerator
from .comments import get_comment
Expand Down Expand Up @@ -85,6 +90,7 @@ def __init__(
plugin_manager: Optional[PluginManager] = None,
enable_custom_operations: bool = False,
include_typename: bool = True,
ignore_extra_fields: bool = True,
) -> None:
self.package_path = Path(target_path) / package_name

Expand Down Expand Up @@ -135,6 +141,7 @@ def __init__(
self.custom_scalars = custom_scalars if custom_scalars else {}
self.plugin_manager = plugin_manager
self.include_typename = include_typename
self.ignore_extra_fields = ignore_extra_fields

self._result_types_files: Dict[str, ast.Module] = {}
self._generated_files: List[str] = []
Expand Down Expand Up @@ -355,6 +362,8 @@ def _copy_files(self):
]
for source_path in files_to_copy:
code = self._add_comments_to_code(source_path.read_text(encoding="utf-8"))
if not self.ignore_extra_fields and source_path.name == "base_model.py":
code = add_extra_to_base_model(code)
if self.plugin_manager:
code = self.plugin_manager.copy_code(code)
target_path = self.package_path / source_path.name
Expand Down Expand Up @@ -538,4 +547,5 @@ def get_package_generator(
plugin_manager=plugin_manager,
enable_custom_operations=settings.enable_custom_operations,
include_typename=settings.include_typename,
ignore_extra_fields=settings.ignore_extra_fields,
)
1 change: 1 addition & 0 deletions ariadne_codegen/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class ClientSettings(BaseSettings):
files_to_include: List[str] = field(default_factory=list)
scalars: Dict[str, ScalarData] = field(default_factory=dict)
include_typename: bool = True
ignore_extra_fields: bool = True

def __post_init__(self):
if not self.queries_path and not self.enable_custom_operations:
Expand Down
26 changes: 26 additions & 0 deletions ariadne_codegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,29 @@ def process_name(
if set(name) == {"_"} and not processed_name:
return "underscore_named_field_"
return processed_name


def add_extra_to_base_model(code: str) -> str:
"Adds `extra='forbid'` to the ConfigDict in BaseModel if not already present."
tree = ast.parse(code)
for node in tree.body:
if not isinstance(node, ast.ClassDef):
continue
if node.name != "BaseModel":
continue
for statement in node.body:
if not isinstance(statement, ast.Assign):
continue
call = statement.value
if not isinstance(call, ast.Call):
continue
if not isinstance(call.func, ast.Name):
continue
if call.func.id != "ConfigDict":
continue
if not any(kw.arg == "extra" for kw in call.keywords):
call.keywords.append(
ast.keyword(arg="extra", value=ast.Constant("forbid"))
)
ast.fix_missing_locations(tree)
return ast.unparse(tree)
40 changes: 40 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

from ariadne_codegen.utils import (
add_extra_to_base_model,
ast_to_str,
convert_to_multiline_string,
format_multiline_strings,
Expand Down Expand Up @@ -201,3 +202,42 @@ def test_process_name_returns_name_returned_from_plugin_for_name_with_only_under
)
== "name_from_plugin"
)


def test_adds_extra_to_base_model_if_missing():
code = dedent("""
class BaseModel:
Config = ConfigDict()
""")
expected = dedent("""
class BaseModel:
Config = ConfigDict(extra='forbid')
""")
result = add_extra_to_base_model(code)
assert dedent(result).strip() == expected.strip()


def test_adds_extra_to_base_model_does_not_overwrite_existing_extra():
code = dedent("""
class BaseModel:
Config = ConfigDict(extra='ignore')
""")
expected = dedent("""
class BaseModel:
Config = ConfigDict(extra='ignore')
""")
result = add_extra_to_base_model(code)
assert dedent(result).strip() == expected.strip()


def test_adds_extra_to_base_model_leaves_other_classes_untouched():
code = dedent("""
class NotBaseModel:
Config = ConfigDict()
""")
expected = dedent("""
class NotBaseModel:
Config = ConfigDict()
""")
result = add_extra_to_base_model(code)
assert dedent(result).strip() == expected.strip()