Skip to content

Commit a13d74d

Browse files
Add include_typename configuration option for better GraphQL API compatibility (#379)
* Initial plan * Add include_typename configuration option to ClientSettings and ResultTypesGenerator Co-authored-by: jacksonpradolima <7774063+jacksonpradolima@users.noreply.github.com> * Add comprehensive tests for include_typename functionality Co-authored-by: jacksonpradolima <7774063+jacksonpradolima@users.noreply.github.com> * Fix discriminator field generation when include_typename=False Co-authored-by: jacksonpradolima <7774063+jacksonpradolima@users.noreply.github.com> * Implement Optional typename approach for better API compatibility - Fix parsing error in test: change 'animals' to 'animal' to match schema - Always generate discriminator for union types (even when include_typename=False) - Make typename__ field Optional[Literal[...]] = None when include_typename=False - This allows models to work with APIs that don't return __typename while maintaining union discrimination capability Co-authored-by: jacksonpradolima <7774063+jacksonpradolima@users.noreply.github.com> * Fix include_typename functionality - remove discriminator when __typename unavailable Co-authored-by: jacksonpradolima <7774063+jacksonpradolima@users.noreply.github.com> * Fix discriminator generation to respect include_typename setting Co-authored-by: jacksonpradolima <7774063+jacksonpradolima@users.noreply.github.com> * Initial plan * Fix linting errors: line length and import sorting Co-authored-by: jacksonpradolima <7774063+jacksonpradolima@users.noreply.github.com> * chore: fix checks --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: jacksonpradolima <7774063+jacksonpradolima@users.noreply.github.com>
1 parent b48fa12 commit a13d74d

File tree

8 files changed

+440
-16
lines changed

8 files changed

+440
-16
lines changed

ariadne_codegen/client_generators/fragments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def __init__(
2020
convert_to_snake_case: bool = True,
2121
custom_scalars: Optional[Dict[str, ScalarData]] = None,
2222
plugin_manager: Optional[PluginManager] = None,
23+
include_typename: bool = True,
2324
) -> None:
2425
self.schema = schema
2526
self.enums_module_name = enums_module_name
@@ -28,6 +29,7 @@ def __init__(
2829
self.convert_to_snake_case = convert_to_snake_case
2930
self.custom_scalars = custom_scalars
3031
self.plugin_manager = plugin_manager
32+
self.include_typename = include_typename
3133

3234
self._fragments_names = set(self.fragments_definitions.keys())
3335
self._generated_public_names: List[str] = []
@@ -52,6 +54,7 @@ def generate(self, exclude_names: Optional[Set[str]] = None) -> ast.Module:
5254
convert_to_snake_case=self.convert_to_snake_case,
5355
custom_scalars=self.custom_scalars,
5456
plugin_manager=self.plugin_manager,
57+
include_typename=self.include_typename,
5558
)
5659
imports.extend(generator.get_imports())
5760
class_defs = generator.get_classes()

ariadne_codegen/client_generators/package.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(
8484
custom_scalars: Optional[Dict[str, ScalarData]] = None,
8585
plugin_manager: Optional[PluginManager] = None,
8686
enable_custom_operations: bool = False,
87+
include_typename: bool = True,
8788
) -> None:
8889
self.package_path = Path(target_path) / package_name
8990

@@ -133,6 +134,7 @@ def __init__(
133134
)
134135
self.custom_scalars = custom_scalars if custom_scalars else {}
135136
self.plugin_manager = plugin_manager
137+
self.include_typename = include_typename
136138

137139
self._result_types_files: Dict[str, ast.Module] = {}
138140
self._generated_files: List[str] = []
@@ -199,6 +201,7 @@ def add_operation(self, definition: OperationDefinitionNode):
199201
convert_to_snake_case=self.convert_to_snake_case,
200202
custom_scalars=self.custom_scalars,
201203
plugin_manager=self.plugin_manager,
204+
include_typename=self.include_typename,
202205
)
203206
self._unpacked_fragments = self._unpacked_fragments.union(
204207
query_types_generator.get_unpacked_fragments()
@@ -454,6 +457,7 @@ def get_package_generator(
454457
convert_to_snake_case=settings.convert_to_snake_case,
455458
custom_scalars=settings.scalars,
456459
plugin_manager=plugin_manager,
460+
include_typename=settings.include_typename,
457461
)
458462
custom_fields_generator = CustomFieldsGenerator(
459463
schema=schema,
@@ -533,4 +537,5 @@ def get_package_generator(
533537
custom_scalars=settings.scalars,
534538
plugin_manager=plugin_manager,
535539
enable_custom_operations=settings.enable_custom_operations,
540+
include_typename=settings.include_typename,
536541
)

ariadne_codegen/client_generators/result_fields.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def parse_operation_field(
8282
typename_values: Optional[List[str]] = None,
8383
custom_scalars: Optional[Dict[str, ScalarData]] = None,
8484
fragments_definitions: Optional[Dict[str, FragmentDefinitionNode]] = None,
85+
include_typename: bool = True,
8586
) -> Tuple[Annotation, Optional[ast.Constant], FieldContext]:
8687
default_value: Optional[ast.Constant] = None
8788
context = FieldContext(
@@ -107,7 +108,7 @@ def parse_operation_field(
107108
)
108109
if isinstance(annotation, ast.Subscript):
109110
annotation.slice = annotate_nested_unions(
110-
cast(AnnotationSlice, annotation.slice)
111+
cast(AnnotationSlice, annotation.slice), include_typename
111112
)
112113
annotation, default_value = parse_directives(
113114
annotation=annotation, directives=directives if directives else tuple()
@@ -363,11 +364,13 @@ def get_fragments_on_subtype(
363364
return fragments
364365

365366

366-
def annotate_nested_unions(annotation: AnnotationSlice) -> AnnotationSlice:
367+
def annotate_nested_unions(
368+
annotation: AnnotationSlice, include_typename: bool = True
369+
) -> AnnotationSlice:
367370
if isinstance(annotation, ast.Tuple):
368371
return generate_tuple(
369372
[
370-
annotate_nested_unions(cast(AnnotationSlice, elt))
373+
annotate_nested_unions(cast(AnnotationSlice, elt), include_typename)
371374
for elt in annotation.elts
372375
]
373376
)
@@ -376,19 +379,25 @@ def annotate_nested_unions(annotation: AnnotationSlice) -> AnnotationSlice:
376379
return annotation
377380

378381
if isinstance(annotation.value, ast.Name) and annotation.value.id == UNION:
379-
return generate_subscript(
380-
value=generate_name(ANNOTATED),
381-
slice_=generate_tuple(
382-
[
383-
annotation,
384-
generate_pydantic_field(
385-
{DISCRIMINATOR_KEYWORD: generate_constant(TYPENAME_ALIAS)}
386-
),
387-
]
388-
),
389-
)
382+
if include_typename:
383+
return generate_subscript(
384+
value=generate_name(ANNOTATED),
385+
slice_=generate_tuple(
386+
[
387+
annotation,
388+
generate_pydantic_field(
389+
{DISCRIMINATOR_KEYWORD: generate_constant(TYPENAME_ALIAS)}
390+
),
391+
]
392+
),
393+
)
394+
else:
395+
# When include_typename=False, return the union without discriminator
396+
return annotation
390397

391-
annotation.slice = annotate_nested_unions(cast(AnnotationSlice, annotation.slice))
398+
annotation.slice = annotate_nested_unions(
399+
cast(AnnotationSlice, annotation.slice), include_typename
400+
)
392401
return annotation
393402

394403

ariadne_codegen/client_generators/result_types.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(
8585
convert_to_snake_case: bool = True,
8686
custom_scalars: Optional[Dict[str, ScalarData]] = None,
8787
plugin_manager: Optional[PluginManager] = None,
88+
include_typename: bool = True,
8889
) -> None:
8990
self.schema = schema
9091
self.operation_definition = operation_definition
@@ -99,6 +100,7 @@ def __init__(
99100
self.custom_scalars = custom_scalars if custom_scalars else {}
100101
self.convert_to_snake_case = convert_to_snake_case
101102
self.plugin_manager = plugin_manager
103+
self.include_typename = include_typename
102104

103105
self._imports: List[ast.ImportFrom] = [
104106
generate_import_from(
@@ -262,6 +264,7 @@ def _parse_type_definition(
262264
typename_values=typename_values,
263265
custom_scalars=self.custom_scalars,
264266
fragments_definitions=self.fragments_definitions,
267+
include_typename=self.include_typename,
265268
)
266269

267270
field_implementation = generate_ann_assign(
@@ -382,6 +385,10 @@ def _unpack_fragment(
382385
def _add_typename_field_to_selections(
383386
self, resolved_fields: List[FieldNode], selection_set: SelectionSetNode
384387
) -> Tuple[List[FieldNode], Tuple[SelectionNode, ...]]:
388+
if not self.include_typename:
389+
# Don't add __typename to fields or selections when include_typename=False
390+
return resolved_fields, selection_set.selections
391+
385392
field_names = {f.name.value for f in resolved_fields}
386393
if TYPENAME_FIELD_NAME not in field_names:
387394
typename_field = FieldNode(name=NameNode(value=TYPENAME_FIELD_NAME))
@@ -436,7 +443,7 @@ def _process_field_implementation(
436443
):
437444
keywords[ALIAS_KEYWORD] = generate_constant(field_schema_name)
438445

439-
if is_union(field_implementation.annotation):
446+
if is_union(field_implementation.annotation) and self.include_typename:
440447
keywords[DISCRIMINATOR_KEYWORD] = generate_constant(TYPENAME_ALIAS)
441448

442449
if keywords and isinstance(field_implementation.value, ast.Constant):

ariadne_codegen/settings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class ClientSettings(BaseSettings):
7373
opentelemetry_client: bool = False
7474
files_to_include: List[str] = field(default_factory=list)
7575
scalars: Dict[str, ScalarData] = field(default_factory=dict)
76+
include_typename: bool = True
7677

7778
def __post_init__(self):
7879
if not self.queries_path and not self.enable_custom_operations:
@@ -167,6 +168,11 @@ def used_settings_message(self) -> str:
167168
if self.plugins
168169
else "No plugin is being used."
169170
)
171+
include_typename_msg = (
172+
"Including __typename fields in generated queries."
173+
if self.include_typename
174+
else "Not including __typename fields in generated queries."
175+
)
170176
return dedent(
171177
f"""\
172178
Selected strategy: {Strategy.CLIENT}
@@ -183,6 +189,7 @@ def used_settings_message(self) -> str:
183189
Comments type: {self.include_comments.value}
184190
{snake_case_msg}
185191
{async_client_msg}
192+
{include_typename_msg}
186193
{files_to_include_msg}
187194
{plugins_msg}
188195
"""

0 commit comments

Comments
 (0)