Skip to content

Commit 10a4783

Browse files
authored
[mulitapi combiner] combine models in multiapi sdk (Azure#28718)
1 parent 9c908dd commit 10a4783

File tree

6 files changed

+166
-12
lines changed

6 files changed

+166
-12
lines changed

tools/azure-sdk-tools/packaging_tools/multiapi_combiner.py

Lines changed: 118 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import argparse
1212
from pathlib import Path
1313
import shutil
14-
from typing import Dict, Optional, List, Any, TypeVar, Callable
14+
from typing import Dict, Optional, List, Any, TypeVar, Callable, Set
1515

1616
from jinja2 import PackageLoader, Environment
1717

@@ -34,6 +34,10 @@ def modify_relative_imports(regex: str, file: str) -> str:
3434
return file.replace(original_str, new_str)
3535

3636

37+
def strip_version_from_docs(input: str) -> str:
38+
return re.sub(r".v20[^.]*", "", input)
39+
40+
3741
class VersionedObject:
3842
"""An object that can be added / removed in an api version"""
3943

@@ -77,6 +81,19 @@ def _combine_helper(
7781
return objs
7882

7983

84+
def _sort_models_helper(current: "ModelAndEnum", seen_model_names: Set[str]) -> List["ModelAndEnum"]:
85+
if current.name in seen_model_names:
86+
return []
87+
ancestors: List["ModelAndEnum"] = [current]
88+
for parent in current.parents:
89+
if parent.name in seen_model_names:
90+
continue
91+
ancestors = _sort_models_helper(parent, seen_model_names) + ancestors
92+
seen_model_names.add(parent.name)
93+
seen_model_names.add(current.name)
94+
return ancestors
95+
96+
8097
class Parameter(VersionedObject):
8198
def __init__(
8299
self,
@@ -105,7 +122,7 @@ def __init__(
105122
self._request_builder: Optional[str] = None
106123

107124
def source_code(self, async_mode: bool) -> str:
108-
return inspect.getsource(self._get_op(self.api_versions[-1], async_mode))
125+
return strip_version_from_docs(inspect.getsource(self._get_op(self.api_versions[-1], async_mode)))
109126

110127
@property
111128
def request_builder_name(self) -> Optional[str]:
@@ -256,6 +273,9 @@ def _get_operation(code_model: "CodeModel", name: str) -> Operation:
256273
get_names_by_api_version=_get_names_by_api_version,
257274
)
258275

276+
def doc(self, async_mode: bool) -> str:
277+
return strip_version_from_docs(self.generated_class(async_mode).__doc__)
278+
259279

260280
class Client:
261281
def __init__(self, code_model: "CodeModel") -> None:
@@ -279,6 +299,29 @@ def name(self) -> str:
279299
return list(self.code_model.api_version_to_metadata.values())[-1]["client"]["name"]
280300

281301

302+
class ModelAndEnum(VersionedObject):
303+
def __init__(self, code_model: "CodeModel", name: str) -> None:
304+
super().__init__(code_model, name)
305+
self._parents: List["ModelAndEnum"] = []
306+
307+
@property
308+
def generated_class(self):
309+
folder_api_version = self.code_model.api_version_to_folder_api_version[self.api_versions[-1]]
310+
module = importlib.import_module(f"{self.code_model.module_name}.{folder_api_version}.models")
311+
return getattr(module, self.name)
312+
313+
@property
314+
def source_code(self) -> str:
315+
return strip_version_from_docs(inspect.getsource(self.generated_class))
316+
317+
@property
318+
def parents(self) -> List["ModelAndEnum"]:
319+
if not self._parents:
320+
for parent in self.generated_class.__mro__[1 : len(self.generated_class.__mro__) - 2]:
321+
self._parents.append(self.code_model.models[parent.__name__])
322+
return self._parents
323+
324+
282325
class CodeModel:
283326
def __init__(self, pkg_path: Path):
284327
self._root_of_code = pkg_path
@@ -297,6 +340,9 @@ def __init__(self, pkg_path: Path):
297340
self.default_folder_api_version = self.api_version_to_folder_api_version[self.default_api_version]
298341
self.module_name = pkg_path.stem.replace("-", ".")
299342
self.operation_groups = self._combine_operation_groups()
343+
self.models: Dict[str, ModelAndEnum] = {}
344+
self.enums: List[ModelAndEnum] = []
345+
self._combine_models_and_enums()
300346
self.client = Client(self)
301347

302348
def get_root_of_code(self, async_mode: bool) -> Path:
@@ -343,6 +389,35 @@ def _get_operation_group(code_model: "CodeModel", name: str):
343389
operation.combine_parameters()
344390
return ogs
345391

392+
def _combine_models_and_enums(self) -> None:
393+
def _get_model(code_model: "CodeModel", name: str) -> ModelAndEnum:
394+
return ModelAndEnum(code_model, name)
395+
396+
def _get_names_by_api_version(api_version: str):
397+
folder_api_version = self.api_version_to_folder_api_version[api_version]
398+
module = importlib.import_module(f"{self.module_name}.{folder_api_version}.models")
399+
return [m for m in dir(module) if m[0] != "_"]
400+
401+
models_and_enums = _combine_helper(
402+
code_model=self,
403+
sorted_api_versions=self.sorted_api_versions,
404+
get_cls=_get_model,
405+
get_names_by_api_version=_get_names_by_api_version,
406+
)
407+
for m in models_and_enums:
408+
if hasattr(m.generated_class, "from_dict"):
409+
self.models[m.name] = m
410+
else:
411+
self.enums.append(m)
412+
self._sort_models()
413+
414+
def _sort_models(self) -> None:
415+
seen_model_names: Set[str] = set()
416+
sorted_models: Dict[str, ModelAndEnum] = {}
417+
for model in self.models.values():
418+
sorted_models.update({m.name: m for m in _sort_models_helper(model, seen_model_names)})
419+
self.models = sorted_models
420+
346421

347422
class Serializer:
348423
def __init__(self, code_model: "CodeModel") -> None:
@@ -486,7 +561,9 @@ def serialize_client(self, async_mode: bool):
486561

487562
main_client_source = "class" + "class".join(split_main_client_source[1:])
488563

489-
client_initialization = re.search(r"((?s).*?) @classmethod", main_client_source).group(1)
564+
client_initialization = strip_version_from_docs(
565+
re.search(r"((?s).*?) @classmethod", main_client_source).group(1)
566+
)
490567

491568
# TODO: switch to current file path
492569
with open(f"{self.code_model.get_root_of_code(async_mode)}/_client.py", "w") as fd:
@@ -532,12 +609,43 @@ def serialize_general(self):
532609
with open(f"{self.code_model.get_root_of_code(False)}/_validation.py", "w") as fd:
533610
fd.write(self.env.get_template("validation.py.jinja2").render())
534611

612+
def serialize_models_folder(self):
613+
# serialize init file
614+
models_folder = self.code_model.get_root_of_code(False) / "models"
615+
Path(models_folder).mkdir(parents=True, exist_ok=True)
616+
with open(f"{models_folder}/__init__.py", "w") as fd:
617+
fd.write(self.env.get_template("models_init.py.jinja2").render(code_model=self.code_model))
618+
default_api_version = self.code_model.default_folder_api_version
619+
default_models_folder_name = f"{self.code_model.module_name}.{default_api_version}.models"
620+
621+
# serialize models file
622+
default_models_module = importlib.import_module(f"{default_models_folder_name}._models_py3")
623+
imports = inspect.getsource(default_models_module).split("class")[0]
624+
imports = modify_relative_imports(r"from (.*) import _serialization", imports)
625+
with open(f"{models_folder}/_models.py", "w") as fd:
626+
fd.write(self.env.get_template("models.py.jinja2").render(code_model=self.code_model, imports=imports))
627+
628+
# serialize enums file
629+
default_enums_module = importlib.import_module(
630+
f"{default_models_folder_name}.{self.code_model.client.generated_filename}_enums"
631+
)
632+
imports = inspect.getsource(default_enums_module).split("class")[0]
633+
if self.code_model.enums:
634+
with open(f"{models_folder}/_enums.py", "w") as fd:
635+
fd.write(self.env.get_template("enums.py.jinja2").render(code_model=self.code_model, imports=imports))
636+
637+
# serialize patch file
638+
with open(f"{models_folder}/_patch.py", "w") as wfd:
639+
with open(f"{self.code_model.get_root_of_code(False)}/{default_api_version}/models/_patch.py", "r") as rfd:
640+
wfd.write(rfd.read())
641+
535642
def remove_versioned_files(self):
536643
root_of_code = self.code_model.get_root_of_code(False)
537644
for api_version_folder_stem in self.code_model.api_version_to_folder_api_version.values():
538645
api_version_folder = root_of_code / api_version_folder_stem
539646
shutil.rmtree(api_version_folder / Path("operations"), ignore_errors=True)
540647
shutil.rmtree(api_version_folder / Path("aio"), ignore_errors=True)
648+
shutil.rmtree(api_version_folder / Path("models"), ignore_errors=True)
541649
files_to_remove = [
542650
"__init__.py",
543651
"_configuration.py",
@@ -551,12 +659,13 @@ def remove_versioned_files(self):
551659
for file in files_to_remove:
552660
os.remove(f"{api_version_folder}/{file}")
553661

554-
# add empty init file so we can still see the models folder
555-
with open(f"{api_version_folder}/__init__.py", "w") as f:
556-
f.write("")
557-
558662
def remove_top_level_files(self, async_mode: bool):
559-
top_level_files = [self.code_model.client.generated_filename, "_operations_mixin"]
663+
top_level_files = [
664+
self.code_model.client.generated_filename,
665+
"_operations_mixin",
666+
]
667+
if not async_mode:
668+
top_level_files.append("models")
560669
for file in top_level_files:
561670
os.remove(f"{self.code_model.get_root_of_code(async_mode)}/{file}.py")
562671

@@ -571,8 +680,8 @@ def serialize(self):
571680
self.serialize_client(async_mode=False)
572681
self.serialize_client(async_mode=True)
573682
self.serialize_general()
683+
self.serialize_models_folder()
574684
self.remove_old_code()
575-
# self.serialize_models_file()
576685

577686

578687
def get_args() -> argparse.Namespace:

tools/azure-sdk-tools/packaging_tools/templates/multiapi_combiner/client.py.jinja2

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ from {{ ".." if async_mode else "." }}_validation import api_version_validation
1313

1414
{{ getsource(generated_client._models_dict) }}
1515

16-
{{ getsource(generated_client.models) }}
17-
1816
{% for operation_group in operation_group_properties %}
1917

2018
@property
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{{ imports }}
2+
3+
{% for enum in code_model.enums %}
4+
{{ enum.source_code }}
5+
{% endfor %}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{{ imports }}
2+
3+
{% for model in code_model.models.values() %}
4+
{{ model.source_code }}
5+
{% endfor %}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# coding=utf-8
2+
# --------------------------------------------------------------------------
3+
# Copyright (c) Microsoft Corporation. All rights reserved.
4+
# Licensed under the MIT License. See License.txt in the project root for license information.
5+
# Code generated by Microsoft (R) AutoRest Code Generator.
6+
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
7+
# --------------------------------------------------------------------------
8+
{% if code_model.models %}
9+
from ._models import (
10+
{% for model in code_model.models.keys() %}
11+
{{ model }},
12+
{% endfor %}
13+
)
14+
{% endif %}
15+
16+
{% if code_model.enums %}
17+
from ._enums import (
18+
{% for enum in code_model.enums %}
19+
{{ enum.name }},
20+
{% endfor %}
21+
)
22+
{% endif %}
23+
24+
from ._patch import __all__ as _patch_all
25+
from ._patch import * # pylint: disable=unused-wildcard-import
26+
from ._patch import patch_sdk as _patch_sdk
27+
28+
__all__ = [
29+
{% for model in code_model.models.keys() %}
30+
"{{ model }}",
31+
{% endfor %}
32+
{% for enum in code_model.enums %}
33+
"{{ enum.name }}",
34+
{% endfor %}
35+
]
36+
__all__.extend([p for p in _patch_all if p not in __all__])
37+
_patch_sdk()

tools/azure-sdk-tools/packaging_tools/templates/multiapi_combiner/operation_group.py.jinja2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
class {{ operation_group.name }}{{ "(" + operation_group.name.replace("Operations", "") + "ABC)" if operation_group.is_mixin else "" }}:
77
{% if not operation_group.is_mixin %}
88
"""
9-
{{ operation_group.generated_class(async_mode).__doc__ }}
9+
{{ operation_group.doc(async_mode) }}
1010
"""
1111
models = _models
1212

0 commit comments

Comments
 (0)