1111 FindNodes , nodes as ir
1212)
1313from loki .transform import Transformation
14+ from loki import (
15+ FindVariables , DerivedType , SymbolAttributes ,
16+ Array , single_variable_declaration , Transformer
17+ )
1418
1519__all__ = ['ParallelRoutineDispatchTransformation' ]
1620
1721
1822class ParallelRoutineDispatchTransformation (Transformation ):
1923
24+ def __init__ (self ):
25+ self .horizontal = [
26+ "KLON" , "YDCPG_OPTS%KLON" , "YDGEOMETRY%YRDIM%NPROMA" ,
27+ "KPROMA" , "YDDIM%NPROMA" , "NPROMA"
28+ ]
29+ # CALL FIELD_NEW (YL_ZA, UBOUNDS=[KLON, KFLEVG, KGPBLKS], LBOUNDS=[1, 0, 1], PERSISTENT=.TRUE.)
30+ self .new_calls = []
31+ # IF (ASSOCIATED (YL_ZA)) CALL FIELD_DELETE (YL_ZA)
32+ self .delete_calls = []
33+ self .routine_map_temp = {}
34+
2035 def transform_subroutine (self , routine , ** kwargs ):
2136 with pragma_regions_attached (routine ):
2237 for region in FindNodes (ir .PragmaRegion ).visit (routine .body ):
2338 if is_loki_pragma (region .pragma ):
2439 self .process_parallel_region (routine , region )
40+ single_variable_declaration (routine )
41+ self .add_temp (routine )
42+ self .add_field (routine )
43+ #call add_arrays etc...
2544
2645 def process_parallel_region (self , routine , region ):
2746 pragma_content = region .pragma .content .split (maxsplit = 1 )
@@ -41,6 +60,14 @@ def process_parallel_region(self, routine, region):
4160 region .prepend (dr_hook_calls [0 ])
4261 region .append (dr_hook_calls [1 ])
4362
63+ region_map_temp = self .decl_local_array (routine , region )
64+
65+ for var_name in region_map_temp :
66+ if var_name not in self .routine_map_temp :
67+ self .routine_map_temp [var_name ]= region_map_temp [var_name ]
68+
69+
70+
4471 @staticmethod
4572 def create_dr_hook_calls (scope , cdname , handle ):
4673 dr_hook_calls = []
@@ -56,3 +83,99 @@ def create_dr_hook_calls(scope, cdname, handle):
5683 )
5784 ]
5885 return dr_hook_calls
86+
87+ # @staticmethod
88+ # def get_local_array(routine):
89+
90+ def create_field_new_delete (self , routine , var , field_ptr_var ):
91+ # Create the FIELD_NEW call
92+ var_shape = var .shape
93+ ubounds = [d .upper if isinstance (d , sym .RangeIndex ) else d for d in var_shape ]
94+ ubounds += [sym .Variable (name = "KGPBLKS" , parent = routine .variable_map ["YDCPG_OPTS" ])]
95+ # NB: This is presumably what the condition should look like, however, the
96+ # logic is flawed in to_parallel and it will only insert lbounds if _the last_
97+ # dimension has an lbound. We emulate this with the second line here to
98+ # generate identical results, but this line should probably not be there
99+ has_lbounds = any (isinstance (d , sym .RangeIndex ) for d in var_shape )
100+ has_lbounds = has_lbounds and isinstance (var_shape [- 1 ], sym .RangeIndex )
101+ if has_lbounds :
102+ lbounds = [
103+ d .lower if isinstance (d , sym .RangeIndex ) else sym .IntLiteral (0 )
104+ for d in var_shape
105+ ]
106+ kwarguments = (
107+ ('UBOUNDS' , sym .LiteralList (ubounds )),
108+ ('LBOUNDS' , sym .LiteralList (lbounds )),
109+ ('PERSISTENT' , sym .LogicLiteral (True ))
110+ )
111+ else :
112+ kwarguments = (
113+ ('UBOUNDS' , sym .LiteralList (ubounds )),
114+ ('PERSISTENT' , sym .LogicLiteral (True ))
115+ )
116+ self .new_calls += [ir .CallStatement (
117+ name = sym .Variable (name = 'FIELD_NEW' , scope = routine ),
118+ arguments = (field_ptr_var ,),
119+ kwarguments = kwarguments
120+ )]
121+
122+ # Create the FIELD_DELETE CALL
123+ call = ir .CallStatement (sym .Variable (name = 'FIELD_DELETE' , scope = routine ), arguments = (field_ptr_var ,))
124+ condition = sym .InlineCall (sym .Variable (name = 'ASSOCIATED' ), parameters = (field_ptr_var ,))
125+ self .delete_calls += [ir .Conditional (condition = condition , inline = True , body = (call ,))]
126+
127+ def decl_local_array (self , routine , region ):
128+ temp_arrays = [var for var in FindVariables (Array ).visit (region ) if isinstance (var , Array ) and not var .name_parts [0 ] in routine .arguments and var .shape [0 ] in self .horizontal ]
129+ region_map_temp = {}
130+ #check if first dim NPROMA ?
131+ for var in temp_arrays :
132+ var_type = var .type
133+ var_shape = var .shape
134+
135+ dim = len (var_shape ) + 1 # Temporary dimensions + block
136+
137+ # The FIELD_{d}RB variable
138+ field_ptr_type = SymbolAttributes (
139+ dtype = DerivedType (f'FIELD_{ dim } RB' ),
140+ pointer = True , polymorphic = True , initial = "NULL()"
141+ )
142+ field_ptr_var = sym .Variable (name = f'YL_{ var .name } ' , type = field_ptr_type , scope = routine )
143+
144+ # Create a pointer instead of the array
145+ shape = (sym .RangeIndex ((None , None )),) * dim
146+ # var.type = var_type.clone(pointer=True, shape=shape)
147+ local_ptr_var = var .clone (dimensions = shape )
148+
149+ region_map_temp [var .name ]= [field_ptr_var ,local_ptr_var ]
150+ self .create_field_new_delete (routine , var , field_ptr_var )
151+ return (region_map_temp )
152+
153+ def add_temp (self , routine ):
154+ # Replace temporary declaration by pointer to array and pointer to field_api object
155+ map_dcl = {}
156+ for decl in routine .declarations :
157+ if len (decl .symbols ) == 1 :
158+ var = decl .symbols [0 ]
159+ if var .name in self .routine_map_temp :
160+ new_vars = self .routine_map_temp [var .name ]
161+ new_vars [1 ].type = new_vars [1 ].type .clone (pointer = True , shape = new_vars [1 ].dimensions , initial = "NULL()" )
162+ map_dcl .update ({decl : (ir .VariableDeclaration (symbols = (new_vars [0 ],)), ir .VariableDeclaration (symbols = (new_vars [1 ],)))})
163+ else :
164+ raise Exception ("Declaration should have only one symbol, please run single_variable_declaration before calling add_temp function." )
165+ routine .spec = Transformer (map_dcl ).visit (routine .spec )
166+ #decls_to_replace = [var for var in decl.var for decl in routine.declarations if var in self.routine_map_temp]
167+
168+ def add_field (self , routine ):
169+ # Insert the field generation wrapped into a DR_HOOK call
170+ dr_hook_calls = self .create_dr_hook_calls (
171+ routine , cdname = 'CREATE_TEMPORARIES' ,
172+ handle = sym .Variable (name = 'ZHOOK_HANDLE_FIELD_API' , scope = routine )
173+ )
174+ routine .body .insert (2 , (dr_hook_calls [0 ], ir .Comment (text = '' ), * self .new_calls , dr_hook_calls [1 ]))
175+
176+ # Insert the field deletion wrapped into a DR_HOOK call
177+ dr_hook_calls = self .create_dr_hook_calls (
178+ routine , cdname = 'DELETE_TEMPORARIES' ,
179+ handle = sym .Variable (name = 'ZHOOK_HANDLE_FIELD_API' , scope = routine )
180+ )
181+ routine .body .insert (- 2 ,(dr_hook_calls [0 ], ir .Comment (text = '' ), * self .delete_calls , dr_hook_calls [1 ]))
0 commit comments