Skip to content

Commit 9773a57

Browse files
PaleNeutronJohn Lyu
andauthored
Add docstring to generated code based on schema description (#362)
* Add docstring to generated code based on schema description * add method description * add docstring for input field * add field docstring to custom_fields class * fix docstring of method * add tests * lint --------- Co-authored-by: John Lyu <lvjunhong@citics.com>
1 parent 33f1a57 commit 9773a57

File tree

8 files changed

+297
-9
lines changed

8 files changed

+297
-9
lines changed

ariadne_codegen/client_generators/custom_fields.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def _parse_object_type_definitions(
114114
class_def = self._generate_class_def_body(
115115
definition=graphql_type,
116116
class_name=f"{graphql_type.name}{self._get_suffix(graphql_type)}",
117+
description=graphql_type.description,
117118
)
118119
if isinstance(graphql_type, GraphQLInterfaceType):
119120
class_def.body.append(
@@ -129,17 +130,20 @@ def _generate_class_def_body(
129130
self,
130131
definition: Union[GraphQLObjectType, GraphQLInterfaceType],
131132
class_name: str,
133+
description: Optional[str] = None,
132134
) -> ast.ClassDef:
133135
"""
134136
Generates the body of a class definition for a given GraphQL object
135137
or interface type.
136138
"""
137139
base_names = [GRAPHQL_BASE_FIELD_CLASS]
138140
additional_fields_typing = set()
139-
class_def = generate_class_def(name=class_name, base_names=base_names)
140-
for lineno, (org_name, field) in enumerate(
141-
self._get_combined_fields(definition).items(), start=1
142-
):
141+
class_def = generate_class_def(
142+
name=class_name, base_names=base_names, description=description
143+
)
144+
lineno = 0
145+
for org_name, field in self._get_combined_fields(definition).items():
146+
lineno += 1
143147
name = process_name(
144148
org_name, convert_to_snake_case=self.convert_to_snake_case
145149
)
@@ -154,6 +158,11 @@ def _generate_class_def_body(
154158
name, field_name, org_name, field, method_required, lineno
155159
)
156160
)
161+
# Add field docstring for class attributes (not methods)
162+
if not getattr(field, "args") and field.description and not method_required:
163+
lineno += 1
164+
docstring = ast.Expr(value=ast.Constant(field.description))
165+
class_def.body.append(docstring)
157166

158167
class_def.body.append(
159168
self._generate_fields_method(
@@ -216,7 +225,11 @@ def _generate_class_field(
216225
"""Handles the generation of field types."""
217226
if getattr(field, "args") or method_required:
218227
return self.generate_product_type_method(
219-
name, field_name, org_name, getattr(field, "args")
228+
name,
229+
field_name,
230+
org_name,
231+
getattr(field, "args"),
232+
description=getattr(field, "description"),
220233
)
221234
return generate_ann_assign(
222235
target=generate_name(name),
@@ -316,6 +329,7 @@ def generate_product_type_method(
316329
class_name: str,
317330
org_name: str,
318331
arguments: Optional[Dict[str, Any]] = None,
332+
description: Optional[str] = None,
319333
) -> ast.FunctionDef:
320334
"""Generates a method for a product type."""
321335
arguments = arguments or {}
@@ -355,6 +369,7 @@ def generate_product_type_method(
355369
),
356370
return_type=generate_name(f'"{class_name}"'),
357371
decorator_list=[generate_name("classmethod")],
372+
description=description,
358373
)
359374

360375
def _get_suffix(

ariadne_codegen/client_generators/custom_operation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def generate(self) -> ast.Module:
8484
operation_name=name,
8585
operation_args=field.args,
8686
final_type=final_type,
87+
description=field.description,
8788
)
8889
method_def.lineno = len(self._class_def.body) + 1
8990
self._class_def.body.append(method_def)
@@ -115,6 +116,7 @@ def _generate_method(
115116
operation_name: str,
116117
operation_args,
117118
final_type,
119+
description: Optional[str] = None,
118120
) -> ast.FunctionDef:
119121
"""Generates a method definition for a given operation."""
120122
(
@@ -141,6 +143,7 @@ def _generate_method(
141143
name=str_to_snake_case(operation_name),
142144
arguments=method_arguments,
143145
return_type=generate_name(return_type_name),
146+
description=description,
144147
body=[
145148
*arguments_body,
146149
generate_return(

ariadne_codegen/client_generators/input_types.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,13 @@ def _parse_input_definition(
157157
self, definition: GraphQLInputObjectType
158158
) -> ast.ClassDef:
159159
class_def = generate_class_def(
160-
name=definition.name, base_names=[BASE_MODEL_CLASS_NAME]
160+
name=definition.name,
161+
base_names=[BASE_MODEL_CLASS_NAME],
162+
description=definition.description,
161163
)
162-
163-
for lineno, (org_name, field) in enumerate(definition.fields.items(), start=1):
164+
lineno = 0
165+
for org_name, field in definition.fields.items():
166+
lineno += 1
164167
name = process_name(
165168
org_name,
166169
convert_to_snake_case=self.convert_to_snake_case,
@@ -190,6 +193,10 @@ def _parse_input_definition(
190193
field_implementation, input_field=field, field_name=org_name
191194
)
192195
class_def.body.append(field_implementation)
196+
if field.description:
197+
lineno += 1
198+
docstring = ast.Expr(value=ast.Constant(value=field.description))
199+
class_def.body.append(docstring)
193200
self._save_dependencies(root_type=definition.name, field_type=field_type)
194201

195202
if self.plugin_manager:

ariadne_codegen/codegen.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,21 @@ def generate_class_def(
111111
name: str,
112112
base_names: Optional[List[str]] = None,
113113
body: Optional[List[ast.stmt]] = None,
114+
description: str = "",
114115
) -> ast.ClassDef:
115116
"""Generate class definition."""
116117
bases = cast(
117118
List[ast.expr], [ast.Name(id=name) for name in base_names] if base_names else []
118119
)
120+
body = body if body else []
121+
if description:
122+
docstring = ast.Expr(value=ast.Constant(value=description))
123+
body.insert(0, docstring)
119124
params: Dict[str, Any] = {
120125
"name": name,
121126
"bases": bases,
122127
"keywords": [],
123-
"body": body if body else [],
128+
"body": body,
124129
"decorator_list": [],
125130
}
126131
if sys.version_info >= (3, 12):
@@ -354,10 +359,15 @@ def generate_method_definition(
354359
name: str,
355360
arguments: ast.arguments,
356361
return_type: Union[ast.Name, ast.Subscript],
362+
description: str = "",
357363
body: Optional[List[ast.stmt]] = None,
358364
lineno: int = 1,
359365
decorator_list: Optional[List[ast.expr]] = None,
360366
) -> ast.FunctionDef:
367+
body = body if body else [ast.Pass()]
368+
if description:
369+
docstring = ast.Expr(value=ast.Constant(value=description))
370+
body.insert(0, docstring)
361371
params: Dict[str, Any] = {
362372
"name": name,
363373
"args": arguments,
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import ast
2+
from unittest import mock
3+
4+
from ariadne_codegen.codegen import generate_arguments, generate_method_definition
5+
6+
7+
def test_generate_method_definition_with_minimal_parameters():
8+
"""Test generate_method_definition with only required parameters."""
9+
name = "test_method"
10+
arguments = generate_arguments()
11+
return_type = ast.Name(id="str")
12+
13+
result = generate_method_definition(name, arguments, return_type)
14+
15+
assert isinstance(result, ast.FunctionDef)
16+
assert result.name == name
17+
assert result.args == arguments
18+
assert result.returns == return_type
19+
assert result.lineno == 1
20+
assert len(result.body) == 1
21+
assert isinstance(result.body[0], ast.Pass)
22+
assert result.decorator_list == []
23+
24+
25+
def test_generate_method_definition_with_description():
26+
"""Test generate_method_definition with description adds docstring."""
27+
name = "test_method"
28+
arguments = generate_arguments()
29+
return_type = ast.Name(id="str")
30+
description = "This is a test method"
31+
32+
result = generate_method_definition(
33+
name, arguments, return_type, description=description
34+
)
35+
36+
assert len(result.body) == 2
37+
assert isinstance(result.body[0], ast.Expr)
38+
assert isinstance(result.body[0].value, ast.Constant)
39+
assert result.body[0].value.value == description
40+
assert isinstance(result.body[1], ast.Pass)
41+
42+
43+
def test_generate_method_definition_with_custom_body():
44+
"""Test generate_method_definition with custom body."""
45+
name = "test_method"
46+
arguments = generate_arguments()
47+
return_type = ast.Name(id="str")
48+
body = [ast.Return(value=ast.Constant(value="test"))]
49+
50+
result = generate_method_definition(name, arguments, return_type, body=body)
51+
52+
assert result.body == body
53+
assert len(result.body) == 1
54+
assert isinstance(result.body[0], ast.Return)
55+
56+
57+
def test_generate_method_definition_with_description_and_custom_body():
58+
"""Test generate_method_definition with both description and custom body."""
59+
name = "test_method"
60+
arguments = generate_arguments()
61+
return_type = ast.Name(id="str")
62+
description = "Test description"
63+
body = [ast.Return(value=ast.Constant(value="test"))]
64+
65+
result = generate_method_definition(
66+
name, arguments, return_type, description=description, body=body
67+
)
68+
69+
assert len(result.body) == 2
70+
assert isinstance(result.body[0], ast.Expr)
71+
assert result.body[0].value.value == description
72+
assert isinstance(result.body[1], ast.Return)
73+
74+
75+
def test_generate_method_definition_with_custom_lineno():
76+
"""Test generate_method_definition with custom line number."""
77+
name = "test_method"
78+
arguments = generate_arguments()
79+
return_type = ast.Name(id="str")
80+
lineno = 42
81+
82+
result = generate_method_definition(name, arguments, return_type, lineno=lineno)
83+
84+
assert result.lineno == lineno
85+
86+
87+
def test_generate_method_definition_with_decorators():
88+
"""Test generate_method_definition with decorator list."""
89+
name = "test_method"
90+
arguments = generate_arguments()
91+
return_type = ast.Name(id="str")
92+
decorators = [ast.Name(id="property"), ast.Name(id="staticmethod")]
93+
94+
result = generate_method_definition(
95+
name, arguments, return_type, decorator_list=decorators
96+
)
97+
98+
assert result.decorator_list == decorators
99+
100+
101+
def test_generate_method_definition_with_subscript_return_type():
102+
"""Test generate_method_definition with Subscript return type."""
103+
name = "test_method"
104+
arguments = generate_arguments()
105+
return_type = ast.Subscript(value=ast.Name(id="List"), slice=ast.Name(id="str"))
106+
107+
result = generate_method_definition(name, arguments, return_type)
108+
109+
assert result.returns == return_type
110+
assert isinstance(result.returns, ast.Subscript)
111+
112+
113+
@mock.patch("sys.version_info", (3, 12, 0))
114+
def test_generate_method_definition_python_312_or_later():
115+
"""Test generate_method_definition includes type_params for Python 3.12+."""
116+
name = "test_method"
117+
arguments = generate_arguments()
118+
return_type = ast.Name(id="str")
119+
120+
result = generate_method_definition(name, arguments, return_type)
121+
122+
assert hasattr(result, "type_params")
123+
assert result.type_params == []
124+
125+
126+
@mock.patch("sys.version_info", (3, 11, 0))
127+
def test_generate_method_definition_python_311_or_earlier():
128+
"""Test generate_method_definition doesn't include type_params for Python < 3.12."""
129+
name = "test_method"
130+
arguments = generate_arguments()
131+
return_type = ast.Name(id="str")
132+
133+
result = generate_method_definition(name, arguments, return_type)
134+
135+
# In Python < 3.12, type_params shouldn't be set
136+
# We check that the function still works correctly
137+
assert isinstance(result, ast.FunctionDef)
138+
assert result.name == name
139+
140+
141+
def test_generate_method_definition_all_parameters():
142+
"""Test generate_method_definition with all parameters specified."""
143+
name = "complex_method"
144+
arguments = generate_arguments()
145+
return_type = ast.Subscript(value=ast.Name(id="Optional"), slice=ast.Name(id="int"))
146+
description = "A complex method with all parameters"
147+
body = [
148+
ast.Assign(targets=[ast.Name(id="x")], value=ast.Constant(value=1)),
149+
ast.Return(value=ast.Name(id="x")),
150+
]
151+
lineno = 10
152+
decorators = [ast.Name(id="classmethod")]
153+
154+
result = generate_method_definition(
155+
name, arguments, return_type, description, body, lineno, decorators
156+
)
157+
158+
assert result.name == name
159+
assert result.args == arguments
160+
assert result.returns == return_type
161+
assert result.lineno == lineno
162+
assert result.decorator_list == decorators
163+
assert len(result.body) == 3 # docstring + 2 body statements
164+
assert isinstance(result.body[0], ast.Expr) # docstring
165+
assert result.body[0].value.value == description

tests/codegen/test_generated_classes.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,17 @@ def test_generate_class_def_returns_class_def_without_base():
2424
assert isinstance(result, ast.ClassDef)
2525
assert result.name == name
2626
assert not result.bases
27+
28+
29+
def test_generate_class_def_with_description_adds_docstring():
30+
name = "Xyz"
31+
description = "This is a test class. \nWith multiple lines."
32+
33+
result = generate_class_def(name, description=description)
34+
35+
assert isinstance(result, ast.ClassDef)
36+
assert result.name == name
37+
docstring = result.body[0]
38+
assert isinstance(docstring, ast.Expr)
39+
assert isinstance(docstring.value, ast.Constant)
40+
assert docstring.value.value == description

0 commit comments

Comments
 (0)