1414from loki import (
1515 FindVariables , DerivedType , SymbolAttributes ,
1616 Array , single_variable_declaration , Transformer ,
17- BasicType
17+ BasicType , as_tuple
1818)
1919import pickle
2020import os
@@ -31,6 +31,12 @@ def __init__(self):
3131 "KLON" , "YDCPG_OPTS%KLON" , "YDGEOMETRY%YRDIM%NPROMA" ,
3232 "KPROMA" , "YDDIM%NPROMA" , "NPROMA"
3333 ]
34+ self .map_compute = {
35+ "OpenMP" : self .create_compute_openmp ,
36+ "OpenMPSingleColumn" : self .create_compute_openmpscc ,
37+ "OpenACCSingleColumn" : self .create_compute_openaccscc
38+ }
39+
3440 #TODO : do smthg for opening field_index.pkl
3541 with open (os .getcwd ()+ "/transformations/transformations/field_index.pkl" , 'rb' ) as fp :
3642 self .map_index = pickle .load (fp )
@@ -46,6 +52,7 @@ def __init__(self):
4652 self .routine_map_derived = {}
4753
4854 def transform_subroutine (self , routine , ** kwargs ):
55+ self .get_cpg (routine )
4956 with pragma_regions_attached (routine ):
5057 for region in FindNodes (ir .PragmaRegion ).visit (routine .body ):
5158 if is_loki_pragma (region .pragma ):
@@ -56,6 +63,7 @@ def transform_subroutine(self, routine, **kwargs):
5663 self .add_derived (routine )
5764 #call add_arrays etc...
5865
66+
5967 def process_parallel_region (self , routine , region ):
6068 pragma_content = region .pragma .content .split (maxsplit = 1 )
6169 pragma_content = [entry .split ('=' , maxsplit = 1 ) for entry in pragma_content [1 ].split (',' )]
@@ -79,22 +87,21 @@ def process_parallel_region(self, routine, region):
7987 region_map_derived = self .decl_derived_types (routine , region )
8088
8189 self .get_data = {}
90+ self .compute = {}
8291### self.synchost = {} #synchost same for all the targets
8392### self.nullify = {} #synchost same for all the targets
8493
8594 self .synchost = self .create_synchost (routine , region_name , region_map_derived , region_map_temp )
8695 self .nullify = self .create_nullify (routine , region_name , region_map_derived , region_map_temp )
8796
88-
8997 for target in pragma_attrs ['target' ]:
90-
9198# Q : I would like get_data, synchost and nullify not be members of the Transformation object, however, I need them to run the test...
9299# A : maybe have them as members of the routine while
93100# Is there an object to handle data that is needed for tests ?
94101# get_data = self.create_pt_sync(routine, region_name, True, region_map_derived, region_map_temp)
95102# synchost = self.create_synchost(routine, region_name, True, region_map_derived, region_map_temp)
96103# nullify = self.create_nullify(routine, region_name, True, region_map_derived, region_map_temp)
97- self .process_target (routine , target , region_name , region_map_temp , region_map_derived )
104+ self .process_target (routine , region , region_name , region_map_temp , region_map_derived , target )
98105 for var_name in region_map_temp :
99106 if var_name not in self .routine_map_temp :
100107 self .routine_map_temp [var_name ]= region_map_temp [var_name ]
@@ -103,12 +110,9 @@ def process_parallel_region(self, routine, region):
103110 if var_name not in self .routine_map_derived :
104111 self .routine_map_derived [var_name ]= region_map_derived [var_name ]
105112
106- def process_target (self , routine , target , region_name , region_map_temp , region_map_derived ):
107-
113+ def process_target (self , routine , region , region_name , region_map_temp , region_map_derived , target ):
108114 self .get_data [target ] = self .create_pt_sync (routine , target , region_name , True , region_map_derived , region_map_temp )
109- ### self.synchost[target] = self.create_synchost(routine, target, region_name, region_map_derived, region_map_temp)
110- ### self.nullify[target] = self.create_nullify(routine, target, region_name, region_map_derived, region_map_temp)
111-
115+ self .compute [target ] = self .map_compute [target ](routine , region , region_name , region_map_temp , region_map_derived )
112116
113117 @staticmethod
114118 def create_dr_hook_calls (scope , cdname , handle ):
@@ -232,12 +236,12 @@ def decl_derived_types(self, routine, region):
232236 # Creating the pointer on the data : YL_A
233237 data_name = f"Z_{ var .name .replace ('%' , '_' )} "
234238 if "REAL" and "JPRB" in value [0 ]:
239+ data_dim = value [2 ] + 1
240+ data_shape = (sym .RangeIndex ((None , None )),) * data_dim
235241 data_type = SymbolAttributes (
236242 dtype = BasicType .REAL , kind = routine .symbol_map ['JPRB' ],
237- pointer = True
243+ pointer = True , shape = data_shape
238244 )
239- data_dim = value [2 ] + 1
240- data_shape = (sym .RangeIndex ((None , None )),) * data_dim
241245 ptr_var = sym .Variable (name = data_name , type = data_type , dimensions = data_shape , scope = routine )
242246
243247 else :
@@ -293,7 +297,6 @@ def create_pt_sync(self, routine, target, region_name, is_get_data, region_map_d
293297
294298 call = sym .InlineCall (sym .Variable (name = f"{ sync_name } _{ intent } " ), parameters = (var [0 ],))
295299 sync_data += [ir .Assignment (lhs = var [1 ].clone (dimensions = None ), rhs = call , ptr = True )]
296- #sync_data += [ir.Assignment(lhs=(var[1],), rhs=(call,), ptr=True)]
297300
298301 sync_data .append (dr_hook_calls [1 ])
299302
@@ -315,3 +318,96 @@ def create_nullify(self, routine, region_name, region_map_derived, region_map_te
315318 nullify .append (dr_hook_calls [1 ])
316319 return nullify
317320
321+ def get_cpg (self ,routine ):
322+ #Assuming CPG_OPTS_TYPE and CPG_BNDS_TYPE are the same in all the routine.
323+ found_opts = False
324+ found_bnds = False
325+ for var in FindVariables ().visit (routine .spec ):
326+ if var .type .dtype .name == "CPG_OPTS_TYPE" :
327+ self .cpg_opts = var
328+ found_opts = True
329+ if var .type .dtype .name == "CPG_BNDS_TYPE" :
330+ self .cpg_bnds = var
331+ found_bnds = True
332+ if (found_opts and found_bnds ) :
333+ if "YD" in self .cpg_bnds .name :
334+ lcpg_bnds_name = self .cpg_bnds .name .replace ("YD" , "YL" )
335+ self .lcpg_bnds = sym .Variable (name = lcpg_bnds_name , scope = routine )
336+ dcl = ir .VariableDeclaration (symbols = as_tuple (self .lcpg_bnds ))
337+ routine .spec .append (dcl )
338+ data_type = SymbolAttributes (
339+ dtype = BasicType .INTEGER , kind = routine .symbol_map ['JPIM' ]
340+ )
341+ self .jblk = sym .Variable (name = "JBLK" , type = data_type , scope = routine )
342+ routine .spec .append (self .jblk )
343+ return
344+ else :
345+ raise Exception (f"cpg_bnds unexpected name : { self .cpg_bnds .name } " )
346+
347+ def update_args (self , arg , region_map ):
348+ new_arg = region_map [arg .name ][1 ]
349+ dim = len (new_arg .dimensions )
350+ #dim = len(new_arg.shape)
351+ new_dimensions = (sym .RangeIndex ((None , None )),) * (dim - 1 )
352+ new_dimensions += (self .jblk ,)
353+ return new_arg .clone (dimensions = new_dimensions )
354+
355+ def create_compute_openmp (self , routine , region , region_name , region_map_temp , region_map_derived ):
356+ #ylcpg_bnds : new var to add to spec, type(ylcpg)=type(cpg_bnds)=CPG_BNDS_TYPE
357+
358+ #hook_compute 0
359+ #call ylcpg_bnds%init(ydcpg_opts)
360+ #!$omp parallel do private (jblk) firstprivate (ylcpg_bnds)
361+ #do jblk = 1, ydcpg_opts%kgpblks
362+ # call ylcpg_bnds%update(jblk)
363+ # call callee(ydgeometry, ydmodel, ylcpg_bnds%kidia, ... (...BLK))
364+ #enddo
365+ #hook_compute 1
366+
367+ init = ir .CallStatement (
368+ name = routine .resolve_typebound_var (f"{ self .lcpg_bnds .name } %INIT" ),
369+ arguments = (self .cpg_opts ,))
370+ #TODO : generate lst_private !!!!
371+ lst_private = "JBLK"
372+ pragma = ir .Pragma (keyword = "OMP" , content = f"PARALLEL DO PRIVATE { lst_private } FIRSTPRIVATE ({ self .lcpg_bnds } )" )
373+ update = ir .CallStatement (
374+ name = routine .resolve_typebound_var (f"{ self .lcpg_bnds .name } %UPDATE" ),
375+ arguments = (self .jblk ,)
376+ )
377+ #TODO : musn't be call but the body of the region here??
378+
379+ new_calls = []
380+ for call in FindNodes (ir .CallStatement ).visit (region ):
381+ if call .name != "DR_HOOK" :
382+ # for var in chain(region_map_temp.values(), region_map_derived.values()):
383+ new_arguments = []
384+ for arg in call .arguments :
385+ if arg .name in region_map_temp :
386+ new_arguments += [self .update_args (arg , region_map_temp )]
387+ elif arg .name in region_map_derived :
388+ new_arguments += [self .update_args (arg , region_map_derived )]
389+ elif arg .name_parts [0 ]== self .cpg_bnds .name :
390+ new_arguments += [routine .resolve_typebound_var (f"{ self .lcpg_bnds } %{ arg .name_parts [1 ]} " )]
391+ else :
392+ new_arguments += [arg ]
393+ new_calls += [call .clone (arguments = as_tuple (new_arguments ))]
394+
395+ new_calls = tuple (new_calls )
396+
397+ loop_body = (update ,) + new_calls
398+ loop = ir .Loop (variable = self .jblk , bounds = sym .LoopRange ((1 ,routine .resolve_typebound_var (f"{ self .cpg_opts } %KGPBLKS" ))), body = loop_body )
399+ dr_hook_calls = self .create_dr_hook_calls (
400+ routine , f"{ routine .name } :{ region_name } :COMPUTE" ,
401+ sym .Variable (name = 'ZHOOK_HANDLE_COMPUTE' , scope = routine )
402+ )
403+ new_region = (dr_hook_calls [0 ], init , pragma , loop , dr_hook_calls [1 ])
404+ return (new_region )
405+ # TODO : YLCPG_BNDS%INIT
406+ # TODO : OMP PARALLEL
407+ # sym.DeferredTypeSymbol
408+ #call : call.clone(name=..., args= tuple of the region var + dimensions!!!)
409+
410+ def create_compute_openmpscc (self , routine , region , region_name , region_map_temp , region_map_derived ):
411+ pass
412+ def create_compute_openaccscc (self , routine , region , region_name , region_map_temp , region_map_derived ):
413+ pass
0 commit comments