Skip to content

Commit 4101345

Browse files
authored
Merge pull request #91 from ponytailer/swagger-cli
Add swagger cli
2 parents 9f9af13 + 7ade4b3 commit 4101345

File tree

8 files changed

+986
-6
lines changed

8 files changed

+986
-6
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
strategy:
1313
matrix:
1414
os: [ubuntu-latest, windows-latest, macos-latest]
15-
python-version: ["3.8", "3.9", "3.10", "3.11"]
15+
python-version: ["3.9", "3.10", "3.11", "3.12"]
1616
fail-fast: true
1717

1818
steps:

poetry.lock

Lines changed: 585 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pydantic_client/cli.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import logging
2+
from collections import defaultdict
3+
from typing import Optional
4+
5+
from openapi_spec_validator import validate
6+
from openapi_spec_validator.readers import read_from_filename
7+
from pydantic import BaseModel
8+
from typer import Typer
9+
10+
all_models = {}
11+
12+
logger = logging.getLogger(__name__)
13+
14+
cli_app = Typer()
15+
16+
17+
class RequestEntity(BaseModel):
18+
method: str
19+
function_name: str
20+
path: str
21+
response: str
22+
arg_dict: dict
23+
args: Optional[str] = ""
24+
25+
26+
def generate_pydantic_model(schema, model_name):
27+
properties = schema.get('properties', {})
28+
required_fields = schema.get('required', [])
29+
30+
simple_fields = {}
31+
for field_name, field_schema in properties.items():
32+
field_type = field_schema.get('type')
33+
if field_type == 'object':
34+
ref_object = field_schema["$ref"].rsplit("/", 1)[-1]
35+
simple_fields[field_name] = ref_object
36+
elif field_type == 'array':
37+
items_schema = field_schema.get('items', {})
38+
items_type = items_schema.get('type')
39+
if items_type == 'object':
40+
ref_object = field_schema["$ref"].rsplit("/", 1)[-1]
41+
simple_fields[field_name] = list[ref_object]
42+
else:
43+
simple_fields[field_name] = list[
44+
map_type_to_string(items_type)]
45+
else:
46+
simple_fields[field_name] = map_type_to_string(
47+
field_type)
48+
49+
if field_name not in required_fields:
50+
simple_fields[field_name] = f"Optional[{simple_fields[field_name]}]"
51+
52+
if simple_fields:
53+
all_models[model_name] = simple_fields
54+
55+
56+
def map_type_to_string(openapi_type):
57+
type_mapping = {
58+
'string': "str",
59+
'integer': "int",
60+
'number': "float",
61+
'boolean': "bool",
62+
'array': list,
63+
'object': dict
64+
}
65+
return type_mapping.get(openapi_type, str) # 默认返回 str
66+
67+
68+
def success_return(status, response):
69+
return status == "200" or status == "201" and "content" in response
70+
71+
72+
def convert_request_name(name: str):
73+
s = []
74+
for k in name:
75+
if k.isupper():
76+
s.append("_")
77+
s.append(k.lower())
78+
else:
79+
s.append(k)
80+
return "".join(s[1:]) if s[0] == "_" else "".join(s)
81+
82+
83+
def parse_swagger_and_generate_models(swagger_path):
84+
swagger_dict, _ = read_from_filename(swagger_path)
85+
validate(swagger_dict)
86+
87+
definitions = swagger_dict.get('components', {}).get('schemas', {})
88+
paths = swagger_dict.get('paths', {})
89+
90+
for model_name, model_schema in definitions.items():
91+
generate_pydantic_model(model_schema, model_name)
92+
93+
path_pairs = defaultdict(list)
94+
95+
for path, path_item in paths.items():
96+
function_name_suffix = "_".join(
97+
p for p in path.split("/") if "{" not in p)
98+
99+
for method, operation in path_item.items():
100+
request_entity_body = {}
101+
arg_dict = {}
102+
103+
# 处理路径参数和查询参数
104+
if "parameters" in operation:
105+
for parameter in operation["parameters"]:
106+
if parameter["in"] == "path":
107+
for _, value_ in parameter["schema"].items():
108+
value_string = map_type_to_string(value_)
109+
arg_dict[parameter["name"]] = value_string
110+
break
111+
elif parameter["in"] == "query":
112+
for _, value_ in parameter["schema"].items():
113+
value_string = map_type_to_string(value_)
114+
arg_dict[
115+
parameter["name"]] = f"Optional[{value_string}]"
116+
break
117+
else:
118+
logger.warning(
119+
f"unknown parameter in {path}: {parameter}")
120+
121+
# 处理请求体
122+
if "requestBody" in operation:
123+
request_body = operation['requestBody']
124+
for _, media_type in request_body.get("content", {}).items():
125+
schema = media_type.get("schema", {})
126+
if not schema:
127+
continue
128+
request_model_name = schema.get("$ref", "").split('/')[-1]
129+
arg_dict[convert_request_name(
130+
request_model_name)] = request_model_name
131+
break
132+
133+
# 处理响应体
134+
if 'responses' in operation:
135+
for status_code, response in operation['responses'].items():
136+
if not success_return(status_code, response):
137+
continue
138+
for _, media_type in response["content"].items():
139+
if "schema" not in media_type:
140+
continue
141+
schema = media_type['schema']
142+
# 提取响应模型名称
143+
response_model_name = \
144+
schema.get('$ref', '').split('/')[-1]
145+
request_entity_body["response"] = response_model_name
146+
147+
request_entity_body["method"] = method
148+
request_entity_body[
149+
"function_name"] = f"{method}{function_name_suffix}"
150+
request_entity_body["arg_dict"] = arg_dict
151+
path_pairs[path].append(request_entity_body)
152+
153+
ret = []
154+
for path, values in path_pairs.items():
155+
for value in values:
156+
value["path"] = path
157+
value["function_name"] = value["function_name"].replace("-", "_")
158+
ret.append(RequestEntity(**value))
159+
return ret
160+
161+
162+
@cli_app.command()
163+
def parse(path: str, model_file_name: str = ""):
164+
import jinja2
165+
results = parse_swagger_and_generate_models(path)
166+
tmpl = jinja2.Template(open("pydantic_client/models.template").read())
167+
render_fields = {
168+
key: [
169+
f"{field}: {ftype}"
170+
for field, ftype in value.items()
171+
] for key, value in all_models.items()
172+
}
173+
for entity in results:
174+
args = ", ".join([
175+
f"{name}: {type_}"
176+
for name, type_ in entity.arg_dict.items()
177+
])
178+
entity.args = args
179+
180+
if render_fields:
181+
model_string = tmpl.render(models=render_fields, info=results)
182+
if model_file_name:
183+
with open(model_file_name, "w") as f:
184+
f.write(model_string)
185+
else:
186+
print(model_string)
187+
188+
189+
if __name__ == "__main__":
190+
cli_app()

pydantic_client/models.template

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Optional
2+
from pydantic import BaseModel
3+
{% for model_name, fields in models.items() %}
4+
5+
class {{model_name}}(BaseModel):
6+
{% for field in fields -%}
7+
{{ field }}
8+
{% endfor %}
9+
{% endfor %}
10+
11+
class WebClient:
12+
{% for entity in info %}
13+
@{{entity.method}}("{{entity.path}}")
14+
def {{entity.function_name}}(self{%if entity.args%}, {%endif%}{{entity.args}}):
15+
...
16+
{% endfor %}

pyproject.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pydantic-client"
3-
version = "1.0.6"
3+
version = "1.0.7"
44
description = "Http client base pydantic, with requests or aiohttp"
55
authors = ["ponytailer <huangxiaohen2738@gmail.com>"]
66
readme = "README.md"
@@ -13,19 +13,25 @@ url = "https://pypi.tuna.tsinghua.edu.cn/simple"
1313

1414

1515
[tool.poetry.dependencies]
16-
python = "^3.8"
16+
python = "^3.9"
1717

1818
pydantic = ">=2.1"
1919
requests = "*"
2020

2121
aiohttp = { version = "*", optional = true }
2222
httpx = { version = "*", extras = ["http2"], optional = true }
23+
openapi-spec-validator = "^0.7.1"
24+
jinja2 = "^3.1.5"
25+
typer = "^0.15.1"
2326

2427
[tool.poetry.extras]
2528
httpx = ["httpx"]
2629
aiohttp = ["aiohttp"]
2730
all = ["httpx", "aiohttp"]
2831

32+
[tool.poetry.scripts]
33+
pydantic-client = "pydantic_client.cli:cli_app"
34+
2935
[build-system]
3036
requires = ["poetry-core"]
3137
build-backend = "poetry.core.masonry.api"

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,7 @@ pydantic==2.5.2
1717
python-multipart==0.0.9
1818
requests==2.31.0
1919
uvicorn==0.29.0
20-
toml
20+
toml
21+
typer
22+
jinja2
23+
openapi-spec-validator

0 commit comments

Comments
 (0)