From e8347c2a758dc73108d4071b8a45b1b54da4ee9b Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Wed, 8 Jan 2025 18:45:02 +0000 Subject: [PATCH] Avoid generating empty ImportFrom statements - When emitting a `ImportFrom()` node, assert that the list of names is not empty, and address all places where this potentially could happen. - Extend the `input_types_generator/test_filtering_names` test to include at least 1 enum so that it would trigger the assertion error if that enum is not included in the used inputs. --- ariadne_codegen/client_generators/client.py | 26 ++++++++++--------- .../client_generators/input_types.py | 4 +-- ariadne_codegen/codegen.py | 1 + .../test_filtering_names.py | 7 +++++ 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/ariadne_codegen/client_generators/client.py b/ariadne_codegen/client_generators/client.py index 4c0831a5..8977a050 100644 --- a/ariadne_codegen/client_generators/client.py +++ b/ariadne_codegen/client_generators/client.py @@ -111,20 +111,22 @@ def __init__( def generate(self) -> ast.Module: """Generate module with class definition of graphql client.""" - self._add_import( - generate_import_from( - names=self.arguments_generator.get_used_inputs(), - from_=self.input_types_module_name, - level=1, + if used_inputs := self.arguments_generator.get_used_inputs(): + self._add_import( + generate_import_from( + names=used_inputs, + from_=self.input_types_module_name, + level=1, + ) ) - ) - self._add_import( - generate_import_from( - names=self.arguments_generator.get_used_enums(), - from_=self.enums_module_name, - level=1, + if used_enums := self.arguments_generator.get_used_enums(): + self._add_import( + generate_import_from( + names=used_enums, + from_=self.enums_module_name, + level=1, + ) ) - ) for custom_scalar_name in self.arguments_generator.get_used_custom_scalars(): scalar_data = self.custom_scalars[custom_scalar_name] for import_ in generate_scalar_imports(scalar_data): diff --git a/ariadne_codegen/client_generators/input_types.py b/ariadne_codegen/client_generators/input_types.py index 0a5797d3..ec609315 100644 --- a/ariadne_codegen/client_generators/input_types.py +++ b/ariadne_codegen/client_generators/input_types.py @@ -81,9 +81,9 @@ def generate(self, types_to_include: Optional[List[str]] = None) -> ast.Module: class_defs = self._filter_class_defs(types_to_include=types_to_include) self._generated_public_names = [class_def.name for class_def in class_defs] - if self._used_enums: + if used_imports := self.get_used_enums(): self._imports.append( - generate_import_from(self.get_used_enums(), self.enums_module, 1) + generate_import_from(used_imports, self.enums_module, 1) ) for scalar_name in self._used_scalars: diff --git a/ariadne_codegen/codegen.py b/ariadne_codegen/codegen.py index 6f964ea1..b613789c 100644 --- a/ariadne_codegen/codegen.py +++ b/ariadne_codegen/codegen.py @@ -29,6 +29,7 @@ def generate_import_from( names: List[str], from_: Optional[str] = None, level: int = 0 ) -> ast.ImportFrom: """Generate import from statement.""" + assert names, "Using ImportFrom with no names would produce invalid Python code" return ast.ImportFrom( module=from_, names=[ast.alias(n) for n in names], level=level ) diff --git a/tests/client_generators/input_types_generator/test_filtering_names.py b/tests/client_generators/input_types_generator/test_filtering_names.py index 14106564..a38caa20 100644 --- a/tests/client_generators/input_types_generator/test_filtering_names.py +++ b/tests/client_generators/input_types_generator/test_filtering_names.py @@ -30,9 +30,15 @@ ) def test_generator_returns_module_with_filtered_classes(used_types, expected_classes): schema_str = """ + enum EnumA { + VALUE1 + VALUE2 + } + input InputA { valueAA: InputAA! valueAB: InputAB + valueEnumA: EnumA } input InputAA { @@ -46,6 +52,7 @@ def test_generator_returns_module_with_filtered_classes(used_types, expected_cla input InputAB { val: String! valueA: InputA + valueEnumA: EnumA } input InputX {