1111import pytest
1212
1313from loki .frontend import available_frontends , OMNI
14- from loki import Sourcefile , FindNodes , CallStatement , fgen , Conditional
14+ from loki import Sourcefile , FindNodes , CallStatement , fgen , Conditional , ProcedureItem
1515
1616from transformations .parallel_routine_dispatch import ParallelRoutineDispatchTransformation
1717
@@ -25,13 +25,14 @@ def fixture_here():
2525def test_parallel_routine_dispatch_dr_hook (here , frontend ):
2626
2727 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
28+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
2829 routine = source ['dispatch_routine' ]
2930
3031 calls = FindNodes (CallStatement ).visit (routine .body )
3132 assert len (calls ) == 3
3233
3334 transformation = ParallelRoutineDispatchTransformation ()
34- transformation .apply (source ['dispatch_routine' ])
35+ transformation .apply (source ['dispatch_routine' ], item = item )
3536
3637 calls = [call for call in FindNodes (CallStatement ).visit (routine .body ) if call .name .name == 'DR_HOOK' ]
3738 assert len (calls ) == 8
@@ -40,10 +41,11 @@ def test_parallel_routine_dispatch_dr_hook(here, frontend):
4041def test_parallel_routine_dispatch_decl_local_arrays (here , frontend ):
4142
4243 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
44+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
4345 routine = source ['dispatch_routine' ]
4446
4547 transformation = ParallelRoutineDispatchTransformation ()
46- transformation .apply (source ['dispatch_routine' ])
48+ transformation .apply (source ['dispatch_routine' ], item = item )
4749 var_lst = ["YL_ZRDG_CVGQ" , "ZRDG_CVGQ" , "YL_ZRDG_MU0LU" , "ZRDG_MU0LU" , "YL_ZRDG_MU0M" , "ZRDG_MU0M" , "YL_ZRDG_MU0N" , "ZRDG_MU0N" , "YL_ZRDG_MU0" , "ZRDG_MU0" ]
4850 dcls = [dcl for dcl in routine .declarations if dcl .symbols [0 ].name in var_lst ]
4951 str_dcls = ""
@@ -65,10 +67,11 @@ def test_parallel_routine_dispatch_decl_local_arrays(here, frontend):
6567def test_parallel_routine_dispatch_decl_field_create_delete (here , frontend ):
6668
6769 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
70+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
6871 routine = source ['dispatch_routine' ]
6972
7073 transformation = ParallelRoutineDispatchTransformation ()
71- transformation .apply (source ['dispatch_routine' ])
74+ transformation .apply (source ['dispatch_routine' ], item = item )
7275
7376 var_lst = ["YL_ZRDG_CVGQ" , "ZRDG_CVGQ" , "YL_ZRDG_MU0LU" , "ZRDG_MU0LU" , "YL_ZRDG_MU0M" , "ZRDG_MU0M" , "YL_ZRDG_MU0N" , "ZRDG_MU0N" , "YL_ZRDG_MU0" , "ZRDG_MU0" ]
7477 field_create = ["CALL FIELD_NEW(YL_ZRDG_CVGQ, UBOUNDS=(/ YDCPG_OPTS%KLON, YDCPG_OPTS%KFLEVG, YDCPG_OPTS%KGPBLKS /), LBOUNDS=(/ 0, 1 /), &\n & PERSISTENT=.true.)" ,
@@ -105,10 +108,11 @@ def test_parallel_routine_dispatch_decl_field_create_delete(here, frontend):
105108def test_parallel_routine_dispatch_derived_dcl (here , frontend ):
106109
107110 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
111+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
108112 routine = source ['dispatch_routine' ]
109113
110114 transformation = ParallelRoutineDispatchTransformation ()
111- transformation .apply (source ['dispatch_routine' ])
115+ transformation .apply (source ['dispatch_routine' ], item = item )
112116
113117 dcls = [fgen (dcl ) for dcl in routine .spec .body [- 13 :- 1 ]]
114118
@@ -132,10 +136,11 @@ def test_parallel_routine_dispatch_derived_dcl(here, frontend):
132136def test_parallel_routine_dispatch_derived_var (here , frontend ):
133137
134138 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
139+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
135140 routine = source ['dispatch_routine' ]
136141
137142 transformation = ParallelRoutineDispatchTransformation ()
138- transformation .apply (source ['dispatch_routine' ])
143+ transformation .apply (source ['dispatch_routine' ], item = item )
139144
140145
141146 test_map = {
@@ -153,8 +158,9 @@ def test_parallel_routine_dispatch_derived_var(here, frontend):
153158 "YDCPG_DYN0%CTY%EVEL" : ["YDCPG_DYN0%CTY%F_EVEL" , "Z_YDCPG_DYN0_CTY_EVEL" ],
154159 "YDMF_PHYS_SURF%GSD_VF%PZ0F" : ["YDMF_PHYS_SURF%GSD_VF%F_Z0F" , "Z_YDMF_PHYS_SURF_GSD_VF_PZ0F" ]
155160 }
156- for var_name in transformation .routine_map_derived :
157- value = transformation .routine_map_derived [var_name ]
161+ routine_map_derived = item .trafo_data ['create_parallel' ]['map_routine' ]['routine_map_derived' ]
162+ for var_name in routine_map_derived :
163+ value = routine_map_derived [var_name ]
158164 field_ptr = value [0 ]
159165 ptr = value [1 ]
160166
@@ -165,12 +171,13 @@ def test_parallel_routine_dispatch_derived_var(here, frontend):
165171def test_parallel_routine_dispatch_get_data (here , frontend ):
166172
167173 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
174+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
168175 routine = source ['dispatch_routine' ]
169176
170177 transformation = ParallelRoutineDispatchTransformation ()
171- transformation .apply (source ['dispatch_routine' ])
178+ transformation .apply (source ['dispatch_routine' ], item = item )
172179
173- get_data = transformation . get_data
180+ get_data = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ ' get_data' ]
174181
175182 test_get_data = {}
176183# test_get_data["OpenMP"] = """
@@ -276,7 +283,7 @@ def test_parallel_routine_dispatch_get_data(here, frontend):
276283### routine = source['dispatch_routine']
277284###
278285### transformation = ParallelRoutineDispatchTransformation()
279- ### transformation.apply(source['dispatch_routine'])
286+ ### transformation.apply(source['dispatch_routine'], item=item )
280287###
281288### get_data = transformation.get_data
282289###
@@ -294,12 +301,13 @@ def test_parallel_routine_dispatch_get_data(here, frontend):
294301def test_parallel_routine_dispatch_synchost (here , frontend ):
295302
296303 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
304+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
297305 routine = source ['dispatch_routine' ]
298306
299307 transformation = ParallelRoutineDispatchTransformation ()
300- transformation .apply (source ['dispatch_routine' ])
308+ transformation .apply (source ['dispatch_routine' ], item = item )
301309
302- synchost = transformation . synchost [ 0 ]
310+ synchost = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ 'synchost' ]
303311
304312 test_synchost = """IF (LSYNCHOST('DISPATCH_ROUTINE:CPPHINP')) THEN
305313 IF (LHOOK) CALL DR_HOOK('DISPATCH_ROUTINE:CPPHINP:SYNCHOST', 0, ZHOOK_HANDLE_FIELD_API)
@@ -332,12 +340,13 @@ def test_parallel_routine_dispatch_synchost(here, frontend):
332340def test_parallel_routine_dispatch_nullify (here , frontend ):
333341
334342 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
343+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
335344 routine = source ['dispatch_routine' ]
336345
337346 transformation = ParallelRoutineDispatchTransformation ()
338- transformation .apply (source ['dispatch_routine' ])
347+ transformation .apply (source ['dispatch_routine' ], item = item )
339348
340- nullify = transformation . nullify
349+ nullify = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ ' nullify' ]
341350
342351 test_nullify = """
343352IF (LHOOK) CALL DR_HOOK('DISPATCH_ROUTINE:CPPHINP:NULLIFY', 0, ZHOOK_HANDLE_FIELD_API)
@@ -370,12 +379,13 @@ def test_parallel_routine_dispatch_nullify(here, frontend):
370379def test_parallel_routine_dispatch_compute_openmp (here , frontend ):
371380
372381 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
382+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
373383 routine = source ['dispatch_routine' ]
374384
375385 transformation = ParallelRoutineDispatchTransformation ()
376- transformation .apply (source ['dispatch_routine' ])
386+ transformation .apply (source ['dispatch_routine' ], item = item )
377387
378- map_compute = transformation . compute
388+ map_compute = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ ' compute' ]
379389 compute_openmp = map_compute ['OpenMP' ]
380390
381391 test_compute = """
@@ -423,12 +433,13 @@ def test_parallel_routine_dispatch_compute_openmp(here, frontend):
423433def test_parallel_routine_dispatch_compute_openmpscc (here , frontend ):
424434
425435 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
436+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
426437 routine = source ['dispatch_routine' ]
427438
428439 transformation = ParallelRoutineDispatchTransformation ()
429- transformation .apply (source ['dispatch_routine' ])
440+ transformation .apply (source ['dispatch_routine' ], item = item )
430441
431- map_compute = transformation . compute
442+ map_compute = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ ' compute' ]
432443 compute_openmpscc = map_compute ['OpenMPSingleColumn' ]
433444
434445 test_compute = """
@@ -490,12 +501,13 @@ def test_parallel_routine_dispatch_compute_openmpscc(here, frontend):
490501def test_parallel_routine_dispatch_compute_openaccscc (here , frontend ):
491502
492503 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
504+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
493505 routine = source ['dispatch_routine' ]
494506
495507 transformation = ParallelRoutineDispatchTransformation ()
496- transformation .apply (source ['dispatch_routine' ])
508+ transformation .apply (source ['dispatch_routine' ], item = item )
497509
498- map_compute = transformation . compute
510+ map_compute = item . trafo_data [ 'create_parallel' ][ 'map_region' ][ ' compute' ]
499511 compute_openaccscc = map_compute ['OpenACCSingleColumn' ]
500512
501513 test_compute = """
@@ -576,12 +588,13 @@ def test_parallel_routine_dispatch_compute_openaccscc(here, frontend):
576588def test_parallel_routine_dispatch_variables (here , frontend ):
577589
578590 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
591+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
579592 routine = source ['dispatch_routine' ]
580593
581594 transformation = ParallelRoutineDispatchTransformation ()
582- transformation .apply (source ['dispatch_routine' ])
595+ transformation .apply (source ['dispatch_routine' ], item = item )
583596
584- variables = transformation . dcls
597+ variables = item . trafo_data [ 'create_parallel' ][ 'map_routine' ][ ' dcls' ]
585598
586599 test_variables = '''TYPE(CPG_BNDS_TYPE), INTENT(IN) :: YLCPG_BNDS
587600TYPE(STACK) :: YLSTACK
@@ -597,12 +610,13 @@ def test_parallel_routine_dispatch_imports(here, frontend):
597610 #TODO : add imports to _parallel routines
598611
599612 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
613+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
600614 routine = source ['dispatch_routine' ]
601615
602616 transformation = ParallelRoutineDispatchTransformation ()
603- transformation .apply (source ['dispatch_routine' ])
617+ transformation .apply (source ['dispatch_routine' ], item = item )
604618
605- imports = transformation . imports
619+ imports = item . trafo_data [ 'create_parallel' ][ 'map_routine' ][ ' imports' ]
606620
607621 test_imports = """
608622USE ACPY_MOD
@@ -621,11 +635,12 @@ def test_parallel_routine_dispatch_new_callee_imports(here, frontend):
621635 #TODO : add imports to _parallel routines
622636
623637 source = Sourcefile .from_file (here / 'sources/projParallelRoutineDispatch/dispatch_routine.F90' , frontend = frontend )
638+ item = ProcedureItem (name = 'parallel_routine_dispatch' , source = source )
624639 routine = source ['dispatch_routine' ]
625640
626641 transformation = ParallelRoutineDispatchTransformation ()
627- transformation .apply (source ['dispatch_routine' ])
642+ transformation .apply (source ['dispatch_routine' ], item = item )
628643
629- imports = transformation . callee_imports
644+ imports = item . trafo_data [ 'create_parallel' ][ 'map_routine' ][ ' callee_imports' ]
630645
631646 assert fgen (imports ) == '#include "cpphinp_openacc.intfb.h"'
0 commit comments