Skip to content

Commit cacb023

Browse files
committed
Add temporary arrays transformation into pointers, and field_new/field_delete
1 parent 99267cb commit cacb023

File tree

2 files changed

+191
-5
lines changed

2 files changed

+191
-5
lines changed

transformations/tests/test_parallel_routine_dispatch.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytest
1212

1313
from loki.frontend import available_frontends, OMNI
14-
from loki import Sourcefile, FindNodes, CallStatement
14+
from loki import Sourcefile, FindNodes, CallStatement, fgen, Conditional
1515

1616
from transformations.parallel_routine_dispatch import ParallelRoutineDispatchTransformation
1717

@@ -33,8 +33,71 @@ def test_parallel_routine_dispatch_dr_hook(here, frontend):
3333
transformation = ParallelRoutineDispatchTransformation()
3434
transformation.apply(source['dispatch_routine'])
3535

36-
calls = FindNodes(CallStatement).visit(routine.body)
36+
calls = [call for call in FindNodes(CallStatement).visit(routine.body) if call.name.name=='DR_HOOK']
37+
assert len(calls) == 8
38+
39+
@pytest.mark.parametrize('frontend', available_frontends(skip=[OMNI]))
40+
def test_parallel_routine_dispatch_decl_local_arrays(here, frontend):
41+
42+
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
43+
routine = source['dispatch_routine']
44+
45+
transformation = ParallelRoutineDispatchTransformation()
46+
transformation.apply(source['dispatch_routine'])
47+
var_lst=["YL_ZRDG_CVGQ", "ZRDG_CVGQ", "YL_ZRDG_MU0LU", "ZRDG_MU0LU", "YL_ZRDG_MU0M", "ZRDG_MU0M", "YL_ZRDG_MU0N", "ZRDG_MU0N", "YL_ZRDG_MU0", "ZRDG_MU0"]
48+
dcls = [dcl for dcl in routine.declarations if dcl.symbols[0].name in var_lst]
49+
str_dcls = ""
50+
for dcl in dcls:
51+
str_dcls += fgen(dcl)+"\n"
52+
assert str_dcls == """CLASS(FIELD_3RB), POINTER :: YL_ZRDG_CVGQ => NULL()
53+
REAL(KIND=JPRB), POINTER :: ZRDG_CVGQ(:, :, :) => NULL()
54+
CLASS(FIELD_2RB), POINTER :: YL_ZRDG_MU0LU => NULL()
55+
REAL(KIND=JPRB), POINTER :: ZRDG_MU0LU(:, :) => NULL()
56+
CLASS(FIELD_2RB), POINTER :: YL_ZRDG_MU0M => NULL()
57+
REAL(KIND=JPRB), POINTER :: ZRDG_MU0M(:, :) => NULL()
58+
CLASS(FIELD_2RB), POINTER :: YL_ZRDG_MU0N => NULL()
59+
REAL(KIND=JPRB), POINTER :: ZRDG_MU0N(:, :) => NULL()
60+
CLASS(FIELD_2RB), POINTER :: YL_ZRDG_MU0 => NULL()
61+
REAL(KIND=JPRB), POINTER :: ZRDG_MU0(:, :) => NULL()
62+
"""
63+
64+
@pytest.mark.parametrize('frontend', available_frontends(skip=[OMNI]))
65+
def test_parallel_routine_dispatch_decl_field_create_delete(here, frontend):
66+
67+
source = Sourcefile.from_file(here/'sources/projParallelRoutineDispatch/dispatch_routine.F90', frontend=frontend)
68+
routine = source['dispatch_routine']
69+
70+
transformation = ParallelRoutineDispatchTransformation()
71+
transformation.apply(source['dispatch_routine'])
72+
73+
var_lst = ["YL_ZRDG_CVGQ", "ZRDG_CVGQ", "YL_ZRDG_MU0LU", "ZRDG_MU0LU", "YL_ZRDG_MU0M", "ZRDG_MU0M", "YL_ZRDG_MU0N", "ZRDG_MU0N", "YL_ZRDG_MU0", "ZRDG_MU0"]
74+
field_create = ["CALL FIELD_NEW(YL_ZRDG_CVGQ, UBOUNDS=(/ YDCPG_OPTS%KLON, YDCPG_OPTS%KFLEVG, YDCPG_OPTS%KGPBLKS /), LBOUNDS=(/ 0, 1 /), &\n& PERSISTENT=.true.)",
75+
"CALL FIELD_NEW(YL_ZRDG_MU0N, UBOUNDS=(/ YDCPG_OPTS%KLON, YDCPG_OPTS%KGPBLKS /), PERSISTENT=.true.)",
76+
"CALL FIELD_NEW(YL_ZRDG_MU0LU, UBOUNDS=(/ YDCPG_OPTS%KLON, YDCPG_OPTS%KGPBLKS /), PERSISTENT=.true.)",
77+
"CALL FIELD_NEW(YL_ZRDG_MU0, UBOUNDS=(/ YDCPG_OPTS%KLON, YDCPG_OPTS%KGPBLKS /), PERSISTENT=.true.)",
78+
"CALL FIELD_NEW(YL_ZRDG_MU0M, UBOUNDS=(/ YDCPG_OPTS%KLON, YDCPG_OPTS%KGPBLKS /), PERSISTENT=.true.)"
79+
]
80+
81+
calls = [call for call in FindNodes(CallStatement).visit(routine.body) if call.name.name=="FIELD_NEW"]
3782
assert len(calls) == 5
38-
assert [str(call.name).lower() for call in calls] == [
39-
'dr_hook', 'dr_hook', 'cpphinp', 'dr_hook', 'dr_hook'
40-
]
83+
for call in calls:
84+
assert fgen(call) in field_create
85+
86+
field_delete = ["IF (ASSOCIATED(YL_ZRDG_CVGQ)) CALL FIELD_DELETE(YL_ZRDG_CVGQ)",
87+
"IF (ASSOCIATED(YL_ZRDG_MU0LU)) CALL FIELD_DELETE(YL_ZRDG_MU0LU)",
88+
"IF (ASSOCIATED(YL_ZRDG_MU0M)) CALL FIELD_DELETE(YL_ZRDG_MU0M)",
89+
"IF (ASSOCIATED(YL_ZRDG_MU0)) CALL FIELD_DELETE(YL_ZRDG_MU0)",
90+
"IF (ASSOCIATED(YL_ZRDG_MU0N)) CALL FIELD_DELETE(YL_ZRDG_MU0N)"
91+
]
92+
93+
conds = [cond for cond in FindNodes(Conditional).visit(routine.body)]
94+
conditional = []
95+
for cond in conds:
96+
for call in FindNodes(CallStatement).visit(cond):
97+
if call.name.name=="FIELD_DELETE":
98+
conditional.append(cond)
99+
100+
assert len(conditional) == 5
101+
for cond in conditional:
102+
assert fgen(cond) in field_delete
103+
breakpoint()

transformations/transformations/parallel_routine_dispatch.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,36 @@
1111
FindNodes, nodes as ir
1212
)
1313
from loki.transform import Transformation
14+
from loki import (
15+
FindVariables, DerivedType, SymbolAttributes,
16+
Array, single_variable_declaration, Transformer
17+
)
1418

1519
__all__ = ['ParallelRoutineDispatchTransformation']
1620

1721

1822
class ParallelRoutineDispatchTransformation(Transformation):
1923

24+
def __init__(self):
25+
self.horizontal = [
26+
"KLON", "YDCPG_OPTS%KLON", "YDGEOMETRY%YRDIM%NPROMA",
27+
"KPROMA", "YDDIM%NPROMA", "NPROMA"
28+
]
29+
# CALL FIELD_NEW (YL_ZA, UBOUNDS=[KLON, KFLEVG, KGPBLKS], LBOUNDS=[1, 0, 1], PERSISTENT=.TRUE.)
30+
self.new_calls = []
31+
# IF (ASSOCIATED (YL_ZA)) CALL FIELD_DELETE (YL_ZA)
32+
self.delete_calls = []
33+
self.routine_map_temp = {}
34+
2035
def transform_subroutine(self, routine, **kwargs):
2136
with pragma_regions_attached(routine):
2237
for region in FindNodes(ir.PragmaRegion).visit(routine.body):
2338
if is_loki_pragma(region.pragma):
2439
self.process_parallel_region(routine, region)
40+
single_variable_declaration(routine)
41+
self.add_temp(routine)
42+
self.add_field(routine)
43+
#call add_arrays etc...
2544

2645
def process_parallel_region(self, routine, region):
2746
pragma_content = region.pragma.content.split(maxsplit=1)
@@ -41,6 +60,14 @@ def process_parallel_region(self, routine, region):
4160
region.prepend(dr_hook_calls[0])
4261
region.append(dr_hook_calls[1])
4362

63+
region_map_temp= self.decl_local_array(routine, region)
64+
65+
for var_name in region_map_temp:
66+
if var_name not in self.routine_map_temp:
67+
self.routine_map_temp[var_name]=region_map_temp[var_name]
68+
69+
70+
4471
@staticmethod
4572
def create_dr_hook_calls(scope, cdname, handle):
4673
dr_hook_calls = []
@@ -56,3 +83,99 @@ def create_dr_hook_calls(scope, cdname, handle):
5683
)
5784
]
5885
return dr_hook_calls
86+
87+
# @staticmethod
88+
# def get_local_array(routine):
89+
90+
def create_field_new_delete(self, routine, var, field_ptr_var):
91+
# Create the FIELD_NEW call
92+
var_shape = var.shape
93+
ubounds = [d.upper if isinstance(d, sym.RangeIndex) else d for d in var_shape]
94+
ubounds += [sym.Variable(name="KGPBLKS", parent=routine.variable_map["YDCPG_OPTS"])]
95+
# NB: This is presumably what the condition should look like, however, the
96+
# logic is flawed in to_parallel and it will only insert lbounds if _the last_
97+
# dimension has an lbound. We emulate this with the second line here to
98+
# generate identical results, but this line should probably not be there
99+
has_lbounds = any(isinstance(d, sym.RangeIndex) for d in var_shape)
100+
has_lbounds = has_lbounds and isinstance(var_shape[-1], sym.RangeIndex)
101+
if has_lbounds:
102+
lbounds = [
103+
d.lower if isinstance(d, sym.RangeIndex) else sym.IntLiteral(0)
104+
for d in var_shape
105+
]
106+
kwarguments = (
107+
('UBOUNDS', sym.LiteralList(ubounds)),
108+
('LBOUNDS', sym.LiteralList(lbounds)),
109+
('PERSISTENT', sym.LogicLiteral(True))
110+
)
111+
else:
112+
kwarguments = (
113+
('UBOUNDS', sym.LiteralList(ubounds)),
114+
('PERSISTENT', sym.LogicLiteral(True))
115+
)
116+
self.new_calls += [ir.CallStatement(
117+
name=sym.Variable(name='FIELD_NEW', scope=routine),
118+
arguments=(field_ptr_var,),
119+
kwarguments=kwarguments
120+
)]
121+
122+
# Create the FIELD_DELETE CALL
123+
call = ir.CallStatement(sym.Variable(name='FIELD_DELETE', scope=routine), arguments=(field_ptr_var,))
124+
condition = sym.InlineCall(sym.Variable(name='ASSOCIATED'), parameters=(field_ptr_var,))
125+
self.delete_calls += [ir.Conditional(condition=condition, inline=True, body=(call,))]
126+
127+
def decl_local_array(self, routine, region):
128+
temp_arrays = [var for var in FindVariables(Array).visit(region) if isinstance(var, Array) and not var.name_parts[0] in routine.arguments and var.shape[0] in self.horizontal]
129+
region_map_temp={}
130+
#check if first dim NPROMA ?
131+
for var in temp_arrays:
132+
var_type = var.type
133+
var_shape = var.shape
134+
135+
dim = len(var_shape) + 1 # Temporary dimensions + block
136+
137+
# The FIELD_{d}RB variable
138+
field_ptr_type = SymbolAttributes(
139+
dtype=DerivedType(f'FIELD_{dim}RB'),
140+
pointer=True, polymorphic=True, initial="NULL()"
141+
)
142+
field_ptr_var = sym.Variable(name=f'YL_{var.name}', type=field_ptr_type, scope=routine)
143+
144+
# Create a pointer instead of the array
145+
shape = (sym.RangeIndex((None, None)),) * dim
146+
# var.type = var_type.clone(pointer=True, shape=shape)
147+
local_ptr_var = var.clone(dimensions=shape)
148+
149+
region_map_temp[var.name]=[field_ptr_var,local_ptr_var]
150+
self.create_field_new_delete(routine, var, field_ptr_var)
151+
return(region_map_temp)
152+
153+
def add_temp(self, routine):
154+
# Replace temporary declaration by pointer to array and pointer to field_api object
155+
map_dcl = {}
156+
for decl in routine.declarations:
157+
if len(decl.symbols) == 1:
158+
var = decl.symbols[0]
159+
if var.name in self.routine_map_temp:
160+
new_vars = self.routine_map_temp[var.name]
161+
new_vars[1].type = new_vars[1].type.clone(pointer=True, shape=new_vars[1].dimensions, initial="NULL()")
162+
map_dcl.update({decl : (ir.VariableDeclaration(symbols=(new_vars[0],)), ir.VariableDeclaration(symbols=(new_vars[1],)))})
163+
else:
164+
raise Exception("Declaration should have only one symbol, please run single_variable_declaration before calling add_temp function.")
165+
routine.spec = Transformer(map_dcl).visit(routine.spec)
166+
#decls_to_replace = [var for var in decl.var for decl in routine.declarations if var in self.routine_map_temp]
167+
168+
def add_field(self, routine):
169+
# Insert the field generation wrapped into a DR_HOOK call
170+
dr_hook_calls = self.create_dr_hook_calls(
171+
routine, cdname='CREATE_TEMPORARIES',
172+
handle=sym.Variable(name='ZHOOK_HANDLE_FIELD_API', scope=routine)
173+
)
174+
routine.body.insert(2, (dr_hook_calls[0], ir.Comment(text=''), *self.new_calls, dr_hook_calls[1]))
175+
176+
# Insert the field deletion wrapped into a DR_HOOK call
177+
dr_hook_calls = self.create_dr_hook_calls(
178+
routine, cdname='DELETE_TEMPORARIES',
179+
handle=sym.Variable(name='ZHOOK_HANDLE_FIELD_API', scope=routine)
180+
)
181+
routine.body.insert(-2,(dr_hook_calls[0], ir.Comment(text=''), *self.delete_calls, dr_hook_calls[1]))

0 commit comments

Comments
 (0)