|
| 1 | +import logging |
| 2 | +from collections import defaultdict |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +from openapi_spec_validator import validate |
| 6 | +from openapi_spec_validator.readers import read_from_filename |
| 7 | +from pydantic import BaseModel |
| 8 | +from typer import Typer |
| 9 | + |
| 10 | +all_models = {} |
| 11 | + |
| 12 | +logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | +cli_app = Typer() |
| 15 | + |
| 16 | + |
| 17 | +class RequestEntity(BaseModel): |
| 18 | + method: str |
| 19 | + function_name: str |
| 20 | + path: str |
| 21 | + response: str |
| 22 | + arg_dict: dict |
| 23 | + args: Optional[str] = "" |
| 24 | + |
| 25 | + |
| 26 | +def generate_pydantic_model(schema, model_name): |
| 27 | + properties = schema.get('properties', {}) |
| 28 | + required_fields = schema.get('required', []) |
| 29 | + |
| 30 | + simple_fields = {} |
| 31 | + for field_name, field_schema in properties.items(): |
| 32 | + field_type = field_schema.get('type') |
| 33 | + if field_type == 'object': |
| 34 | + ref_object = field_schema["$ref"].rsplit("/", 1)[-1] |
| 35 | + simple_fields[field_name] = ref_object |
| 36 | + elif field_type == 'array': |
| 37 | + items_schema = field_schema.get('items', {}) |
| 38 | + items_type = items_schema.get('type') |
| 39 | + if items_type == 'object': |
| 40 | + ref_object = field_schema["$ref"].rsplit("/", 1)[-1] |
| 41 | + simple_fields[field_name] = list[ref_object] |
| 42 | + else: |
| 43 | + simple_fields[field_name] = list[ |
| 44 | + map_type_to_string(items_type)] |
| 45 | + else: |
| 46 | + simple_fields[field_name] = map_type_to_string( |
| 47 | + field_type) |
| 48 | + |
| 49 | + if field_name not in required_fields: |
| 50 | + simple_fields[field_name] = f"Optional[{simple_fields[field_name]}]" |
| 51 | + |
| 52 | + if simple_fields: |
| 53 | + all_models[model_name] = simple_fields |
| 54 | + |
| 55 | + |
| 56 | +def map_type_to_string(openapi_type): |
| 57 | + type_mapping = { |
| 58 | + 'string': "str", |
| 59 | + 'integer': "int", |
| 60 | + 'number': "float", |
| 61 | + 'boolean': "bool", |
| 62 | + 'array': list, |
| 63 | + 'object': dict |
| 64 | + } |
| 65 | + return type_mapping.get(openapi_type, str) # 默认返回 str |
| 66 | + |
| 67 | + |
| 68 | +def success_return(status, response): |
| 69 | + return status == "200" or status == "201" and "content" in response |
| 70 | + |
| 71 | + |
| 72 | +def convert_request_name(name: str): |
| 73 | + s = [] |
| 74 | + for k in name: |
| 75 | + if k.isupper(): |
| 76 | + s.append("_") |
| 77 | + s.append(k.lower()) |
| 78 | + else: |
| 79 | + s.append(k) |
| 80 | + return "".join(s[1:]) if s[0] == "_" else "".join(s) |
| 81 | + |
| 82 | + |
| 83 | +def parse_swagger_and_generate_models(swagger_path): |
| 84 | + swagger_dict, _ = read_from_filename(swagger_path) |
| 85 | + validate(swagger_dict) |
| 86 | + |
| 87 | + definitions = swagger_dict.get('components', {}).get('schemas', {}) |
| 88 | + paths = swagger_dict.get('paths', {}) |
| 89 | + |
| 90 | + for model_name, model_schema in definitions.items(): |
| 91 | + generate_pydantic_model(model_schema, model_name) |
| 92 | + |
| 93 | + path_pairs = defaultdict(list) |
| 94 | + |
| 95 | + for path, path_item in paths.items(): |
| 96 | + function_name_suffix = "_".join( |
| 97 | + p for p in path.split("/") if "{" not in p) |
| 98 | + |
| 99 | + for method, operation in path_item.items(): |
| 100 | + request_entity_body = {} |
| 101 | + arg_dict = {} |
| 102 | + |
| 103 | + # 处理路径参数和查询参数 |
| 104 | + if "parameters" in operation: |
| 105 | + for parameter in operation["parameters"]: |
| 106 | + if parameter["in"] == "path": |
| 107 | + for _, value_ in parameter["schema"].items(): |
| 108 | + value_string = map_type_to_string(value_) |
| 109 | + arg_dict[parameter["name"]] = value_string |
| 110 | + break |
| 111 | + elif parameter["in"] == "query": |
| 112 | + for _, value_ in parameter["schema"].items(): |
| 113 | + value_string = map_type_to_string(value_) |
| 114 | + arg_dict[ |
| 115 | + parameter["name"]] = f"Optional[{value_string}]" |
| 116 | + break |
| 117 | + else: |
| 118 | + logger.warning( |
| 119 | + f"unknown parameter in {path}: {parameter}") |
| 120 | + |
| 121 | + # 处理请求体 |
| 122 | + if "requestBody" in operation: |
| 123 | + request_body = operation['requestBody'] |
| 124 | + for _, media_type in request_body.get("content", {}).items(): |
| 125 | + schema = media_type.get("schema", {}) |
| 126 | + if not schema: |
| 127 | + continue |
| 128 | + request_model_name = schema.get("$ref", "").split('/')[-1] |
| 129 | + arg_dict[convert_request_name( |
| 130 | + request_model_name)] = request_model_name |
| 131 | + break |
| 132 | + |
| 133 | + # 处理响应体 |
| 134 | + if 'responses' in operation: |
| 135 | + for status_code, response in operation['responses'].items(): |
| 136 | + if not success_return(status_code, response): |
| 137 | + continue |
| 138 | + for _, media_type in response["content"].items(): |
| 139 | + if "schema" not in media_type: |
| 140 | + continue |
| 141 | + schema = media_type['schema'] |
| 142 | + # 提取响应模型名称 |
| 143 | + response_model_name = \ |
| 144 | + schema.get('$ref', '').split('/')[-1] |
| 145 | + request_entity_body["response"] = response_model_name |
| 146 | + |
| 147 | + request_entity_body["method"] = method |
| 148 | + request_entity_body[ |
| 149 | + "function_name"] = f"{method}{function_name_suffix}" |
| 150 | + request_entity_body["arg_dict"] = arg_dict |
| 151 | + path_pairs[path].append(request_entity_body) |
| 152 | + |
| 153 | + ret = [] |
| 154 | + for path, values in path_pairs.items(): |
| 155 | + for value in values: |
| 156 | + value["path"] = path |
| 157 | + value["function_name"] = value["function_name"].replace("-", "_") |
| 158 | + ret.append(RequestEntity(**value)) |
| 159 | + return ret |
| 160 | + |
| 161 | + |
| 162 | +@cli_app.command() |
| 163 | +def parse(path: str, model_file_name: str = ""): |
| 164 | + import jinja2 |
| 165 | + results = parse_swagger_and_generate_models(path) |
| 166 | + tmpl = jinja2.Template(open("pydantic_client/models.template").read()) |
| 167 | + render_fields = { |
| 168 | + key: [ |
| 169 | + f"{field}: {ftype}" |
| 170 | + for field, ftype in value.items() |
| 171 | + ] for key, value in all_models.items() |
| 172 | + } |
| 173 | + for entity in results: |
| 174 | + args = ", ".join([ |
| 175 | + f"{name}: {type_}" |
| 176 | + for name, type_ in entity.arg_dict.items() |
| 177 | + ]) |
| 178 | + entity.args = args |
| 179 | + |
| 180 | + if render_fields: |
| 181 | + model_string = tmpl.render(models=render_fields, info=results) |
| 182 | + if model_file_name: |
| 183 | + with open(model_file_name, "w") as f: |
| 184 | + f.write(model_string) |
| 185 | + else: |
| 186 | + print(model_string) |
| 187 | + |
| 188 | + |
| 189 | +if __name__ == "__main__": |
| 190 | + cli_app() |
0 commit comments