77
88import pytest
99
10- from loki import (
11- Module , Subroutine , VariableDeclaration , TypeDef , fexprgen ,
12- BasicType , Assignment , FindNodes , FindInlineCalls , FindTypedSymbols ,
13- Transformer , fgen , SymbolAttributes , Variable , Import , Section , Intrinsic ,
14- Scalar , DeferredTypeSymbol , FindVariables , SubstituteExpressions , Literal
15- )
10+ from loki import Module , Subroutine , fexprgen , fgen
1611from loki .build import jit_compile , clean_test
12+ from loki .expression import symbols as sym
1713from loki .frontend import available_frontends , OMNI
14+ from loki .ir import (
15+ nodes as ir , FindNodes , FindInlineCalls , FindTypedSymbols ,
16+ FindVariables , SubstituteExpressions , Transformer
17+ )
1818from loki .sourcefile import Sourcefile
19+ from loki .types import BasicType , DerivedType , SymbolAttributes
1920
2021
2122@pytest .mark .parametrize ('frontend' , available_frontends ())
@@ -40,8 +41,8 @@ def test_module_from_source(frontend, tmp_path):
4041end module a_module
4142""" .strip ()
4243 module = Module .from_source (fcode , frontend = frontend , xmods = [tmp_path ])
43- assert len ([o for o in module .spec .body if isinstance (o , VariableDeclaration )]) == 2
44- assert len ([o for o in module .spec .body if isinstance (o , TypeDef )]) == 1
44+ assert len ([o for o in module .spec .body if isinstance (o , ir . VariableDeclaration )]) == 2
45+ assert len ([o for o in module .spec .body if isinstance (o , ir . TypeDef )]) == 1
4546 assert 'derived_type' in module .typedef_map
4647 assert len (module .routines ) == 1
4748 assert module .routines [0 ].name == 'my_routine'
@@ -100,7 +101,7 @@ def test_module_external_typedefs_subroutine(frontend, tmp_path):
100101 assert fexprgen (a .shape ) == exptected_array_shape
101102
102103 # Check the LHS of the assignment has correct meta-data
103- stmt = FindNodes (Assignment ).visit (routine .body )[0 ]
104+ stmt = FindNodes (ir . Assignment ).visit (routine .body )[0 ]
104105 pt_ext_arr = stmt .lhs
105106 assert pt_ext_arr .type .dtype == BasicType .REAL
106107 assert fexprgen (pt_ext_arr .shape ) == exptected_array_shape
@@ -177,14 +178,14 @@ def test_module_external_typedefs_type(frontend, tmp_path):
177178
178179 # Verify correct attachment of type information
179180 assert 'ext_type' in module .symbol_attrs
180- assert isinstance (module .symbol_attrs ['ext_type' ].dtype .typedef , TypeDef )
181- assert isinstance (nested .symbol_attrs ['ext' ].dtype .typedef , TypeDef )
182- assert isinstance (module ['my_routine' ].symbol_attrs ['pt' ].dtype .typedef , TypeDef )
183- assert isinstance (module ['my_routine' ].symbol_attrs ['pt%ext' ].dtype .typedef , TypeDef )
181+ assert isinstance (module .symbol_attrs ['ext_type' ].dtype .typedef , ir . TypeDef )
182+ assert isinstance (nested .symbol_attrs ['ext' ].dtype .typedef , ir . TypeDef )
183+ assert isinstance (module ['my_routine' ].symbol_attrs ['pt' ].dtype .typedef , ir . TypeDef )
184+ assert isinstance (module ['my_routine' ].symbol_attrs ['pt%ext' ].dtype .typedef , ir . TypeDef )
184185 assert 'other_type' in module .symbol_attrs
185186 assert 'other_type' not in module ['other_routine' ].symbol_attrs
186- assert isinstance (module .symbol_attrs ['other_type' ].dtype .typedef , TypeDef )
187- assert isinstance (module ['other_routine' ].symbol_attrs ['pt' ].dtype .typedef , TypeDef )
187+ assert isinstance (module .symbol_attrs ['other_type' ].dtype .typedef , ir . TypeDef )
188+ assert isinstance (module ['other_routine' ].symbol_attrs ['pt' ].dtype .typedef , ir . TypeDef )
188189
189190 # OMNI resolves explicit shape parameters in the frontend parser
190191 exptected_array_shape = '(2, 3)' if frontend == OMNI else '(x, y)'
@@ -206,7 +207,7 @@ def test_module_external_typedefs_type(frontend, tmp_path):
206207 assert fexprgen (pt_ext_a .shape ) == exptected_array_shape
207208
208209 # Check the LHS of the assignment has correct meta-data
209- stmt = FindNodes (Assignment ).visit (routine .body )[0 ]
210+ stmt = FindNodes (ir . Assignment ).visit (routine .body )[0 ]
210211 pt_ext_arr = stmt .lhs
211212 assert pt_ext_arr .type .dtype == BasicType .REAL
212213 assert fexprgen (pt_ext_arr .shape ) == exptected_array_shape
@@ -412,9 +413,9 @@ def test_module_variables_add_remove(frontend, tmp_path):
412413 x = module .variable_map ['x' ] # That's the symbol for variable 'x'
413414 real_type = SymbolAttributes ('real' , kind = module .variable_map ['jprb' ])
414415 int_type = SymbolAttributes ('integer' )
415- a = Variable (name = 'a' , type = real_type , scope = module )
416- b = Variable (name = 'b' , dimensions = (x , ), type = real_type , scope = module )
417- c = Variable (name = 'c' , type = int_type , scope = module )
416+ a = sym . Variable (name = 'a' , type = real_type , scope = module )
417+ b = sym . Variable (name = 'b' , dimensions = (x , ), type = real_type , scope = module )
418+ c = sym . Variable (name = 'c' , type = int_type , scope = module )
418419
419420 # Add new variables and check that they are all in the module spec
420421 module .variables += (a , b , c )
@@ -554,22 +555,22 @@ def test_module_deep_clone(frontend, tmp_path):
554555 new_module = module .clone ()
555556
556557 n = [v for v in FindVariables ().visit (new_module .spec ) if v .name == 'n' ][0 ]
557- n_decl = FindNodes (VariableDeclaration ).visit (new_module .spec )[0 ]
558+ n_decl = FindNodes (ir . VariableDeclaration ).visit (new_module .spec )[0 ]
558559
559560 # Remove the declaration of `n` and replace it with `3`
560561 new_module .spec = Transformer ({n_decl : None }).visit (new_module .spec )
561- new_module .spec = SubstituteExpressions ({n : Literal (3 )}).visit (new_module .spec )
562+ new_module .spec = SubstituteExpressions ({n : sym . Literal (3 )}).visit (new_module .spec )
562563
563564 # Check the new module has been changed
564- assert len (FindNodes (VariableDeclaration ).visit (new_module .spec )) == 1
565- new_type_decls = FindNodes (VariableDeclaration ).visit (new_module ['my_type' ].body )
565+ assert len (FindNodes (ir . VariableDeclaration ).visit (new_module .spec )) == 1
566+ new_type_decls = FindNodes (ir . VariableDeclaration ).visit (new_module ['my_type' ].body )
566567 assert len (new_type_decls ) == 2
567568 assert new_type_decls [0 ].symbols [0 ] == 'vector(3)'
568569 assert new_type_decls [1 ].symbols [0 ] == 'matrix(3, 3)'
569570
570571 # Check the old one has not changed
571- assert len (FindNodes (VariableDeclaration ).visit (module .spec )) == 2
572- type_decls = FindNodes (VariableDeclaration ).visit (module ['my_type' ].body )
572+ assert len (FindNodes (ir . VariableDeclaration ).visit (module .spec )) == 2
573+ type_decls = FindNodes (ir . VariableDeclaration ).visit (module ['my_type' ].body )
573574 assert len (type_decls ) == 2
574575 assert type_decls [0 ].symbols [0 ] == 'vector(n)'
575576 assert type_decls [1 ].symbols [0 ] == 'matrix(n, n)'
@@ -831,7 +832,7 @@ def test_module_rename_imports_with_definitions(frontend, tmp_path):
831832 assert mod3 .symbol_attrs [s ].compare (mod2 .symbol_attrs [use_name or s ], ignore = ('imported' , 'module' , 'use_name' ))
832833
833834 # Verify Import IR node
834- for imprt in FindNodes (Import ).visit (mod3 .spec ):
835+ for imprt in FindNodes (ir . Import ).visit (mod3 .spec ):
835836 if imprt .module == 'test_rename_mod' :
836837 assert imprt .rename_list
837838 assert not imprt .symbols
@@ -915,7 +916,7 @@ def test_module_rename_imports_no_definitions(frontend, tmp_path):
915916 assert mod3 .symbol_attrs [s ].use_name == use_name
916917
917918 # Verify Import IR node
918- for imprt in FindNodes (Import ).visit (mod3 .spec ):
919+ for imprt in FindNodes (ir . Import ).visit (mod3 .spec ):
919920 if imprt .module == 'test_rename_mod' :
920921 assert imprt .rename_list
921922 assert not imprt .symbols
@@ -969,7 +970,7 @@ def test_module_use_module_nature(frontend, tmp_path):
969970
970971 # Check properties on the Import IR node in the external module
971972 assert ext_mod .imported_symbols == ('int16' ,)
972- imprt = FindNodes (Import ).visit (ext_mod .spec )[0 ]
973+ imprt = FindNodes (ir . Import ).visit (ext_mod .spec )[0 ]
973974 assert imprt .nature .lower () == 'intrinsic'
974975 assert imprt .module .lower () == 'iso_c_binding'
975976 assert ext_mod .imported_symbol_map ['int16' ].type .imported is True
@@ -988,8 +989,8 @@ def test_module_use_module_nature(frontend, tmp_path):
988989 assert set (my_kinds .imported_symbols ) == {'int8' , 'int16' }
989990 assert set (kinds .imported_symbols ) == {'int8' , 'int16' }
990991
991- my_import_map = {s .name : imprt for imprt in FindNodes (Import ).visit (my_kinds .spec ) for s in imprt .symbols }
992- import_map = {s .name : imprt for imprt in FindNodes (Import ).visit (kinds .spec ) for s in imprt .symbols }
992+ my_import_map = {s .name : imprt for imprt in FindNodes (ir . Import ).visit (my_kinds .spec ) for s in imprt .symbols }
993+ import_map = {s .name : imprt for imprt in FindNodes (ir . Import ).visit (kinds .spec ) for s in imprt .symbols }
993994
994995 assert my_import_map ['int8' ] is my_import_map ['int16' ]
995996 assert import_map ['int8' ] is import_map ['int16' ]
@@ -1194,13 +1195,13 @@ def test_module_contains_auto_insert(frontend, tmp_path):
11941195 assert routine1 .contains is None
11951196
11961197 routine1 = routine1 .clone (contains = routine2 )
1197- assert isinstance (routine1 .contains , Section )
1198- assert isinstance (routine1 .contains .body [0 ], Intrinsic )
1198+ assert isinstance (routine1 .contains , ir . Section )
1199+ assert isinstance (routine1 .contains .body [0 ], ir . Intrinsic )
11991200 assert routine1 .contains .body [0 ].text == 'CONTAINS'
12001201
12011202 module = module .clone (contains = routine1 )
1202- assert isinstance (module .contains , Section )
1203- assert isinstance (module .contains .body [0 ], Intrinsic )
1203+ assert isinstance (module .contains , ir . Section )
1204+ assert isinstance (module .contains .body [0 ], ir . Intrinsic )
12041205 assert module .contains .body [0 ].text == 'CONTAINS'
12051206
12061207
@@ -1243,14 +1244,14 @@ def test_module_missing_imported_symbol(frontend, only_list, complete_tree, tmp_
12431244 b = driver .symbol_map ['b' ]
12441245
12451246 if complete_tree :
1246- assert isinstance (a , Scalar )
1247+ assert isinstance (a , sym . Scalar )
12471248 assert a .type .dtype is BasicType .INTEGER
1248- assert isinstance (b , Scalar )
1249+ assert isinstance (b , sym . Scalar )
12491250 assert b .type .dtype is BasicType .INTEGER
12501251 else :
1251- assert isinstance (a , DeferredTypeSymbol )
1252+ assert isinstance (a , sym . DeferredTypeSymbol )
12521253 assert a .type .dtype is BasicType .DEFERRED
1253- assert isinstance (b , DeferredTypeSymbol )
1254+ assert isinstance (b , sym . DeferredTypeSymbol )
12541255 assert b .type .dtype is BasicType .DEFERRED
12551256
12561257 assert a .type .imported
@@ -1371,3 +1372,55 @@ def test_module_enrichment_within_file(frontend, tmp_path):
13711372 assert calls [0 ].arguments [0 ].type .parameter
13721373 assert calls [0 ].arguments [0 ].type .initial == 16
13731374 assert calls [0 ].arguments [0 ].type .module is source ['foo' ]
1375+
1376+
1377+ @pytest .mark .parametrize ('frontend' , available_frontends ())
1378+ def test_module_enrichment_typdefs (frontend , tmp_path ):
1379+ """ Test that module-level enrihcment is propagated correctly """
1380+
1381+ fcode_state_mod = """
1382+ module state_type_mod
1383+ implicit none
1384+
1385+ type state_type
1386+ real, pointer, dimension(:,:) :: a
1387+ end type state_type
1388+
1389+ end module state_type_mod
1390+ """
1391+
1392+ fcode_driver_mod = """
1393+ module driver_mod
1394+ use state_type_mod, only: state_type
1395+ implicit none
1396+
1397+ contains
1398+ subroutine driver_routine(state)
1399+ type(state_type), intent(inout) :: state
1400+
1401+ state%a = 1
1402+
1403+ end subroutine driver_routine
1404+ end module driver_mod
1405+ """
1406+ state_mod = Sourcefile .from_source (fcode_state_mod , frontend = frontend , xmods = [tmp_path ])['state_type_mod' ]
1407+ driver_mod = Sourcefile .from_source (fcode_driver_mod , frontend = frontend , xmods = [tmp_path ])['driver_mod' ]
1408+ driver = driver_mod ['driver_routine' ]
1409+
1410+ state = driver .variable_map ['state' ]
1411+ assert isinstance (state .type .dtype , DerivedType )
1412+ assert state .type .dtype .typedef == BasicType .DEFERRED
1413+
1414+ # Enrich typedef on the outer module Import
1415+ driver_mod .enrich ([state_mod ], recurse = True )
1416+
1417+ state = driver .variable_map ['state' ]
1418+
1419+ # Ensure type info has been propagated to inner subroutine
1420+ assert isinstance (state .type .dtype , DerivedType )
1421+ assert isinstance (state .type .dtype .typedef , ir .TypeDef )
1422+
1423+ assigns = FindNodes (ir .Assignment ).visit (driver .body )
1424+ assert len (assigns ) == 1
1425+ assert assigns [0 ].lhs .type .dtype == BasicType .REAL
1426+ assert assigns [0 ].lhs .type .shape == (':' , ':' )
0 commit comments