Skip to content

Commit d5a8e6c

Browse files
authored
Merge pull request #321 from ecmwf-ifs/naan-import-var-sizes
Add transformation generated imports
2 parents 8a9c4ec + 78d9f45 commit d5a8e6c

File tree

6 files changed

+138
-18
lines changed

6 files changed

+138
-18
lines changed

loki/transformations/hoist_variables.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,17 @@
8080
scheduler.process(transformation=HoistTemporaryArraysTransformationAllocatable(key=key))
8181
"""
8282

83+
from collections import defaultdict
84+
8385
from loki.batch import Transformation, ProcedureItem
8486
from loki.expression import (
8587
symbols as sym, FindVariables, FindInlineCalls,
8688
SubstituteExpressions, is_dimension_constant
8789
)
8890
from loki.ir import (
89-
CallStatement, Allocation, Deallocation, Transformer, FindNodes
91+
CallStatement, Allocation, Deallocation, Transformer, FindNodes, Comment, Import
9092
)
91-
from loki.tools.util import is_iterable, as_tuple, CaseInsensitiveDict
93+
from loki.tools.util import is_iterable, as_tuple, CaseInsensitiveDict, flatten
9294

9395
from loki.transformations.utilities import single_variable_declaration
9496

@@ -139,9 +141,14 @@ def transform_subroutine(self, routine, **kwargs):
139141
if role != 'driver':
140142
variables = self.find_variables(routine)
141143
item.trafo_data[self._key]["to_hoist"] = variables
144+
dims = flatten([getattr(v, 'shape', []) for v in variables])
145+
import_map = routine.import_map
146+
item.trafo_data[self._key]["imported_sizes"] = [(d.type.module, d) for d in dims
147+
if str(d) in import_map]
142148
item.trafo_data[self._key]["hoist_variables"] = [var.clone(name=f'{routine.name}_{var.name}')
143149
for var in variables]
144150
else:
151+
item.trafo_data[self._key]["imported_sizes"] = []
145152
item.trafo_data[self._key]["to_hoist"] = []
146153
item.trafo_data[self._key]["hoist_variables"] = []
147154

@@ -166,6 +173,7 @@ def transform_subroutine(self, routine, **kwargs):
166173
item.trafo_data[self._key]["hoist_variables"].extend(hoist_variables)
167174
item.trafo_data[self._key]["hoist_variables"] = list(dict.fromkeys(
168175
item.trafo_data[self._key]["hoist_variables"]))
176+
item.trafo_data[self._key]["imported_sizes"] += child.trafo_data[self._key]["imported_sizes"]
169177

170178
def find_variables(self, routine):
171179
"""
@@ -273,6 +281,25 @@ def transform_subroutine(self, routine, **kwargs):
273281
routine=routine, call=call, variables=hoisted_variables
274282
)
275283

284+
# Add imports used to define hoisted
285+
missing_imports_map = defaultdict(set)
286+
import_map = routine.import_map
287+
for module, var in item.trafo_data[self._key]["imported_sizes"]:
288+
if not var.name in import_map:
289+
missing_imports_map[module] |= {var}
290+
291+
if missing_imports_map:
292+
routine.spec.prepend(Comment(text=(
293+
'![Loki::HoistVariablesTransformation] ---------------------------------------'
294+
)))
295+
for module, variables in missing_imports_map.items():
296+
routine.spec.prepend(Import(module=module.name, symbols=variables))
297+
298+
routine.spec.prepend(Comment(text=(
299+
'![Loki::HoistVariablesTransformation] '
300+
'-------- Added hoisted temporary size imports -------------------------------'
301+
)))
302+
276303
routine.body = Transformer(call_map).visit(routine.body)
277304

278305
def driver_variable_declaration(self, routine, variables):

loki/transformations/inline.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from loki.batch import Transformation
1616
from loki.ir import (
1717
Import, Comment, Assignment, VariableDeclaration, CallStatement,
18-
Transformer, FindNodes, pragmas_attached, is_loki_pragma
18+
Transformer, FindNodes, pragmas_attached, is_loki_pragma, Interface
1919
)
2020
from loki.expression import (
2121
symbols as sym, FindVariables, FindInlineCalls, FindLiterals,
@@ -199,9 +199,9 @@ def map_inline_call(self, expr, *args, **kwargs):
199199

200200
def resolve_sequence_association_for_inlined_calls(routine, inline_internals, inline_marked):
201201
"""
202-
Resolve sequence association in calls to all member procedures (if `inline_internals = True`)
203-
or in calls to procedures that have been marked with an inline pragma (if `inline_marked = True`).
204-
If both `inline_internals` and `inline_marked` are `False`, no processing is done.
202+
Resolve sequence association in calls to all member procedures (if ``inline_internals = True``)
203+
or in calls to procedures that have been marked with an inline pragma (if ``inline_marked = True``).
204+
If both ``inline_internals`` and ``inline_marked`` are ``False``, no processing is done.
205205
"""
206206
call_map = {}
207207
with pragmas_attached(routine, node_type=CallStatement):
@@ -635,13 +635,26 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
635635

636636
# If we're importing the same module, check for missing symbols
637637
if m := imported_module_map.get(impt.module):
638-
if not all(s in m.symbols for s in impt.symbols):
638+
_m = import_map.get(m, m)
639+
if not all(s in _m.symbols for s in impt.symbols):
639640
new_symbols = tuple(s.rescope(routine) for s in impt.symbols)
640-
import_map[m] = m.clone(symbols=tuple(set(m.symbols + new_symbols)))
641+
import_map[m] = m.clone(symbols=tuple(set(_m.symbols + new_symbols)))
641642

642643
# Finally, apply the import remapping
643644
routine.spec = Transformer(import_map).visit(routine.spec)
644645

646+
# Add missing explicit interfaces from inlined subroutines
647+
new_intfs = []
648+
intf_symbols = routine.interface_symbols
649+
for callee in call_sets.keys():
650+
for intf in callee.interfaces:
651+
for s in intf.symbols:
652+
if not s in intf_symbols:
653+
new_intfs += [s.type.dtype.procedure,]
654+
655+
if new_intfs:
656+
routine.spec.append(Interface(body=as_tuple(new_intfs)))
657+
645658
# Add Fortran imports to the top, and C-style interface headers at the bottom
646659
c_imports = tuple(im for im in new_imports if im.c_import)
647660
f_imports = tuple(im for im in new_imports if not im.c_import)

loki/transformations/pool_allocator.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -731,10 +731,14 @@ def apply_pool_allocator_to_temporaries(self, routine, item=None):
731731
stack_size, stack_storage)
732732
allocations += allocation
733733

734-
# Store type information of temporary allocation
735-
if item and (_kind := arr.type.kind):
736-
if _kind in routine.imported_symbols:
737-
item.trafo_data[self._key]['kind_imports'][_kind] = routine.import_map[_kind.name].module.lower()
734+
# Store type and size information of temporary allocation
735+
if item:
736+
if (kind := arr.type.kind):
737+
if kind in routine.imported_symbols:
738+
item.trafo_data[self._key]['kind_imports'][kind] = routine.import_map[kind.name].module.lower()
739+
dims = [d for d in arr.shape if d in routine.imported_symbols]
740+
for d in dims:
741+
item.trafo_data[self._key]['kind_imports'][d] = routine.import_map[d.name].module.lower()
738742

739743
routine.spec.append(declarations)
740744
routine.body.prepend(allocations)

loki/transformations/tests/test_hoist_variables.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,11 @@ def test_hoist_mixed_variable_declarations(tmp_path, frontend, config):
557557
subroutine kernel(start, end, klon, klev, nclv, field1, field2)
558558
use, intrinsic :: iso_c_binding, only : c_size_t
559559
implicit none
560+
interface
561+
subroutine another_kernel(klev)
562+
integer, intent(in) :: klev
563+
end subroutine another_kernel
564+
end interface
560565
integer, parameter :: jprb = selected_real_kind(13,300)
561566
integer, intent(in) :: nclv
562567
integer, intent(in) :: start, end, klon, klev
@@ -582,19 +587,38 @@ def test_hoist_mixed_variable_declarations(tmp_path, frontend, config):
582587
tmp5(jl, jm, :) = field1(jl)
583588
enddo
584589
enddo
590+
591+
call another_kernel(klev)
585592
end subroutine kernel
586593
end module kernel_mod
587594
""".strip()
595+
fcode_mod = """
596+
module size_mod
597+
implicit none
598+
integer :: n
599+
end module size_mod
600+
""".strip()
601+
fcode_another_kernel = """
602+
subroutine another_kernel(klev)
603+
use size_mod, only : n
604+
implicit none
605+
integer, intent(in) :: klev
606+
real :: another_tmp(klev,n)
607+
end subroutine another_kernel
608+
""".strip()
588609

589610
(tmp_path/'driver.F90').write_text(fcode_driver)
590611
(tmp_path/'kernel_mod.F90').write_text(fcode_kernel)
612+
(tmp_path/'size_mod.F90').write_text(fcode_mod)
613+
(tmp_path/'another_kernel.F90').write_text(fcode_another_kernel)
591614

592615
config = {
593616
'default': {
594617
'mode': 'idem',
595618
'role': 'kernel',
596619
'expand': True,
597-
'strict': True
620+
'strict': True,
621+
'enable_imports': True
598622
},
599623
'routines': {
600624
'driver': {'role': 'driver'}
@@ -613,11 +637,12 @@ def test_hoist_mixed_variable_declarations(tmp_path, frontend, config):
613637
driver_variables = (
614638
'jprb', 'nlon', 'nz', 'nb', 'b',
615639
'field1(nlon, nb)', 'field2(nlon, nz, nb)',
616-
'kernel_tmp2(:,:)', 'kernel_tmp5(:,:,:)'
640+
'kernel_tmp2(:,:)', 'kernel_tmp5(:,:,:)', 'another_kernel_another_tmp(:,:)'
617641
)
618642
kernel_arguments = (
619643
'start', 'end', 'klon', 'klev', 'nclv',
620-
'field1(klon)', 'field2(klon,klev)', 'tmp2(klon,klev)', 'tmp5(klon,nclv,klev)'
644+
'field1(klon)', 'field2(klon,klev)', 'tmp2(klon,klev)', 'tmp5(klon,nclv,klev)',
645+
'another_kernel_another_tmp(klev,n)'
621646
)
622647

623648
# Check hoisting and declaration in driver
@@ -629,8 +654,17 @@ def test_hoist_mixed_variable_declarations(tmp_path, frontend, config):
629654
assert len(calls) == 1
630655
assert calls[0].arguments == (
631656
'1', 'nlon', 'nlon', 'nz', '2', 'field1(:,b)', 'field2(:,:,b)',
632-
'kernel_tmp2', 'kernel_tmp5'
657+
'kernel_tmp2', 'kernel_tmp5', 'another_kernel_another_tmp'
633658
)
634659

635660
# Check that fgen works
636661
assert scheduler['kernel_mod#kernel'].source.to_fortran()
662+
663+
# Check that imports were updated
664+
imports = FindNodes(ir.Import).visit(scheduler['kernel_mod#kernel'].ir.spec)
665+
assert len(imports) == 2
666+
assert 'n' in scheduler['kernel_mod#kernel'].ir.imported_symbols
667+
668+
imports = FindNodes(ir.Import).visit(scheduler['#driver'].ir.spec)
669+
assert len(imports) == 2
670+
assert 'n' in scheduler['#driver'].ir.imported_symbols

loki/transformations/tests/test_inline.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,11 +738,19 @@ def test_inline_marked_subroutines(frontend, adjust_imports):
738738
739739
contains
740740
subroutine add_one(a)
741+
interface
742+
subroutine do_something()
743+
end subroutine do_something
744+
end interface
741745
real(kind=8), intent(inout) :: a
742746
a = a + 1
743747
end subroutine add_one
744748
745749
subroutine add_a_to_b(a, b, n)
750+
interface
751+
subroutine do_something_else()
752+
end subroutine do_something_else
753+
end interface
746754
real(kind=8), intent(inout) :: a(:), b(:)
747755
integer, intent(in) :: n
748756
integer :: i
@@ -786,6 +794,14 @@ def test_inline_marked_subroutines(frontend, adjust_imports):
786794
else:
787795
assert imports[0].symbols == ('add_one', 'add_a_to_b')
788796

797+
if adjust_imports:
798+
# check that explicit interfaces were imported
799+
intfs = driver.interfaces
800+
assert len(intfs) == 1
801+
assert all(isinstance(s, sym.ProcedureSymbol) for s in driver.interface_symbols)
802+
assert 'do_something' in driver.interface_symbols
803+
assert 'do_something_else' in driver.interface_symbols
804+
789805

790806
@pytest.mark.parametrize('frontend', available_frontends())
791807
def test_inline_marked_subroutines_with_interfaces(frontend):
@@ -1361,6 +1377,7 @@ def test_inline_transformation_adjust_imports(frontend):
13611377
module bnds_module
13621378
integer :: m
13631379
integer :: n
1380+
integer :: l
13641381
end module bnds_module
13651382
"""
13661383

@@ -1374,10 +1391,13 @@ def test_inline_transformation_adjust_imports(frontend):
13741391
subroutine test_inline_outer(a, b)
13751392
use bnds_module, only: n
13761393
use test_inline_mod, only: test_inline_inner
1394+
use test_inline_another_mod, only: test_inline_another_inner
13771395
implicit none
13781396
13791397
real(kind=8), intent(inout) :: a(n), b(n)
13801398
1399+
!$loki inline
1400+
call test_inline_another_inner()
13811401
!$loki inline
13821402
call test_inline_inner(a, b)
13831403
end subroutine test_inline_outer
@@ -1403,11 +1423,25 @@ def test_inline_transformation_adjust_imports(frontend):
14031423
end subroutine test_inline_inner
14041424
end module test_inline_mod
14051425
"""
1426+
1427+
fcode_another_inner = """
1428+
module test_inline_another_mod
1429+
implicit none
1430+
contains
1431+
1432+
subroutine test_inline_another_inner()
1433+
use BNDS_module, only: n, m, l
1434+
1435+
end subroutine test_inline_another_inner
1436+
end module test_inline_another_mod
1437+
"""
1438+
14061439
_ = Module.from_source(fcode_another, frontend=frontend)
14071440
_ = Module.from_source(fcode_module, frontend=frontend)
14081441
inner = Module.from_source(fcode_inner, frontend=frontend)
1442+
another_inner = Module.from_source(fcode_another_inner, frontend=frontend)
14091443
outer = Subroutine.from_source(
1410-
fcode_outer, definitions=inner, frontend=frontend
1444+
fcode_outer, definitions=(inner, another_inner), frontend=frontend
14111445
)
14121446

14131447
trafo = InlineTransformation(
@@ -1430,7 +1464,7 @@ def test_inline_transformation_adjust_imports(frontend):
14301464
assert imports[0].module == 'another_module'
14311465
assert imports[0].symbols == ('x',)
14321466
assert imports[1].module == 'bnds_module'
1433-
assert 'm' in imports[1].symbols and 'n' in imports[1].symbols
1467+
assert all(_ in imports[1].symbols for _ in ['l', 'm', 'n'])
14341468

14351469

14361470
@pytest.mark.parametrize('frontend', available_frontends())

loki/transformations/tests/test_pool_allocator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,8 @@ def test_pool_allocator_args_vs_kwargs(tmp_path, frontend, block_dim_alt, cray_p
11961196
type geom_type
11971197
type(dim_type) :: blk_dim
11981198
end type geom_type
1199+
1200+
integer :: n
11991201
end module geom_mod
12001202
"""
12011203

@@ -1228,6 +1230,7 @@ def test_pool_allocator_args_vs_kwargs(tmp_path, frontend, block_dim_alt, cray_p
12281230
contains
12291231
subroutine kernel(start, end, klon, klev, field1, field2)
12301232
use parkind1, only : jpim, jplm
1233+
use geom_mod, only : n
12311234
implicit none
12321235
integer, parameter :: jwrb = selected_real_kind(13,300)
12331236
integer, intent(in) :: start, end, klon, klev
@@ -1237,6 +1240,7 @@ def test_pool_allocator_args_vs_kwargs(tmp_path, frontend, block_dim_alt, cray_p
12371240
real(kind=jwrb) :: tmp2(klon, klev)
12381241
integer(kind=jpim) :: tmp3(klon*2)
12391242
logical(kind=jplm) :: tmp4(klev)
1243+
logical(kind=jplm) :: tmp5(klev,n)
12401244
integer :: jk, jl
12411245
12421246
do jk=1,klev
@@ -1247,6 +1251,7 @@ def test_pool_allocator_args_vs_kwargs(tmp_path, frontend, block_dim_alt, cray_p
12471251
end do
12481252
field1(jl) = tmp1(jl)
12491253
tmp4(jk) = .true.
1254+
tmp5(jk,1:n) = .true.
12501255
end do
12511256
12521257
do jl=start,end
@@ -1344,3 +1349,6 @@ def test_pool_allocator_args_vs_kwargs(tmp_path, frontend, block_dim_alt, cray_p
13441349
# check stack size allocation
13451350
allocations = FindNodes(Allocation).visit(driver.body)
13461351
assert len(allocations) == 1 and 'zstack(istsz,geom%blk_dim%nb)' in allocations[0].variables
1352+
1353+
# check that array size was imported to the driver
1354+
assert 'n' in driver.imported_symbols

0 commit comments

Comments
 (0)