Skip to content

Commit b93a01d

Browse files
authored
Merge pull request #448 from ecmwf-ifs/naml-fix-module-enrichment
Module: Fix enrichment of type info via `Module` imports
2 parents bcef931 + 83222f4 commit b93a01d

File tree

2 files changed

+93
-39
lines changed

2 files changed

+93
-39
lines changed

loki/program_unit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,8 @@ def enrich(self, definitions, recurse=False):
327327
"""
328328
definitions_map = CaseInsensitiveDict((r.name, r) for r in as_tuple(definitions))
329329

330-
for imprt in self.imports:
330+
# Enrich type info from all known imports (including parent scopes)
331+
for imprt in self.all_imports:
331332
if not (module := definitions_map.get(imprt.module)):
332333
# Skip modules that are not available in the definitions list
333334
continue

loki/tests/test_modules.py

Lines changed: 91 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@
77

88
import pytest
99

10-
from loki import (
11-
Module, Subroutine, VariableDeclaration, TypeDef, fexprgen,
12-
BasicType, Assignment, FindNodes, FindInlineCalls, FindTypedSymbols,
13-
Transformer, fgen, SymbolAttributes, Variable, Import, Section, Intrinsic,
14-
Scalar, DeferredTypeSymbol, FindVariables, SubstituteExpressions, Literal
15-
)
10+
from loki import Module, Subroutine, fexprgen, fgen
1611
from loki.build import jit_compile, clean_test
12+
from loki.expression import symbols as sym
1713
from loki.frontend import available_frontends, OMNI
14+
from loki.ir import (
15+
nodes as ir, FindNodes, FindInlineCalls, FindTypedSymbols,
16+
FindVariables, SubstituteExpressions, Transformer
17+
)
1818
from loki.sourcefile import Sourcefile
19+
from loki.types import BasicType, DerivedType, SymbolAttributes
1920

2021

2122
@pytest.mark.parametrize('frontend', available_frontends())
@@ -40,8 +41,8 @@ def test_module_from_source(frontend, tmp_path):
4041
end module a_module
4142
""".strip()
4243
module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path])
43-
assert len([o for o in module.spec.body if isinstance(o, VariableDeclaration)]) == 2
44-
assert len([o for o in module.spec.body if isinstance(o, TypeDef)]) == 1
44+
assert len([o for o in module.spec.body if isinstance(o, ir.VariableDeclaration)]) == 2
45+
assert len([o for o in module.spec.body if isinstance(o, ir.TypeDef)]) == 1
4546
assert 'derived_type' in module.typedef_map
4647
assert len(module.routines) == 1
4748
assert module.routines[0].name == 'my_routine'
@@ -100,7 +101,7 @@ def test_module_external_typedefs_subroutine(frontend, tmp_path):
100101
assert fexprgen(a.shape) == exptected_array_shape
101102

102103
# Check the LHS of the assignment has correct meta-data
103-
stmt = FindNodes(Assignment).visit(routine.body)[0]
104+
stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
104105
pt_ext_arr = stmt.lhs
105106
assert pt_ext_arr.type.dtype == BasicType.REAL
106107
assert fexprgen(pt_ext_arr.shape) == exptected_array_shape
@@ -177,14 +178,14 @@ def test_module_external_typedefs_type(frontend, tmp_path):
177178

178179
# Verify correct attachment of type information
179180
assert 'ext_type' in module.symbol_attrs
180-
assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, TypeDef)
181-
assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, TypeDef)
182-
assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef)
183-
assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, TypeDef)
181+
assert isinstance(module.symbol_attrs['ext_type'].dtype.typedef, ir.TypeDef)
182+
assert isinstance(nested.symbol_attrs['ext'].dtype.typedef, ir.TypeDef)
183+
assert isinstance(module['my_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)
184+
assert isinstance(module['my_routine'].symbol_attrs['pt%ext'].dtype.typedef, ir.TypeDef)
184185
assert 'other_type' in module.symbol_attrs
185186
assert 'other_type' not in module['other_routine'].symbol_attrs
186-
assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, TypeDef)
187-
assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, TypeDef)
187+
assert isinstance(module.symbol_attrs['other_type'].dtype.typedef, ir.TypeDef)
188+
assert isinstance(module['other_routine'].symbol_attrs['pt'].dtype.typedef, ir.TypeDef)
188189

189190
# OMNI resolves explicit shape parameters in the frontend parser
190191
exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'
@@ -206,7 +207,7 @@ def test_module_external_typedefs_type(frontend, tmp_path):
206207
assert fexprgen(pt_ext_a.shape) == exptected_array_shape
207208

208209
# Check the LHS of the assignment has correct meta-data
209-
stmt = FindNodes(Assignment).visit(routine.body)[0]
210+
stmt = FindNodes(ir.Assignment).visit(routine.body)[0]
210211
pt_ext_arr = stmt.lhs
211212
assert pt_ext_arr.type.dtype == BasicType.REAL
212213
assert fexprgen(pt_ext_arr.shape) == exptected_array_shape
@@ -412,9 +413,9 @@ def test_module_variables_add_remove(frontend, tmp_path):
412413
x = module.variable_map['x'] # That's the symbol for variable 'x'
413414
real_type = SymbolAttributes('real', kind=module.variable_map['jprb'])
414415
int_type = SymbolAttributes('integer')
415-
a = Variable(name='a', type=real_type, scope=module)
416-
b = Variable(name='b', dimensions=(x, ), type=real_type, scope=module)
417-
c = Variable(name='c', type=int_type, scope=module)
416+
a = sym.Variable(name='a', type=real_type, scope=module)
417+
b = sym.Variable(name='b', dimensions=(x, ), type=real_type, scope=module)
418+
c = sym.Variable(name='c', type=int_type, scope=module)
418419

419420
# Add new variables and check that they are all in the module spec
420421
module.variables += (a, b, c)
@@ -554,22 +555,22 @@ def test_module_deep_clone(frontend, tmp_path):
554555
new_module = module.clone()
555556

556557
n = [v for v in FindVariables().visit(new_module.spec) if v.name == 'n'][0]
557-
n_decl = FindNodes(VariableDeclaration).visit(new_module.spec)[0]
558+
n_decl = FindNodes(ir.VariableDeclaration).visit(new_module.spec)[0]
558559

559560
# Remove the declaration of `n` and replace it with `3`
560561
new_module.spec = Transformer({n_decl: None}).visit(new_module.spec)
561-
new_module.spec = SubstituteExpressions({n: Literal(3)}).visit(new_module.spec)
562+
new_module.spec = SubstituteExpressions({n: sym.Literal(3)}).visit(new_module.spec)
562563

563564
# Check the new module has been changed
564-
assert len(FindNodes(VariableDeclaration).visit(new_module.spec)) == 1
565-
new_type_decls = FindNodes(VariableDeclaration).visit(new_module['my_type'].body)
565+
assert len(FindNodes(ir.VariableDeclaration).visit(new_module.spec)) == 1
566+
new_type_decls = FindNodes(ir.VariableDeclaration).visit(new_module['my_type'].body)
566567
assert len(new_type_decls) == 2
567568
assert new_type_decls[0].symbols[0] == 'vector(3)'
568569
assert new_type_decls[1].symbols[0] == 'matrix(3, 3)'
569570

570571
# Check the old one has not changed
571-
assert len(FindNodes(VariableDeclaration).visit(module.spec)) == 2
572-
type_decls = FindNodes(VariableDeclaration).visit(module['my_type'].body)
572+
assert len(FindNodes(ir.VariableDeclaration).visit(module.spec)) == 2
573+
type_decls = FindNodes(ir.VariableDeclaration).visit(module['my_type'].body)
573574
assert len(type_decls) == 2
574575
assert type_decls[0].symbols[0] == 'vector(n)'
575576
assert type_decls[1].symbols[0] == 'matrix(n, n)'
@@ -831,7 +832,7 @@ def test_module_rename_imports_with_definitions(frontend, tmp_path):
831832
assert mod3.symbol_attrs[s].compare(mod2.symbol_attrs[use_name or s], ignore=('imported', 'module', 'use_name'))
832833

833834
# Verify Import IR node
834-
for imprt in FindNodes(Import).visit(mod3.spec):
835+
for imprt in FindNodes(ir.Import).visit(mod3.spec):
835836
if imprt.module == 'test_rename_mod':
836837
assert imprt.rename_list
837838
assert not imprt.symbols
@@ -915,7 +916,7 @@ def test_module_rename_imports_no_definitions(frontend, tmp_path):
915916
assert mod3.symbol_attrs[s].use_name == use_name
916917

917918
# Verify Import IR node
918-
for imprt in FindNodes(Import).visit(mod3.spec):
919+
for imprt in FindNodes(ir.Import).visit(mod3.spec):
919920
if imprt.module == 'test_rename_mod':
920921
assert imprt.rename_list
921922
assert not imprt.symbols
@@ -969,7 +970,7 @@ def test_module_use_module_nature(frontend, tmp_path):
969970

970971
# Check properties on the Import IR node in the external module
971972
assert ext_mod.imported_symbols == ('int16',)
972-
imprt = FindNodes(Import).visit(ext_mod.spec)[0]
973+
imprt = FindNodes(ir.Import).visit(ext_mod.spec)[0]
973974
assert imprt.nature.lower() == 'intrinsic'
974975
assert imprt.module.lower() == 'iso_c_binding'
975976
assert ext_mod.imported_symbol_map['int16'].type.imported is True
@@ -988,8 +989,8 @@ def test_module_use_module_nature(frontend, tmp_path):
988989
assert set(my_kinds.imported_symbols) == {'int8', 'int16'}
989990
assert set(kinds.imported_symbols) == {'int8', 'int16'}
990991

991-
my_import_map = {s.name: imprt for imprt in FindNodes(Import).visit(my_kinds.spec) for s in imprt.symbols}
992-
import_map = {s.name: imprt for imprt in FindNodes(Import).visit(kinds.spec) for s in imprt.symbols}
992+
my_import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(my_kinds.spec) for s in imprt.symbols}
993+
import_map = {s.name: imprt for imprt in FindNodes(ir.Import).visit(kinds.spec) for s in imprt.symbols}
993994

994995
assert my_import_map['int8'] is my_import_map['int16']
995996
assert import_map['int8'] is import_map['int16']
@@ -1194,13 +1195,13 @@ def test_module_contains_auto_insert(frontend, tmp_path):
11941195
assert routine1.contains is None
11951196

11961197
routine1 = routine1.clone(contains=routine2)
1197-
assert isinstance(routine1.contains, Section)
1198-
assert isinstance(routine1.contains.body[0], Intrinsic)
1198+
assert isinstance(routine1.contains, ir.Section)
1199+
assert isinstance(routine1.contains.body[0], ir.Intrinsic)
11991200
assert routine1.contains.body[0].text == 'CONTAINS'
12001201

12011202
module = module.clone(contains=routine1)
1202-
assert isinstance(module.contains, Section)
1203-
assert isinstance(module.contains.body[0], Intrinsic)
1203+
assert isinstance(module.contains, ir.Section)
1204+
assert isinstance(module.contains.body[0], ir.Intrinsic)
12041205
assert module.contains.body[0].text == 'CONTAINS'
12051206

12061207

@@ -1243,14 +1244,14 @@ def test_module_missing_imported_symbol(frontend, only_list, complete_tree, tmp_
12431244
b = driver.symbol_map['b']
12441245

12451246
if complete_tree:
1246-
assert isinstance(a, Scalar)
1247+
assert isinstance(a, sym.Scalar)
12471248
assert a.type.dtype is BasicType.INTEGER
1248-
assert isinstance(b, Scalar)
1249+
assert isinstance(b, sym.Scalar)
12491250
assert b.type.dtype is BasicType.INTEGER
12501251
else:
1251-
assert isinstance(a, DeferredTypeSymbol)
1252+
assert isinstance(a, sym.DeferredTypeSymbol)
12521253
assert a.type.dtype is BasicType.DEFERRED
1253-
assert isinstance(b, DeferredTypeSymbol)
1254+
assert isinstance(b, sym.DeferredTypeSymbol)
12541255
assert b.type.dtype is BasicType.DEFERRED
12551256

12561257
assert a.type.imported
@@ -1371,3 +1372,55 @@ def test_module_enrichment_within_file(frontend, tmp_path):
13711372
assert calls[0].arguments[0].type.parameter
13721373
assert calls[0].arguments[0].type.initial == 16
13731374
assert calls[0].arguments[0].type.module is source['foo']
1375+
1376+
1377+
@pytest.mark.parametrize('frontend', available_frontends())
1378+
def test_module_enrichment_typdefs(frontend, tmp_path):
1379+
""" Test that module-level enrihcment is propagated correctly """
1380+
1381+
fcode_state_mod = """
1382+
module state_type_mod
1383+
implicit none
1384+
1385+
type state_type
1386+
real, pointer, dimension(:,:) :: a
1387+
end type state_type
1388+
1389+
end module state_type_mod
1390+
"""
1391+
1392+
fcode_driver_mod = """
1393+
module driver_mod
1394+
use state_type_mod, only: state_type
1395+
implicit none
1396+
1397+
contains
1398+
subroutine driver_routine(state)
1399+
type(state_type), intent(inout) :: state
1400+
1401+
state%a = 1
1402+
1403+
end subroutine driver_routine
1404+
end module driver_mod
1405+
"""
1406+
state_mod = Sourcefile.from_source(fcode_state_mod, frontend=frontend, xmods=[tmp_path])['state_type_mod']
1407+
driver_mod = Sourcefile.from_source(fcode_driver_mod, frontend=frontend, xmods=[tmp_path])['driver_mod']
1408+
driver = driver_mod['driver_routine']
1409+
1410+
state = driver.variable_map['state']
1411+
assert isinstance(state.type.dtype, DerivedType)
1412+
assert state.type.dtype.typedef == BasicType.DEFERRED
1413+
1414+
# Enrich typedef on the outer module Import
1415+
driver_mod.enrich([state_mod], recurse=True)
1416+
1417+
state = driver.variable_map['state']
1418+
1419+
# Ensure type info has been propagated to inner subroutine
1420+
assert isinstance(state.type.dtype, DerivedType)
1421+
assert isinstance(state.type.dtype.typedef, ir.TypeDef)
1422+
1423+
assigns = FindNodes(ir.Assignment).visit(driver.body)
1424+
assert len(assigns) == 1
1425+
assert assigns[0].lhs.type.dtype == BasicType.REAL
1426+
assert assigns[0].lhs.type.shape == (':', ':')

0 commit comments

Comments
 (0)