Skip to content

Commit 815aa12

Browse files
authored
Merge pull request #319 from ecmwf-ifs/naan-hoist-after-inline
`HoistVariablesAnalysis`: remove unused explicit interfaces after inlining
2 parents eb793e2 + b9132db commit 815aa12

File tree

5 files changed

+263
-5
lines changed

5 files changed

+263
-5
lines changed

loki/transformations/hoist_variables.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def transform_subroutine(self, routine, **kwargs):
152152
for child in successors:
153153
if not isinstance(child, ProcedureItem):
154154
continue
155+
155156
arg_map = dict(call_map[child.local_name].arg_iter())
156157
hoist_variables = []
157158
for var in child.trafo_data[self._key]["hoist_variables"]:

loki/transformations/inline.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,18 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
611611
# Remove import if no further symbols used, otherwise clone with new symbols
612612
import_map[impt] = impt.clone(symbols=new_symbols) if new_symbols else None
613613

614+
# Remove explicit interfaces of inlined routines
615+
for intf in routine.interfaces:
616+
if not intf.spec:
617+
_body = tuple(
618+
s.type.dtype.procedure for s in intf.symbols
619+
if s.name not in callees or s.name in not_inlined
620+
)
621+
if _body:
622+
import_map[intf] = intf.clone(body=_body)
623+
else:
624+
import_map[intf] = None
625+
614626
# Now move any callee imports we might need over to the caller
615627
new_imports = set()
616628
imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports)

loki/transformations/single_column/hoist.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,13 @@ def driver_variable_declaration(self, routine, variables):
6565

6666
# Add explicit device-side allocations/deallocations for hoisted temporaries
6767
vnames = ', '.join(v.name for v in variables)
68-
pragma = ir.Pragma(keyword='acc', content=f'enter data create({vnames})')
69-
pragma_post = ir.Pragma(keyword='acc', content=f'exit data delete({vnames})')
68+
if vnames:
69+
pragma = ir.Pragma(keyword='acc', content=f'enter data create({vnames})')
70+
pragma_post = ir.Pragma(keyword='acc', content=f'exit data delete({vnames})')
7071

71-
# Add comments around standalone pragmas to avoid false attachment
72-
routine.body.prepend((ir.Comment(''), pragma, ir.Comment('')))
73-
routine.body.append((ir.Comment(''), pragma_post, ir.Comment('')))
72+
# Add comments around standalone pragmas to avoid false attachment
73+
routine.body.prepend((ir.Comment(''), pragma, ir.Comment('')))
74+
routine.body.append((ir.Comment(''), pragma_post, ir.Comment('')))
7475

7576
def driver_call_argument_remapping(self, routine, call, variables):
7677
"""

loki/transformations/single_column/tests/test_single_column_coalesced.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,147 @@ def test_single_column_coalesced_hoist_nested_openacc(frontend, horizontal, vert
11271127
assert outer_kernel_pragmas[2].keyword == 'acc'
11281128
assert outer_kernel_pragmas[2].content == 'end data'
11291129

1130+
1131+
@pytest.mark.parametrize('frontend', available_frontends())
1132+
def test_single_column_coalesced_hoist_nested_inline_openacc(frontend, horizontal, vertical, blocking):
1133+
"""
1134+
Test the correct addition of OpenACC pragmas to SCC format code
1135+
when hoisting array temporaries to driver.
1136+
"""
1137+
1138+
fcode_driver = """
1139+
SUBROUTINE column_driver(nlon, nz, q, nb)
1140+
INTEGER, INTENT(IN) :: nlon, nz, nb ! Size of the horizontal and vertical
1141+
REAL, INTENT(INOUT) :: q(nlon,nz,nb)
1142+
INTEGER :: b, start, end
1143+
1144+
start = 1
1145+
end = nlon
1146+
do b=1, nb
1147+
call compute_column(start, end, nlon, nz, q(:,:,b))
1148+
end do
1149+
END SUBROUTINE column_driver
1150+
"""
1151+
1152+
fcode_outer_kernel = """
1153+
SUBROUTINE compute_column(start, end, nlon, nz, q)
1154+
INTEGER, INTENT(IN) :: start, end ! Iteration indices
1155+
INTEGER, INTENT(IN) :: nlon, nz ! Size of the horizontal and vertical
1156+
REAL, INTENT(INOUT) :: q(nlon,nz)
1157+
INTEGER :: jl, jk
1158+
REAL :: c
1159+
1160+
c = 5.345
1161+
DO JL = START, END
1162+
Q(JL, NZ) = Q(JL, NZ) + 1.0
1163+
END DO
1164+
1165+
!$loki inline
1166+
call update_q(start, end, nlon, nz, q, c)
1167+
1168+
DO JL = START, END
1169+
Q(JL, NZ) = Q(JL, NZ) * C
1170+
END DO
1171+
END SUBROUTINE compute_column
1172+
"""
1173+
1174+
fcode_inner_kernel = """
1175+
SUBROUTINE update_q(start, end, nlon, nz, q, c)
1176+
INTEGER, INTENT(IN) :: start, end ! Iteration indices
1177+
INTEGER, INTENT(IN) :: nlon, nz ! Size of the horizontal and vertical
1178+
REAL, INTENT(INOUT) :: q(nlon,nz)
1179+
REAL, INTENT(IN) :: c
1180+
REAL :: t(nlon,nz)
1181+
INTEGER :: jl, jk
1182+
1183+
DO jk = 2, nz
1184+
DO jl = start, end
1185+
t(jl, jk) = c * jk
1186+
q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
1187+
END DO
1188+
END DO
1189+
END SUBROUTINE update_q
1190+
"""
1191+
1192+
# Mimic the scheduler internal mechanis to apply the transformation cascade
1193+
outer_kernel_source = Sourcefile.from_source(fcode_outer_kernel, frontend=frontend)
1194+
inner_kernel_source = Sourcefile.from_source(fcode_inner_kernel, frontend=frontend)
1195+
driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend)
1196+
driver = driver_source['column_driver']
1197+
outer_kernel = outer_kernel_source['compute_column']
1198+
inner_kernel = inner_kernel_source['update_q']
1199+
outer_kernel.enrich(inner_kernel) # Attach kernel source to driver call
1200+
driver.enrich(outer_kernel) # Attach kernel source to driver call
1201+
1202+
driver_item = ProcedureItem(name='#column_driver', source=driver)
1203+
outer_kernel_item = ProcedureItem(name='#compute_column', source=outer_kernel)
1204+
inner_kernel_item = ProcedureItem(name='#update_q', source=inner_kernel)
1205+
1206+
scc_hoist = SCCHoistPipeline(
1207+
horizontal=horizontal, block_dim=blocking,
1208+
dim_vars=(vertical.size,), directive='openacc'
1209+
)
1210+
1211+
InlineTransformation(allowed_aliases=horizontal.index).apply(outer_kernel)
1212+
1213+
# Apply in reverse order to ensure hoisting analysis gets run on kernel first
1214+
scc_hoist.apply(inner_kernel, role='kernel', item=inner_kernel_item)
1215+
scc_hoist.apply(
1216+
outer_kernel, role='kernel', item=outer_kernel_item,
1217+
targets=['compute_q'], successors=()
1218+
)
1219+
scc_hoist.apply(
1220+
driver, role='driver', item=driver_item,
1221+
targets=['compute_column'], successors=(outer_kernel_item,)
1222+
)
1223+
1224+
# Ensure calls have correct arguments
1225+
# driver
1226+
calls = FindNodes(CallStatement).visit(driver.body)
1227+
assert len(calls) == 1
1228+
assert calls[0].arguments == ('start', 'end', 'nlon', 'nz', 'q(:, :, b)',
1229+
'compute_column_t(:, :, b)')
1230+
1231+
# Ensure a single outer parallel loop in driver
1232+
with pragmas_attached(driver, Loop):
1233+
driver_loops = FindNodes(Loop).visit(driver.body)
1234+
assert len(driver_loops) == 1
1235+
assert driver_loops[0].variable == 'b'
1236+
assert driver_loops[0].bounds == '1:nb'
1237+
assert driver_loops[0].pragma[0].keyword == 'acc'
1238+
assert driver_loops[0].pragma[0].content == 'parallel loop gang vector_length(nlon)'
1239+
1240+
# Ensure we have a kernel call in the driver loop
1241+
kernel_calls = FindNodes(CallStatement).visit(driver_loops[0])
1242+
assert len(kernel_calls) == 1
1243+
assert kernel_calls[0].name == 'compute_column'
1244+
1245+
# Ensure that the intermediate kernel contains two wrapped loops and an unwrapped call statement
1246+
with pragmas_attached(outer_kernel, Loop):
1247+
outer_kernel_loops = FindNodes(Loop).visit(outer_kernel.body)
1248+
assert len(outer_kernel_loops) == 2
1249+
assert outer_kernel_loops[0].variable == 'jl'
1250+
assert outer_kernel_loops[0].bounds == 'start:end'
1251+
assert outer_kernel_loops[0].pragma[0].keyword == 'acc'
1252+
assert outer_kernel_loops[0].pragma[0].content == 'loop vector'
1253+
1254+
# check correctly nested vertical loop from inlined routine
1255+
assert outer_kernel_loops[1] in FindNodes(Loop).visit(outer_kernel_loops[0].body)
1256+
1257+
# Ensure the call was inlined
1258+
assert not FindNodes(CallStatement).visit(outer_kernel.body)
1259+
1260+
# Ensure the routine has been marked properly
1261+
outer_kernel_pragmas = FindNodes(Pragma).visit(outer_kernel.ir)
1262+
assert len(outer_kernel_pragmas) == 3
1263+
assert outer_kernel_pragmas[0].keyword == 'acc'
1264+
assert outer_kernel_pragmas[0].content == 'routine vector'
1265+
assert outer_kernel_pragmas[1].keyword == 'acc'
1266+
assert outer_kernel_pragmas[1].content == 'data present(q, t)'
1267+
assert outer_kernel_pragmas[2].keyword == 'acc'
1268+
assert outer_kernel_pragmas[2].content == 'end data'
1269+
1270+
11301271
@pytest.mark.parametrize('frontend', available_frontends())
11311272
def test_single_column_coalesced_nested(frontend, horizontal, blocking):
11321273
"""

loki/transformations/tests/test_inline.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,109 @@ def test_inline_marked_subroutines(frontend, adjust_imports):
787787
assert imports[0].symbols == ('add_one', 'add_a_to_b')
788788

789789

790+
@pytest.mark.parametrize('frontend', available_frontends())
791+
def test_inline_marked_subroutines_with_interfaces(frontend):
792+
""" Test inlining of subroutines with explicit interfaces via marker pragmas. """
793+
794+
fcode_driver = """
795+
subroutine test_pragma_inline(a, b)
796+
implicit none
797+
798+
interface
799+
subroutine add_a_to_b(a, b, n)
800+
real(kind=8), intent(inout) :: a(:), b(:)
801+
integer, intent(in) :: n
802+
end subroutine add_a_to_b
803+
subroutine add_one(a)
804+
real(kind=8), intent(inout) :: a
805+
end subroutine add_one
806+
end interface
807+
808+
interface
809+
subroutine add_two(a)
810+
real(kind=8), intent(inout) :: a
811+
end subroutine add_two
812+
end interface
813+
814+
real(kind=8), intent(inout) :: a(3), b(3)
815+
integer, parameter :: n = 3
816+
integer :: i
817+
818+
do i=1, n
819+
!$loki inline
820+
call add_one(a(i))
821+
end do
822+
823+
!$loki inline
824+
call add_a_to_b(a(:), b(:), 3)
825+
826+
do i=1, n
827+
call add_one(b(i))
828+
!$loki inline
829+
call add_two(b(i))
830+
end do
831+
832+
end subroutine test_pragma_inline
833+
"""
834+
835+
fcode_module = """
836+
module util_mod
837+
implicit none
838+
839+
contains
840+
subroutine add_one(a)
841+
real(kind=8), intent(inout) :: a
842+
a = a + 1
843+
end subroutine add_one
844+
845+
subroutine add_two(a)
846+
real(kind=8), intent(inout) :: a
847+
a = a + 2
848+
end subroutine add_two
849+
850+
subroutine add_a_to_b(a, b, n)
851+
real(kind=8), intent(inout) :: a(:), b(:)
852+
integer, intent(in) :: n
853+
integer :: i
854+
855+
do i = 1, n
856+
a(i) = a(i) + b(i)
857+
end do
858+
end subroutine add_a_to_b
859+
end module util_mod
860+
"""
861+
862+
module = Module.from_source(fcode_module, frontend=frontend)
863+
driver = Subroutine.from_source(fcode_driver, frontend=frontend)
864+
driver.enrich(module.subroutines)
865+
866+
calls = FindNodes(ir.CallStatement).visit(driver.body)
867+
assert calls[0].routine == module['add_one']
868+
assert calls[1].routine == module['add_a_to_b']
869+
assert calls[2].routine == module['add_one']
870+
assert calls[3].routine == module['add_two']
871+
872+
inline_marked_subroutines(routine=driver, allowed_aliases=('I',))
873+
874+
# Check inlined loops and assignments
875+
assert len(FindNodes(ir.Loop).visit(driver.body)) == 3
876+
assign = FindNodes(ir.Assignment).visit(driver.body)
877+
assert len(assign) == 3
878+
assert assign[0].lhs == 'a(i)' and assign[0].rhs == 'a(i) + 1'
879+
assert assign[1].lhs == 'a(i)' and assign[1].rhs == 'a(i) + b(i)'
880+
assert assign[2].lhs == 'b(i)' and assign[2].rhs == 'b(i) + 2'
881+
882+
# Check that the last call is left untouched
883+
calls = FindNodes(ir.CallStatement).visit(driver.body)
884+
assert len(calls) == 1
885+
assert calls[0].routine.name == 'add_one'
886+
assert calls[0].arguments == ('b(i)',)
887+
888+
intfs = FindNodes(ir.Interface).visit(driver.spec)
889+
assert len(intfs) == 1
890+
assert intfs[0].symbols == ('add_one',)
891+
892+
790893
@pytest.mark.parametrize('frontend', available_frontends())
791894
@pytest.mark.parametrize('adjust_imports', [True, False])
792895
def test_inline_marked_routine_with_optionals(frontend, adjust_imports):

0 commit comments

Comments
 (0)