1111import argparse
1212from pathlib import Path
1313import shutil
14- from typing import Dict , Optional , List , Any , TypeVar , Callable
14+ from typing import Dict , Optional , List , Any , TypeVar , Callable , Set
1515
1616from jinja2 import PackageLoader , Environment
1717
@@ -34,6 +34,10 @@ def modify_relative_imports(regex: str, file: str) -> str:
3434 return file .replace (original_str , new_str )
3535
3636
37+ def strip_version_from_docs (input : str ) -> str :
38+ return re .sub (r".v20[^.]*" , "" , input )
39+
40+
3741class VersionedObject :
3842 """An object that can be added / removed in an api version"""
3943
@@ -77,6 +81,19 @@ def _combine_helper(
7781 return objs
7882
7983
84+ def _sort_models_helper (current : "ModelAndEnum" , seen_model_names : Set [str ]) -> List ["ModelAndEnum" ]:
85+ if current .name in seen_model_names :
86+ return []
87+ ancestors : List ["ModelAndEnum" ] = [current ]
88+ for parent in current .parents :
89+ if parent .name in seen_model_names :
90+ continue
91+ ancestors = _sort_models_helper (parent , seen_model_names ) + ancestors
92+ seen_model_names .add (parent .name )
93+ seen_model_names .add (current .name )
94+ return ancestors
95+
96+
8097class Parameter (VersionedObject ):
8198 def __init__ (
8299 self ,
@@ -105,7 +122,7 @@ def __init__(
105122 self ._request_builder : Optional [str ] = None
106123
107124 def source_code (self , async_mode : bool ) -> str :
108- return inspect .getsource (self ._get_op (self .api_versions [- 1 ], async_mode ))
125+ return strip_version_from_docs ( inspect .getsource (self ._get_op (self .api_versions [- 1 ], async_mode ) ))
109126
110127 @property
111128 def request_builder_name (self ) -> Optional [str ]:
@@ -256,6 +273,9 @@ def _get_operation(code_model: "CodeModel", name: str) -> Operation:
256273 get_names_by_api_version = _get_names_by_api_version ,
257274 )
258275
276+ def doc (self , async_mode : bool ) -> str :
277+ return strip_version_from_docs (self .generated_class (async_mode ).__doc__ )
278+
259279
260280class Client :
261281 def __init__ (self , code_model : "CodeModel" ) -> None :
@@ -279,6 +299,29 @@ def name(self) -> str:
279299 return list (self .code_model .api_version_to_metadata .values ())[- 1 ]["client" ]["name" ]
280300
281301
302+ class ModelAndEnum (VersionedObject ):
303+ def __init__ (self , code_model : "CodeModel" , name : str ) -> None :
304+ super ().__init__ (code_model , name )
305+ self ._parents : List ["ModelAndEnum" ] = []
306+
307+ @property
308+ def generated_class (self ):
309+ folder_api_version = self .code_model .api_version_to_folder_api_version [self .api_versions [- 1 ]]
310+ module = importlib .import_module (f"{ self .code_model .module_name } .{ folder_api_version } .models" )
311+ return getattr (module , self .name )
312+
313+ @property
314+ def source_code (self ) -> str :
315+ return strip_version_from_docs (inspect .getsource (self .generated_class ))
316+
317+ @property
318+ def parents (self ) -> List ["ModelAndEnum" ]:
319+ if not self ._parents :
320+ for parent in self .generated_class .__mro__ [1 : len (self .generated_class .__mro__ ) - 2 ]:
321+ self ._parents .append (self .code_model .models [parent .__name__ ])
322+ return self ._parents
323+
324+
282325class CodeModel :
283326 def __init__ (self , pkg_path : Path ):
284327 self ._root_of_code = pkg_path
@@ -297,6 +340,9 @@ def __init__(self, pkg_path: Path):
297340 self .default_folder_api_version = self .api_version_to_folder_api_version [self .default_api_version ]
298341 self .module_name = pkg_path .stem .replace ("-" , "." )
299342 self .operation_groups = self ._combine_operation_groups ()
343+ self .models : Dict [str , ModelAndEnum ] = {}
344+ self .enums : List [ModelAndEnum ] = []
345+ self ._combine_models_and_enums ()
300346 self .client = Client (self )
301347
302348 def get_root_of_code (self , async_mode : bool ) -> Path :
@@ -343,6 +389,35 @@ def _get_operation_group(code_model: "CodeModel", name: str):
343389 operation .combine_parameters ()
344390 return ogs
345391
392+ def _combine_models_and_enums (self ) -> None :
393+ def _get_model (code_model : "CodeModel" , name : str ) -> ModelAndEnum :
394+ return ModelAndEnum (code_model , name )
395+
396+ def _get_names_by_api_version (api_version : str ):
397+ folder_api_version = self .api_version_to_folder_api_version [api_version ]
398+ module = importlib .import_module (f"{ self .module_name } .{ folder_api_version } .models" )
399+ return [m for m in dir (module ) if m [0 ] != "_" ]
400+
401+ models_and_enums = _combine_helper (
402+ code_model = self ,
403+ sorted_api_versions = self .sorted_api_versions ,
404+ get_cls = _get_model ,
405+ get_names_by_api_version = _get_names_by_api_version ,
406+ )
407+ for m in models_and_enums :
408+ if hasattr (m .generated_class , "from_dict" ):
409+ self .models [m .name ] = m
410+ else :
411+ self .enums .append (m )
412+ self ._sort_models ()
413+
414+ def _sort_models (self ) -> None :
415+ seen_model_names : Set [str ] = set ()
416+ sorted_models : Dict [str , ModelAndEnum ] = {}
417+ for model in self .models .values ():
418+ sorted_models .update ({m .name : m for m in _sort_models_helper (model , seen_model_names )})
419+ self .models = sorted_models
420+
346421
347422class Serializer :
348423 def __init__ (self , code_model : "CodeModel" ) -> None :
@@ -486,7 +561,9 @@ def serialize_client(self, async_mode: bool):
486561
487562 main_client_source = "class" + "class" .join (split_main_client_source [1 :])
488563
489- client_initialization = re .search (r"((?s).*?) @classmethod" , main_client_source ).group (1 )
564+ client_initialization = strip_version_from_docs (
565+ re .search (r"((?s).*?) @classmethod" , main_client_source ).group (1 )
566+ )
490567
491568 # TODO: switch to current file path
492569 with open (f"{ self .code_model .get_root_of_code (async_mode )} /_client.py" , "w" ) as fd :
@@ -532,12 +609,43 @@ def serialize_general(self):
532609 with open (f"{ self .code_model .get_root_of_code (False )} /_validation.py" , "w" ) as fd :
533610 fd .write (self .env .get_template ("validation.py.jinja2" ).render ())
534611
612+ def serialize_models_folder (self ):
613+ # serialize init file
614+ models_folder = self .code_model .get_root_of_code (False ) / "models"
615+ Path (models_folder ).mkdir (parents = True , exist_ok = True )
616+ with open (f"{ models_folder } /__init__.py" , "w" ) as fd :
617+ fd .write (self .env .get_template ("models_init.py.jinja2" ).render (code_model = self .code_model ))
618+ default_api_version = self .code_model .default_folder_api_version
619+ default_models_folder_name = f"{ self .code_model .module_name } .{ default_api_version } .models"
620+
621+ # serialize models file
622+ default_models_module = importlib .import_module (f"{ default_models_folder_name } ._models_py3" )
623+ imports = inspect .getsource (default_models_module ).split ("class" )[0 ]
624+ imports = modify_relative_imports (r"from (.*) import _serialization" , imports )
625+ with open (f"{ models_folder } /_models.py" , "w" ) as fd :
626+ fd .write (self .env .get_template ("models.py.jinja2" ).render (code_model = self .code_model , imports = imports ))
627+
628+ # serialize enums file
629+ default_enums_module = importlib .import_module (
630+ f"{ default_models_folder_name } .{ self .code_model .client .generated_filename } _enums"
631+ )
632+ imports = inspect .getsource (default_enums_module ).split ("class" )[0 ]
633+ if self .code_model .enums :
634+ with open (f"{ models_folder } /_enums.py" , "w" ) as fd :
635+ fd .write (self .env .get_template ("enums.py.jinja2" ).render (code_model = self .code_model , imports = imports ))
636+
637+ # serialize patch file
638+ with open (f"{ models_folder } /_patch.py" , "w" ) as wfd :
639+ with open (f"{ self .code_model .get_root_of_code (False )} /{ default_api_version } /models/_patch.py" , "r" ) as rfd :
640+ wfd .write (rfd .read ())
641+
535642 def remove_versioned_files (self ):
536643 root_of_code = self .code_model .get_root_of_code (False )
537644 for api_version_folder_stem in self .code_model .api_version_to_folder_api_version .values ():
538645 api_version_folder = root_of_code / api_version_folder_stem
539646 shutil .rmtree (api_version_folder / Path ("operations" ), ignore_errors = True )
540647 shutil .rmtree (api_version_folder / Path ("aio" ), ignore_errors = True )
648+ shutil .rmtree (api_version_folder / Path ("models" ), ignore_errors = True )
541649 files_to_remove = [
542650 "__init__.py" ,
543651 "_configuration.py" ,
@@ -551,12 +659,13 @@ def remove_versioned_files(self):
551659 for file in files_to_remove :
552660 os .remove (f"{ api_version_folder } /{ file } " )
553661
554- # add empty init file so we can still see the models folder
555- with open (f"{ api_version_folder } /__init__.py" , "w" ) as f :
556- f .write ("" )
557-
558662 def remove_top_level_files (self , async_mode : bool ):
559- top_level_files = [self .code_model .client .generated_filename , "_operations_mixin" ]
663+ top_level_files = [
664+ self .code_model .client .generated_filename ,
665+ "_operations_mixin" ,
666+ ]
667+ if not async_mode :
668+ top_level_files .append ("models" )
560669 for file in top_level_files :
561670 os .remove (f"{ self .code_model .get_root_of_code (async_mode )} /{ file } .py" )
562671
@@ -571,8 +680,8 @@ def serialize(self):
571680 self .serialize_client (async_mode = False )
572681 self .serialize_client (async_mode = True )
573682 self .serialize_general ()
683+ self .serialize_models_folder ()
574684 self .remove_old_code ()
575- # self.serialize_models_file()
576685
577686
578687def get_args () -> argparse .Namespace :
0 commit comments