1- from typing import List , Tuple , Dict , Set , Any
1+ from collections import defaultdict
22from difflib import unified_diff
3+ from pathlib import Path
4+ from typing import List , Tuple , Dict , Iterator , Iterable
35
46import click
57from robot .api import get_model
1214 GlobalFormattingConfig
1315)
1416
17+ INCLUDE_EXT = ('.robot' , '.resource' )
18+
1519
1620class Robotidy :
1721 def __init__ (self ,
1822 transformers : List [Tuple [str , List ]],
19- transformers_config : Dict [ str , List ],
20- src : Set ,
23+ transformers_config : List [ Tuple [ str , List ] ],
24+ src : Tuple [ str , ...] ,
2125 overwrite : bool ,
2226 show_diff : bool ,
2327 formatting_config : GlobalFormattingConfig ,
2428 verbose : bool ,
2529 check : bool
2630 ):
27- self .sources = src
31+ self .sources = self . get_paths ( src )
2832 self .overwrite = overwrite
2933 self .show_diff = show_diff
3034 self .check = check
3135 self .verbose = verbose
3236 self .formatting_config = formatting_config
37+ transformers_config = self .convert_configure (transformers_config )
3338 self .transformers = load_transformers (transformers , transformers_config )
39+ for transformer in self .transformers :
40+ # inject global settings TODO: handle it better
41+ setattr (transformer , 'formatting_config' , self .formatting_config )
3442
3543 def transform_files (self ):
3644 changed_files = 0
@@ -39,13 +47,8 @@ def transform_files(self):
3947 if self .verbose :
4048 click .echo (f'Transforming { source } file' )
4149 model = get_model (source )
42- old_model = StatementLinesCollector (model )
43- for transformer in self .transformers :
44- # inject global settings TODO: handle it better
45- setattr (transformer , 'formatting_config' , self .formatting_config )
46- transformer .visit (model )
47- new_model = StatementLinesCollector (model )
48- if new_model != old_model :
50+ diff , old_model , new_model = self .transform (model )
51+ if diff :
4952 changed_files += 1
5053 self .output_diff (model .source , old_model , new_model )
5154 if not self .check :
@@ -59,6 +62,13 @@ def transform_files(self):
5962 return 0
6063 return 1
6164
65+ def transform (self , model ):
66+ old_model = StatementLinesCollector (model )
67+ for transformer in self .transformers :
68+ transformer .visit (model )
69+ new_model = StatementLinesCollector (model )
70+ return new_model != old_model , old_model , new_model
71+
6272 def save_model (self , model ):
6373 if self .overwrite :
6474 model .save ()
@@ -71,3 +81,32 @@ def output_diff(self, path: str, old_model: StatementLinesCollector, new_model:
7181 lines = list (unified_diff (old , new , fromfile = f'{ path } \t before' , tofile = f'{ path } \t after' ))
7282 colorized_output = decorate_diff_with_color (lines )
7383 click .echo (colorized_output .encode ('ascii' , 'ignore' ).decode ('ascii' ), color = True )
84+
85+ def get_paths (self , src : Tuple [str , ...]):
86+ sources = set ()
87+ for s in src :
88+ path = Path (s ).resolve ()
89+ if path .is_file ():
90+ sources .add (path )
91+ elif path .is_dir ():
92+ sources .update (self .iterate_dir (path .iterdir ()))
93+ elif s == '-' :
94+ sources .add (path )
95+
96+ return sources
97+
98+ def iterate_dir (self , paths : Iterable [Path ]) -> Iterator [Path ]:
99+ for path in paths :
100+ if path .is_file ():
101+ if path .suffix not in INCLUDE_EXT :
102+ continue
103+ yield path
104+ elif path .is_dir ():
105+ yield from self .iterate_dir (path .iterdir ())
106+
107+ @staticmethod
108+ def convert_configure (configure : List [Tuple [str , List ]]) -> Dict [str , List ]:
109+ config_map = defaultdict (list )
110+ for transformer , args in configure :
111+ config_map [transformer ].extend (args )
112+ return config_map
0 commit comments