Skip to content

Commit b9be6f6

Browse files
authored
Merge pull request #320 from ecmwf-ifs/naan-dataflow-fixes
`DataFlowAnalysis` bug fixes
2 parents 844c5d1 + cf8c8e5 commit b9be6f6

File tree

4 files changed

+67
-17
lines changed

4 files changed

+67
-17
lines changed

loki/analyse/analyse_dataflow.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212
from contextlib import contextmanager
13-
from loki.expression import FindVariables, Array, FindInlineCalls
13+
from loki.expression import FindVariables, Array, FindInlineCalls, ProcedureSymbol, FindTypedSymbols
1414
from loki.tools import as_tuple, flatten
1515
from loki.types import BasicType
1616
from loki.ir import Visitor, Transformer
@@ -232,9 +232,9 @@ def visit_CallStatement(self, o, **kwargs):
232232
outvals = [val for arg, val in o.arg_iter() if str(arg.type.intent).lower() in ('inout', 'out')]
233233
invals = [val for arg, val in o.arg_iter() if str(arg.type.intent).lower() in ('inout', 'in')]
234234

235+
arrays = [v for v in FindVariables().visit(outvals) if isinstance(v, Array)]
236+
dims = set(v for a in arrays for v in self._symbols_from_expr(a.dimensions))
235237
for val in outvals:
236-
arrays = [v for v in FindVariables().visit(outvals) if isinstance(v, Array)]
237-
dims = set(v for a in arrays for v in FindVariables().visit(a.dimensions))
238238
exprs = self._symbols_from_expr(val)
239239
defines |= {e for e in exprs if not e in dims}
240240
uses |= dims
@@ -269,12 +269,14 @@ def visit_Deallocation(self, o, **kwargs):
269269
visit_Nullify = visit_Deallocation
270270

271271
def visit_Import(self, o, **kwargs):
272-
defines = self._symbols_from_expr(o.symbols or ())
272+
defines = set(s.clone(dimensions=None) for s in FindTypedSymbols().visit(o.symbols or ())
273+
if isinstance(s, ProcedureSymbol))
273274
return self.visit_Node(o, defines_symbols=defines, **kwargs)
274275

275276
def visit_VariableDeclaration(self, o, **kwargs):
276277
defines = self._symbols_from_expr(o.symbols, condition=lambda v: v.type.initial is not None)
277-
return self.visit_Node(o, defines_symbols=defines, **kwargs)
278+
uses = {v for a in o.symbols if isinstance(a, Array) for v in self._symbols_from_expr(a.dimensions)}
279+
return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs)
278280

279281

280282
class DataflowAnalysisDetacher(Transformer):

loki/analyse/tests/test_analyse_dataflow.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from loki import (
1111
Subroutine, FindNodes, Assignment, Loop, Conditional, Pragma, fgen, Sourcefile,
1212
CallStatement, MultiConditional, MaskedStatement, ProcedureSymbol, WhileLoop,
13-
Associate
13+
Associate, Module
1414
)
1515
from loki.analyse import (
1616
dataflow_analysis_attached, read_after_write_vars, loop_carried_dependencies
@@ -281,6 +281,41 @@ def test_analyse_interface(frontend):
281281
assert isinstance(list(routine.spec.defines_symbols)[0], ProcedureSymbol)
282282
assert 'random_call' in routine.spec.defines_symbols
283283

284+
285+
@pytest.mark.parametrize('frontend', available_frontends())
286+
def test_analyse_imports(frontend):
287+
fcode_module = """
288+
module some_mod
289+
implicit none
290+
real :: my_global
291+
contains
292+
subroutine random_call(v_out,v_in,v_inout)
293+
294+
real,intent(in) :: v_in
295+
real,intent(out) :: v_out
296+
real,intent(inout) :: v_inout
297+
298+
299+
end subroutine random_call
300+
end module some_mod
301+
""".strip()
302+
303+
fcode = """
304+
subroutine test()
305+
use some_mod, only: my_global, random_call
306+
implicit none
307+
308+
end subroutine test
309+
""".strip()
310+
311+
module = Module.from_source(fcode_module, frontend=frontend)
312+
routine = Subroutine.from_source(fcode, frontend=frontend, definitions=module)
313+
314+
with dataflow_analysis_attached(routine):
315+
assert len(routine.spec.defines_symbols) == 1
316+
assert 'random_call' in routine.spec.defines_symbols
317+
318+
284319
@pytest.mark.parametrize('frontend', available_frontends())
285320
def test_analyse_enriched_call(frontend):
286321
fcode = """
@@ -429,26 +464,30 @@ def test_analyse_call_args_array_slicing(frontend):
429464
430465
end subroutine random_call
431466
432-
subroutine test(v,n)
467+
subroutine test(v,n,b)
433468
implicit none
434469
435470
integer,intent(out) :: v(:)
436471
integer,intent( in) :: n
472+
integer,intent( in) :: b(n)
437473
438474
call random_call(v(n))
475+
call random_call(v(b(1)))
439476
440477
end subroutine test
441478
""".strip()
442479

443480
source = Sourcefile.from_source(fcode, frontend=frontend)
444481
routine = source['test']
445482

446-
call = FindNodes(CallStatement).visit(routine.body)[0]
483+
calls = FindNodes(CallStatement).visit(routine.body)
447484
routine.enrich(source.all_subroutines)
448485

449486
with dataflow_analysis_attached(routine):
450-
assert 'n' in call.uses_symbols
451-
assert not 'n' in call.defines_symbols
487+
assert 'n' in calls[0].uses_symbols
488+
assert not 'n' in calls[0].defines_symbols
489+
assert 'b' in calls[1].uses_symbols
490+
assert not 'b' in calls[0].defines_symbols
452491

453492

454493
@pytest.mark.parametrize('frontend', available_frontends())

loki/transformations/data_offload.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,10 @@ def transform_subroutine(self, routine, **kwargs):
296296
var for var in routine.body.uses_symbols
297297
if var.name in import_map or (var.parent and var.parents[0].name in import_map)
298298
}
299+
uses_imported_symbols |= {
300+
var for var in routine.spec.uses_symbols
301+
if var.name in import_map or (var.parent and var.parents[0].name in import_map)
302+
}
299303
defines_imported_symbols = {
300304
var for var in routine.body.defines_symbols
301305
if var.name in import_map or (var.parent and var.parents[0].name in import_map)

loki/transformations/tests/test_data_offload.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def fixture_global_variable_analysis_code(tmp_path):
256256
integer, parameter :: nval = 5
257257
integer, parameter :: nfld = 3
258258
259+
integer :: n
260+
259261
integer :: iarr(nfld)
260262
real :: rarr(nval, nfld)
261263
end module global_var_analysis_header_mod
@@ -297,10 +299,11 @@ def fixture_global_variable_analysis_code(tmp_path):
297299
298300
contains
299301
subroutine kernel_a(arg, tt)
300-
use global_var_analysis_header_mod, only: iarr, nval, nfld
302+
use global_var_analysis_header_mod, only: iarr, nval, nfld, n
301303
302304
real, intent(inout) :: arg(:,:)
303305
type(some_type), intent(in) :: tt
306+
real :: tmp(n)
304307
integer :: i, j
305308
306309
do i=1,nfld
@@ -390,7 +393,7 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
390393

391394
expected_trafo_data = {
392395
'global_var_analysis_header_mod': {
393-
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
396+
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'},
394397
'offload': {}
395398
},
396399
'global_var_analysis_data_mod': {
@@ -402,6 +405,7 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
402405
'defines_symbols': set(),
403406
'uses_symbols': nval_data | nfld_data | {
404407
(f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
408+
('n', 'global_var_analysis_header_mod'),
405409
(f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
406410
}
407411
},
@@ -416,6 +420,7 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
416420
'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')},
417421
'uses_symbols': nval_data | nfld_data | {
418422
('rdata(:, :, :)', 'global_var_analysis_data_mod'),
423+
('n', 'global_var_analysis_header_mod'),
419424
('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'),
420425
(f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
421426
(f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
@@ -465,8 +470,8 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
465470

466471
expected_trafo_data = {
467472
'global_var_analysis_header_mod': {
468-
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
469-
'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}
473+
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'},
474+
'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'}
470475
},
471476
'global_var_analysis_data_mod': {
472477
'declares': {'rdata(:, :, :)', 'tt'},
@@ -486,7 +491,7 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
486491

487492
# Verify imports have been added to the driver
488493
expected_imports = {
489-
'global_var_analysis_header_mod': {'iarr', 'rarr'},
494+
'global_var_analysis_header_mod': {'iarr', 'rarr', 'n'},
490495
'global_var_analysis_data_mod': {'rdata'}
491496
}
492497

@@ -495,7 +500,7 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
495500
assert {var.name.lower() for var in import_.symbols} == expected_imports[import_.module.lower()]
496501

497502
expected_h2d_pragmas = {
498-
'update device': {'iarr', 'rdata', 'rarr'},
503+
'update device': {'iarr', 'rdata', 'rarr', 'n'},
499504
'enter data copyin': {'tt%vals'}
500505
}
501506
expected_d2h_pragmas = {
@@ -515,7 +520,7 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
515520

516521
# Verify declarations have been added to the header modules
517522
expected_declarations = {
518-
'global_var_analysis_header_mod': {'iarr', 'rarr'},
523+
'global_var_analysis_header_mod': {'iarr', 'rarr', 'n'},
519524
'global_var_analysis_data_mod': {'rdata', 'tt'}
520525
}
521526

0 commit comments

Comments
 (0)