@@ -1127,6 +1127,147 @@ def test_single_column_coalesced_hoist_nested_openacc(frontend, horizontal, vert
11271127 assert outer_kernel_pragmas [2 ].keyword == 'acc'
11281128 assert outer_kernel_pragmas [2 ].content == 'end data'
11291129
1130+
1131+ @pytest .mark .parametrize ('frontend' , available_frontends ())
1132+ def test_single_column_coalesced_hoist_nested_inline_openacc (frontend , horizontal , vertical , blocking ):
1133+ """
1134+ Test the correct addition of OpenACC pragmas to SCC format code
1135+ when hoisting array temporaries to driver.
1136+ """
1137+
1138+ fcode_driver = """
1139+ SUBROUTINE column_driver(nlon, nz, q, nb)
1140+ INTEGER, INTENT(IN) :: nlon, nz, nb ! Size of the horizontal and vertical
1141+ REAL, INTENT(INOUT) :: q(nlon,nz,nb)
1142+ INTEGER :: b, start, end
1143+
1144+ start = 1
1145+ end = nlon
1146+ do b=1, nb
1147+ call compute_column(start, end, nlon, nz, q(:,:,b))
1148+ end do
1149+ END SUBROUTINE column_driver
1150+ """
1151+
1152+ fcode_outer_kernel = """
1153+ SUBROUTINE compute_column(start, end, nlon, nz, q)
1154+ INTEGER, INTENT(IN) :: start, end ! Iteration indices
1155+ INTEGER, INTENT(IN) :: nlon, nz ! Size of the horizontal and vertical
1156+ REAL, INTENT(INOUT) :: q(nlon,nz)
1157+ INTEGER :: jl, jk
1158+ REAL :: c
1159+
1160+ c = 5.345
1161+ DO JL = START, END
1162+ Q(JL, NZ) = Q(JL, NZ) + 1.0
1163+ END DO
1164+
1165+ !$loki inline
1166+ call update_q(start, end, nlon, nz, q, c)
1167+
1168+ DO JL = START, END
1169+ Q(JL, NZ) = Q(JL, NZ) * C
1170+ END DO
1171+ END SUBROUTINE compute_column
1172+ """
1173+
1174+ fcode_inner_kernel = """
1175+ SUBROUTINE update_q(start, end, nlon, nz, q, c)
1176+ INTEGER, INTENT(IN) :: start, end ! Iteration indices
1177+ INTEGER, INTENT(IN) :: nlon, nz ! Size of the horizontal and vertical
1178+ REAL, INTENT(INOUT) :: q(nlon,nz)
1179+ REAL, INTENT(IN) :: c
1180+ REAL :: t(nlon,nz)
1181+ INTEGER :: jl, jk
1182+
1183+ DO jk = 2, nz
1184+ DO jl = start, end
1185+ t(jl, jk) = c * jk
1186+ q(jl, jk) = q(jl, jk-1) + t(jl, jk) * c
1187+ END DO
1188+ END DO
1189+ END SUBROUTINE update_q
1190+ """
1191+
1192+ # Mimic the scheduler internal mechanis to apply the transformation cascade
1193+ outer_kernel_source = Sourcefile .from_source (fcode_outer_kernel , frontend = frontend )
1194+ inner_kernel_source = Sourcefile .from_source (fcode_inner_kernel , frontend = frontend )
1195+ driver_source = Sourcefile .from_source (fcode_driver , frontend = frontend )
1196+ driver = driver_source ['column_driver' ]
1197+ outer_kernel = outer_kernel_source ['compute_column' ]
1198+ inner_kernel = inner_kernel_source ['update_q' ]
1199+ outer_kernel .enrich (inner_kernel ) # Attach kernel source to driver call
1200+ driver .enrich (outer_kernel ) # Attach kernel source to driver call
1201+
1202+ driver_item = ProcedureItem (name = '#column_driver' , source = driver )
1203+ outer_kernel_item = ProcedureItem (name = '#compute_column' , source = outer_kernel )
1204+ inner_kernel_item = ProcedureItem (name = '#update_q' , source = inner_kernel )
1205+
1206+ scc_hoist = SCCHoistPipeline (
1207+ horizontal = horizontal , block_dim = blocking ,
1208+ dim_vars = (vertical .size ,), directive = 'openacc'
1209+ )
1210+
1211+ InlineTransformation (allowed_aliases = horizontal .index ).apply (outer_kernel )
1212+
1213+ # Apply in reverse order to ensure hoisting analysis gets run on kernel first
1214+ scc_hoist .apply (inner_kernel , role = 'kernel' , item = inner_kernel_item )
1215+ scc_hoist .apply (
1216+ outer_kernel , role = 'kernel' , item = outer_kernel_item ,
1217+ targets = ['compute_q' ], successors = ()
1218+ )
1219+ scc_hoist .apply (
1220+ driver , role = 'driver' , item = driver_item ,
1221+ targets = ['compute_column' ], successors = (outer_kernel_item ,)
1222+ )
1223+
1224+ # Ensure calls have correct arguments
1225+ # driver
1226+ calls = FindNodes (CallStatement ).visit (driver .body )
1227+ assert len (calls ) == 1
1228+ assert calls [0 ].arguments == ('start' , 'end' , 'nlon' , 'nz' , 'q(:, :, b)' ,
1229+ 'compute_column_t(:, :, b)' )
1230+
1231+ # Ensure a single outer parallel loop in driver
1232+ with pragmas_attached (driver , Loop ):
1233+ driver_loops = FindNodes (Loop ).visit (driver .body )
1234+ assert len (driver_loops ) == 1
1235+ assert driver_loops [0 ].variable == 'b'
1236+ assert driver_loops [0 ].bounds == '1:nb'
1237+ assert driver_loops [0 ].pragma [0 ].keyword == 'acc'
1238+ assert driver_loops [0 ].pragma [0 ].content == 'parallel loop gang vector_length(nlon)'
1239+
1240+ # Ensure we have a kernel call in the driver loop
1241+ kernel_calls = FindNodes (CallStatement ).visit (driver_loops [0 ])
1242+ assert len (kernel_calls ) == 1
1243+ assert kernel_calls [0 ].name == 'compute_column'
1244+
1245+ # Ensure that the intermediate kernel contains two wrapped loops and an unwrapped call statement
1246+ with pragmas_attached (outer_kernel , Loop ):
1247+ outer_kernel_loops = FindNodes (Loop ).visit (outer_kernel .body )
1248+ assert len (outer_kernel_loops ) == 2
1249+ assert outer_kernel_loops [0 ].variable == 'jl'
1250+ assert outer_kernel_loops [0 ].bounds == 'start:end'
1251+ assert outer_kernel_loops [0 ].pragma [0 ].keyword == 'acc'
1252+ assert outer_kernel_loops [0 ].pragma [0 ].content == 'loop vector'
1253+
1254+ # check correctly nested vertical loop from inlined routine
1255+ assert outer_kernel_loops [1 ] in FindNodes (Loop ).visit (outer_kernel_loops [0 ].body )
1256+
1257+ # Ensure the call was inlined
1258+ assert not FindNodes (CallStatement ).visit (outer_kernel .body )
1259+
1260+ # Ensure the routine has been marked properly
1261+ outer_kernel_pragmas = FindNodes (Pragma ).visit (outer_kernel .ir )
1262+ assert len (outer_kernel_pragmas ) == 3
1263+ assert outer_kernel_pragmas [0 ].keyword == 'acc'
1264+ assert outer_kernel_pragmas [0 ].content == 'routine vector'
1265+ assert outer_kernel_pragmas [1 ].keyword == 'acc'
1266+ assert outer_kernel_pragmas [1 ].content == 'data present(q, t)'
1267+ assert outer_kernel_pragmas [2 ].keyword == 'acc'
1268+ assert outer_kernel_pragmas [2 ].content == 'end data'
1269+
1270+
11301271@pytest .mark .parametrize ('frontend' , available_frontends ())
11311272def test_single_column_coalesced_nested (frontend , horizontal , blocking ):
11321273 """
0 commit comments