Skip to content

Commit 6c6e661

Browse files
committed
GlobalVarOffload: globals used to declare array sizes now also offloaded
1 parent 68aa07f commit 6c6e661

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

loki/analyse/analyse_dataflow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,9 @@ def visit_Import(self, o, **kwargs):
275275

276276
def visit_VariableDeclaration(self, o, **kwargs):
277277
defines = self._symbols_from_expr(o.symbols, condition=lambda v: v.type.initial is not None)
278-
return self.visit_Node(o, defines_symbols=defines, **kwargs)
278+
arrays = [a for a in o.symbols if isinstance(a, Array)]
279+
uses = set(v for a in arrays for v in FindVariables().visit(a.dimensions))
280+
return self.visit_Node(o, defines_symbols=defines, uses_symbols=uses, **kwargs)
279281

280282

281283
class DataflowAnalysisDetacher(Transformer):

loki/transformations/data_offload.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,10 @@ def transform_subroutine(self, routine, **kwargs):
296296
var for var in routine.body.uses_symbols
297297
if var.name in import_map or (var.parent and var.parents[0].name in import_map)
298298
}
299+
uses_imported_symbols |= {
300+
var for var in routine.spec.uses_symbols
301+
if var.name in import_map or (var.parent and var.parents[0].name in import_map)
302+
}
299303
defines_imported_symbols = {
300304
var for var in routine.body.defines_symbols
301305
if var.name in import_map or (var.parent and var.parents[0].name in import_map)

loki/transformations/tests/test_data_offload.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def fixture_global_variable_analysis_code(tmp_path):
256256
integer, parameter :: nval = 5
257257
integer, parameter :: nfld = 3
258258
259+
integer :: n
260+
259261
integer :: iarr(nfld)
260262
real :: rarr(nval, nfld)
261263
end module global_var_analysis_header_mod
@@ -297,10 +299,11 @@ def fixture_global_variable_analysis_code(tmp_path):
297299
298300
contains
299301
subroutine kernel_a(arg, tt)
300-
use global_var_analysis_header_mod, only: iarr, nval, nfld
302+
use global_var_analysis_header_mod, only: iarr, nval, nfld, n
301303
302304
real, intent(inout) :: arg(:,:)
303305
type(some_type), intent(in) :: tt
306+
real :: tmp(n)
304307
integer :: i, j
305308
306309
do i=1,nfld
@@ -390,7 +393,7 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
390393

391394
expected_trafo_data = {
392395
'global_var_analysis_header_mod': {
393-
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
396+
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'},
394397
'offload': {}
395398
},
396399
'global_var_analysis_data_mod': {
@@ -402,6 +405,7 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
402405
'defines_symbols': set(),
403406
'uses_symbols': nval_data | nfld_data | {
404407
(f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
408+
('n', 'global_var_analysis_header_mod'),
405409
(f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
406410
}
407411
},
@@ -416,6 +420,7 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
416420
'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')},
417421
'uses_symbols': nval_data | nfld_data | {
418422
('rdata(:, :, :)', 'global_var_analysis_data_mod'),
423+
('n', 'global_var_analysis_header_mod'),
419424
('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'),
420425
(f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
421426
(f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
@@ -428,6 +433,8 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
428433
if item == 'global_var_analysis_data_mod#some_type':
429434
continue
430435
for trafo_data_key, trafo_data_value in item.trafo_data[key].items():
436+
print(item)
437+
print(trafo_data_value)
431438
assert (
432439
sorted(
433440
tuple(str(vv) for vv in v) if isinstance(v, tuple) else str(v)
@@ -465,8 +472,8 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
465472

466473
expected_trafo_data = {
467474
'global_var_analysis_header_mod': {
468-
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
469-
'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}
475+
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'},
476+
'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})', 'n'}
470477
},
471478
'global_var_analysis_data_mod': {
472479
'declares': {'rdata(:, :, :)', 'tt'},
@@ -486,7 +493,7 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
486493

487494
# Verify imports have been added to the driver
488495
expected_imports = {
489-
'global_var_analysis_header_mod': {'iarr', 'rarr'},
496+
'global_var_analysis_header_mod': {'iarr', 'rarr', 'n'},
490497
'global_var_analysis_data_mod': {'rdata'}
491498
}
492499

@@ -495,7 +502,7 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
495502
assert {var.name.lower() for var in import_.symbols} == expected_imports[import_.module.lower()]
496503

497504
expected_h2d_pragmas = {
498-
'update device': {'iarr', 'rdata', 'rarr'},
505+
'update device': {'iarr', 'rdata', 'rarr', 'n'},
499506
'enter data copyin': {'tt%vals'}
500507
}
501508
expected_d2h_pragmas = {
@@ -515,7 +522,7 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
515522

516523
# Verify declarations have been added to the header modules
517524
expected_declarations = {
518-
'global_var_analysis_header_mod': {'iarr', 'rarr'},
525+
'global_var_analysis_header_mod': {'iarr', 'rarr', 'n'},
519526
'global_var_analysis_data_mod': {'rdata', 'tt'}
520527
}
521528

0 commit comments

Comments
 (0)