Skip to content

Commit b9132db

Browse files
authored
Merge branch 'main' into naan-hoist-after-inline
2 parents 714075a + eb793e2 commit b9132db

File tree

11 files changed

+258
-25
lines changed

11 files changed

+258
-25
lines changed

loki/analyse/analyse_dataflow.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212
from contextlib import contextmanager
13-
from loki.expression import FindVariables, Array, FindInlineCalls
13+
from loki.expression import FindVariables, Array, FindInlineCalls, ProcedureSymbol, FindTypedSymbols
1414
from loki.tools import as_tuple, flatten
1515
from loki.types import BasicType
1616
from loki.ir import Visitor, Transformer
@@ -232,9 +232,9 @@ def visit_CallStatement(self, o, **kwargs):
232232
outvals = [val for arg, val in o.arg_iter() if str(arg.type.intent).lower() in ('inout', 'out')]
233233
invals = [val for arg, val in o.arg_iter() if str(arg.type.intent).lower() in ('inout', 'in')]
234234

235+
arrays = [v for v in FindVariables().visit(outvals) if isinstance(v, Array)]
236+
dims = set(v for a in arrays for v in self._symbols_from_expr(a.dimensions))
235237
for val in outvals:
236-
arrays = [v for v in FindVariables().visit(outvals) if isinstance(v, Array)]
237-
dims = set(v for a in arrays for v in FindVariables().visit(a.dimensions))
238238
exprs = self._symbols_from_expr(val)
239239
defines |= {e for e in exprs if not e in dims}
240240
uses |= dims
@@ -269,12 +269,14 @@ def visit_Deallocation(self, o, **kwargs):
269269
visit_Nullify = visit_Deallocation
270270

271271
def visit_Import(self, o, **kwargs):
272-
defines = self._symbols_from_expr(o.symbols or ())
272+
defines = set(s.clone(dimensions=None) for s in FindTypedSymbols().visit(o.symbols or ())
273+
if isinstance(s, ProcedureSymbol))
273274
return self.visit_Node(o, defines_symbols=defines, **kwargs)
274275

275276
def visit_VariableDeclaration(self, o, **kwargs):
276277
defines = self._symbols_from_expr(o.symbols, condition=lambda v: v.type.initial is not None)
277-
return self.visit_Node(o, defines_symbols=defines, **kwargs)
278+
uses = {v for a in o.symbols if isinstance(a, Array) for v in self._symbols_from_expr(a.dimensions)}
279+
return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs)
278280

279281

280282
class DataflowAnalysisDetacher(Transformer):

loki/analyse/tests/test_analyse_dataflow.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from loki import (
1111
Subroutine, FindNodes, Assignment, Loop, Conditional, Pragma, fgen, Sourcefile,
1212
CallStatement, MultiConditional, MaskedStatement, ProcedureSymbol, WhileLoop,
13-
Associate
13+
Associate, Module
1414
)
1515
from loki.analyse import (
1616
dataflow_analysis_attached, read_after_write_vars, loop_carried_dependencies
@@ -281,6 +281,41 @@ def test_analyse_interface(frontend):
281281
assert isinstance(list(routine.spec.defines_symbols)[0], ProcedureSymbol)
282282
assert 'random_call' in routine.spec.defines_symbols
283283

284+
285+
@pytest.mark.parametrize('frontend', available_frontends())
286+
def test_analyse_imports(frontend):
287+
fcode_module = """
288+
module some_mod
289+
implicit none
290+
real :: my_global
291+
contains
292+
subroutine random_call(v_out,v_in,v_inout)
293+
294+
real,intent(in) :: v_in
295+
real,intent(out) :: v_out
296+
real,intent(inout) :: v_inout
297+
298+
299+
end subroutine random_call
300+
end module some_mod
301+
""".strip()
302+
303+
fcode = """
304+
subroutine test()
305+
use some_mod, only: my_global, random_call
306+
implicit none
307+
308+
end subroutine test
309+
""".strip()
310+
311+
module = Module.from_source(fcode_module, frontend=frontend)
312+
routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module)
313+
314+
with dataflow_analysis_attached(routine):
315+
assert len(routine.spec.defines_symbols) == 1
316+
assert 'random_call' in routine.spec.defines_symbols
317+
318+
284319
@pytest.mark.parametrize('frontend', available_frontends())
285320
def test_analyse_enriched_call(frontend):
286321
fcode = """
@@ -429,26 +464,30 @@ def test_analyse_call_args_array_slicing(frontend):
429464
430465
end subroutine random_call
431466
432-
subroutine test(v,n)
467+
subroutine test(v,n,b)
433468
implicit none
434469
435470
integer,intent(out) :: v(:)
436471
integer,intent( in) :: n
472+
integer,intent( in) :: b(n)
437473
438474
call random_call(v(n))
475+
call random_call(v(b(1)))
439476
440477
end subroutine test
441478
""".strip()
442479

443480
source = Sourcefile.from_source(fcode, frontend=frontend)
444481
routine = source['test']
445482

446-
call = FindNodes(CallStatement).visit(routine.body)[0]
483+
calls = FindNodes(CallStatement).visit(routine.body)
447484
routine.enrich(source.all_subroutines)
448485

449486
with dataflow_analysis_attached(routine):
450-
assert 'n' in call.uses_symbols
451-
assert not 'n' in call.defines_symbols
487+
assert 'n' in calls[0].uses_symbols
488+
assert not 'n' in calls[0].defines_symbols
489+
assert 'b' in calls[1].uses_symbols
490+
assert not 'b' in calls[0].defines_symbols
452491

453492

454493
@pytest.mark.parametrize('frontend', available_frontends())

loki/batch/scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,6 @@ def _get_definition_items(_item, sgraph_items):
503503

504504
if transformation.creates_items:
505505
self._discover()
506-
507506
self._parse_items()
508507

509508
def callgraph(self, path, with_file_graph=False, with_legend=False):

loki/frontend/regex.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -900,8 +900,8 @@ def __init__(self):
900900
super().__init__(
901901
r'^(((?:type|class)[ \t]*\([ \t]*(?P<typename>\w+)[ \t]*\))|' # TYPE or CLASS keyword with typename
902902
r'^([ \t]*(?P<basic_type>(logical|real|integer|complex|character))'
903-
r'(?P<param>\((kind|len)=[a-z0-9_-]+\))?[ \t]*))'
904-
r'(?:[ \t]*,[ \t]*[a-z]+(?:\((.(\(.*\))?)*?\))?)*' # Optional attributes
903+
r'[ \t]*(?P<param>\([ \t]*(kind|len)[ \t]*=[ \t]*[a-z0-9_-]+[ \t]*\))?[ \t]*))'
904+
r'(?:[ \t]*,[ \t]*[a-z]+(?:[ \t]*\((.(\(.*\))?)*?\))?)*' # Optional attributes
905905
r'(?:[ \t]*::)?' # Optional `::` delimiter
906906
r'[ \t]*' # Some white space
907907
r'(?P<variables>\w+\b.*?)$', # Variable names
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# (C) Copyright 2018- ECMWF.
2+
# This software is licensed under the terms of the Apache Licence Version 2.0
3+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
4+
# In applying this licence, ECMWF does not waive the privileges and immunities
5+
# granted to it by virtue of its status as an intergovernmental organisation
6+
# nor does it submit to any jurisdiction.
7+
8+
"""
9+
Verify correct parsing behaviour of the REGEX frontend
10+
"""
11+
12+
from loki.frontend import REGEX
13+
from loki.types import BasicType, DerivedType
14+
from loki.subroutine import Subroutine
15+
16+
def test_declaration_whitespace_attributes():
17+
"""
18+
Test correct behaviour with/without white space inside declaration attributes
19+
(reported in #318).
20+
"""
21+
fcode = """
22+
subroutine my_whitespace_declaration_routine(kdim, state_t0, paux)
23+
use type_header, only: dimension_type, STATE_TYPE, aux_type, jprb
24+
implicit none
25+
TYPE( DIMENSION_TYPE) , INTENT (IN) :: KDIM
26+
type (state_type ) , intent ( in ) :: state_t0
27+
TYPE (AUX_TYPE) , InteNT( In) :: PAUX
28+
CHARACTER ( LEN=10) :: STR
29+
REAL( KIND = JPRB ) :: VAR
30+
end subroutine
31+
""".strip()
32+
33+
routine = Subroutine.from_source(fcode, frontend=REGEX)
34+
35+
# Verify that variables and dtype information has been extracted correctly
36+
assert routine.variables == ('kdim', 'state_t0', 'paux', 'str', 'var')
37+
assert isinstance(routine.variable_map['kdim'].type.dtype, DerivedType)
38+
assert routine.variable_map['kdim'].type.dtype.name.lower() == 'dimension_type'
39+
assert isinstance(routine.variable_map['state_t0'].type.dtype, DerivedType)
40+
assert routine.variable_map['state_t0'].type.dtype.name.lower() == 'state_type'
41+
assert isinstance(routine.variable_map['paux'].type.dtype, DerivedType)
42+
assert routine.variable_map['paux'].type.dtype.name.lower() == 'aux_type'
43+
assert routine.variable_map['str'].type.dtype == BasicType.CHARACTER
44+
assert routine.variable_map['var'].type.dtype == BasicType.REAL
45+
46+
routine.make_complete()
47+
48+
# Verify that additional type attributes are correct after full parse
49+
assert routine.variables == ('kdim', 'state_t0', 'paux', 'str', 'var')
50+
assert isinstance(routine.variable_map['kdim'].type.dtype, DerivedType)
51+
assert routine.variable_map['kdim'].type.dtype.name.lower() == 'dimension_type'
52+
assert routine.variable_map['kdim'].type.intent == 'in'
53+
assert isinstance(routine.variable_map['state_t0'].type.dtype, DerivedType)
54+
assert routine.variable_map['state_t0'].type.dtype.name.lower() == 'state_type'
55+
assert routine.variable_map['state_t0'].type.intent == 'in'
56+
assert isinstance(routine.variable_map['paux'].type.dtype, DerivedType)
57+
assert routine.variable_map['paux'].type.dtype.name.lower() == 'aux_type'
58+
assert routine.variable_map['paux'].type.intent == 'in'
59+
assert routine.variable_map['str'].type.dtype == BasicType.CHARACTER
60+
assert routine.variable_map['str'].type.length == 10
61+
assert routine.variable_map['var'].type.dtype == BasicType.REAL
62+
assert routine.variable_map['var'].type.kind == 'jprb'

loki/ir/nodes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1535,7 +1535,10 @@ def parent_type(self):
15351535
return None
15361536
if not self.parent:
15371537
return BasicType.DEFERRED
1538-
return self.parent.symbol_attrs[self.extends].dtype.typedef
1538+
parent_type = self.parent.symbol_attrs.lookup(self.extends)
1539+
if not (parent_type and isinstance(parent_type.dtype, DerivedType)):
1540+
return BasicType.DEFERRED
1541+
return parent_type.dtype.typedef
15391542

15401543
@property
15411544
def declarations(self):

loki/tests/test_derived_types.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,3 +1520,42 @@ def test_derived_type_symbol_inheritance(frontend):
15201520
assert not extended_extended_type.imported_symbol_map
15211521
#check for non-empty declarations
15221522
assert all(decl.symbols for decl in extended_extended_type.declarations)
1523+
1524+
1525+
@pytest.mark.parametrize('frontend', available_frontends())
1526+
@pytest.mark.parametrize('qualified_import', (True, False))
1527+
def test_derived_type_inheritance_missing_parent(frontend, qualified_import, tmp_path):
1528+
fcode_parent = """
1529+
module parent_mod
1530+
implicit none
1531+
type, abstract, public :: parent_type
1532+
integer :: val
1533+
end type parent_type
1534+
end module parent_mod
1535+
""".strip()
1536+
1537+
fcode_derived = f"""
1538+
module derived_mod
1539+
use parent_mod{", only: parent_type" if qualified_import else ""}
1540+
implicit none
1541+
type, public, extends(parent_type) :: derived_type
1542+
integer :: val2
1543+
end type derived_type
1544+
contains
1545+
subroutine do_something(this)
1546+
class(derived_type), intent(inout) :: this
1547+
this%val = 1
1548+
this%val2 = 2
1549+
end subroutine do_something
1550+
end module derived_mod
1551+
""".strip()
1552+
1553+
parent = Module.from_source(fcode_parent, frontend=frontend, xmods=[tmp_path])
1554+
1555+
# Without enrichment we obtain only DEFERRED type information (but don't fail!)
1556+
derived = Module.from_source(fcode_derived, frontend=frontend, xmods=[tmp_path])
1557+
assert derived['derived_type'].parent_type == BasicType.DEFERRED
1558+
1559+
# With enrichment we obtain the parent type from the parent module
1560+
derived = Module.from_source(fcode_derived, frontend=frontend, xmods=[tmp_path], definitions=[parent])
1561+
assert derived['derived_type'].parent_type is parent['parent_type']

loki/transformations/data_offload.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,10 @@ def transform_subroutine(self, routine, **kwargs):
296296
var for var in routine.body.uses_symbols
297297
if var.name in import_map or (var.parent and var.parents[0].name in import_map)
298298
}
299+
uses_imported_symbols |= {
300+
var for var in routine.spec.uses_symbols
301+
if var.name in import_map or (var.parent and var.parents[0].name in import_map)
302+
}
299303
defines_imported_symbols = {
300304
var for var in routine.body.defines_symbols
301305
if var.name in import_map or (var.parent and var.parents[0].name in import_map)

loki/transformations/inline.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ class InlineTransformation(Transformation):
7979
# Ensure correct recursive inlining by traversing from the leaves
8080
reverse_traversal = True
8181

82+
# This transformation will potentially change the edges in the callgraph
83+
creates_items = False
84+
8285
def __init__(
8386
self, inline_constants=False, inline_elementals=True,
8487
inline_internals=False, inline_marked=True,
@@ -95,6 +98,8 @@ def __init__(
9598
self.adjust_imports = adjust_imports
9699
self.external_only = external_only
97100
self.resolve_sequence_association = resolve_sequence_association
101+
if self.inline_marked:
102+
self.creates_items = True
98103

99104
def transform_subroutine(self, routine, **kwargs):
100105

@@ -211,9 +216,9 @@ def resolve_sequence_association_for_inlined_calls(routine, inline_internals, in
211216
# asked sequence assoc to happen with inlining, so source for routine should be
212217
# found in calls to be inlined.
213218
raise ValueError(
214-
f"Cannot resolve sequence association for call to `{call.name}` " +
215-
f"to be inlined in routine `{routine.name}`, because " +
216-
f"the `CallStatement` referring to `{call.name}` does not contain " +
219+
f"Cannot resolve sequence association for call to ``{call.name}`` " +
220+
f"to be inlined in routine ``{routine.name}``, because " +
221+
f"the ``CallStatement`` referring to ``{call.name}`` does not contain " +
217222
"the source code of the procedure. " +
218223
"If running in batch processing mode, please recheck Scheduler configuration."
219224
)

loki/transformations/tests/test_data_offload.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def fixture_global_variable_analysis_code(tmp_path):
256256
integer, parameter :: nval = 5
257257
integer, parameter :: nfld = 3
258258
259+
integer :: n
260+
259261
integer :: iarr(nfld)
260262
real :: rarr(nval, nfld)
261263
end module global_var_analysis_header_mod
@@ -297,10 +299,11 @@ def fixture_global_variable_analysis_code(tmp_path):
297299
298300
contains
299301
subroutine kernel_a(arg, tt)
300-
use global_var_analysis_header_mod, only: iarr, nval, nfld
302+
use global_var_analysis_header_mod, only: iarr, nval, nfld, n
301303
302304
real, intent(inout) :: arg(:,:)
303305
type(some_type), intent(in) :: tt
306+
real :: tmp(n)
304307
integer :: i, j
305308
306309
do i=1,nfld
@@ -390,7 +393,7 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
390393

391394
expected_trafo_data = {
392395
'global_var_analysis_header_mod': {
393-
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
396+
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'},
394397
'offload': {}
395398
},
396399
'global_var_analysis_data_mod': {
@@ -402,6 +405,7 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
402405
'defines_symbols': set(),
403406
'uses_symbols': nval_data | nfld_data | {
404407
(f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
408+
('n', 'global_var_analysis_header_mod'),
405409
(f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
406410
}
407411
},
@@ -416,6 +420,7 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
416420
'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')},
417421
'uses_symbols': nval_data | nfld_data | {
418422
('rdata(:, :, :)', 'global_var_analysis_data_mod'),
423+
('n', 'global_var_analysis_header_mod'),
419424
('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'),
420425
(f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
421426
(f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
@@ -465,8 +470,8 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
465470

466471
expected_trafo_data = {
467472
'global_var_analysis_header_mod': {
468-
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
469-
'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}
473+
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'},
474+
'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'}
470475
},
471476
'global_var_analysis_data_mod': {
472477
'declares': {'rdata(:, :, :)', 'tt'},
@@ -486,7 +491,7 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
486491

487492
# Verify imports have been added to the driver
488493
expected_imports = {
489-
'global_var_analysis_header_mod': {'iarr', 'rarr'},
494+
'global_var_analysis_header_mod': {'iarr', 'rarr', 'n'},
490495
'global_var_analysis_data_mod': {'rdata'}
491496
}
492497

@@ -495,7 +500,7 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
495500
assert {var.name.lower() for var in import_.symbols} == expected_imports[import_.module.lower()]
496501

497502
expected_h2d_pragmas = {
498-
'update device': {'iarr', 'rdata', 'rarr'},
503+
'update device': {'iarr', 'rdata', 'rarr', 'n'},
499504
'enter data copyin': {'tt%vals'}
500505
}
501506
expected_d2h_pragmas = {
@@ -515,7 +520,7 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
515520

516521
# Verify declarations have been added to the header modules
517522
expected_declarations = {
518-
'global_var_analysis_header_mod': {'iarr', 'rarr'},
523+
'global_var_analysis_header_mod': {'iarr', 'rarr', 'n'},
519524
'global_var_analysis_data_mod': {'rdata', 'tt'}
520525
}
521526

0 commit comments

Comments
 (0)