@@ -557,6 +557,11 @@ def test_hoist_mixed_variable_declarations(tmp_path, frontend, config):
557557 subroutine kernel(start, end, klon, klev, nclv, field1, field2)
558558 use, intrinsic :: iso_c_binding, only : c_size_t
559559 implicit none
560+ interface
561+ subroutine another_kernel(klev)
562+ integer, intent(in) :: klev
563+ end subroutine another_kernel
564+ end interface
560565 integer, parameter :: jprb = selected_real_kind(13,300)
561566 integer, intent(in) :: nclv
562567 integer, intent(in) :: start, end, klon, klev
@@ -582,19 +587,38 @@ def test_hoist_mixed_variable_declarations(tmp_path, frontend, config):
582587 tmp5(jl, jm, :) = field1(jl)
583588 enddo
584589 enddo
590+
591+ call another_kernel(klev)
585592 end subroutine kernel
586593end module kernel_mod
587594 """ .strip ()
595+ fcode_mod = """
596+ module size_mod
597+ implicit none
598+ integer :: n
599+ end module size_mod
600+ """ .strip ()
601+ fcode_another_kernel = """
602+ subroutine another_kernel(klev)
603+ use size_mod, only : n
604+ implicit none
605+ integer, intent(in) :: klev
606+ real :: another_tmp(klev,n)
607+ end subroutine another_kernel
608+ """ .strip ()
588609
589610 (tmp_path / 'driver.F90' ).write_text (fcode_driver )
590611 (tmp_path / 'kernel_mod.F90' ).write_text (fcode_kernel )
612+ (tmp_path / 'size_mod.F90' ).write_text (fcode_mod )
613+ (tmp_path / 'another_kernel.F90' ).write_text (fcode_another_kernel )
591614
592615 config = {
593616 'default' : {
594617 'mode' : 'idem' ,
595618 'role' : 'kernel' ,
596619 'expand' : True ,
597- 'strict' : True
620+ 'strict' : True ,
621+ 'enable_imports' : True
598622 },
599623 'routines' : {
600624 'driver' : {'role' : 'driver' }
@@ -613,11 +637,12 @@ def test_hoist_mixed_variable_declarations(tmp_path, frontend, config):
613637 driver_variables = (
614638 'jprb' , 'nlon' , 'nz' , 'nb' , 'b' ,
615639 'field1(nlon, nb)' , 'field2(nlon, nz, nb)' ,
616- 'kernel_tmp2(:,:)' , 'kernel_tmp5(:,:,:)'
640+ 'kernel_tmp2(:,:)' , 'kernel_tmp5(:,:,:)' , 'another_kernel_another_tmp(:,:)'
617641 )
618642 kernel_arguments = (
619643 'start' , 'end' , 'klon' , 'klev' , 'nclv' ,
620- 'field1(klon)' , 'field2(klon,klev)' , 'tmp2(klon,klev)' , 'tmp5(klon,nclv,klev)'
644+ 'field1(klon)' , 'field2(klon,klev)' , 'tmp2(klon,klev)' , 'tmp5(klon,nclv,klev)' ,
645+ 'another_kernel_another_tmp(klev,n)'
621646 )
622647
623648 # Check hoisting and declaration in driver
@@ -629,8 +654,17 @@ def test_hoist_mixed_variable_declarations(tmp_path, frontend, config):
629654 assert len (calls ) == 1
630655 assert calls [0 ].arguments == (
631656 '1' , 'nlon' , 'nlon' , 'nz' , '2' , 'field1(:,b)' , 'field2(:,:,b)' ,
632- 'kernel_tmp2' , 'kernel_tmp5'
657+ 'kernel_tmp2' , 'kernel_tmp5' , 'another_kernel_another_tmp'
633658 )
634659
635660 # Check that fgen works
636661 assert scheduler ['kernel_mod#kernel' ].source .to_fortran ()
662+
663+ # Check that imports were updated
664+ imports = FindNodes (ir .Import ).visit (scheduler ['kernel_mod#kernel' ].ir .spec )
665+ assert len (imports ) == 2
666+ assert 'n' in scheduler ['kernel_mod#kernel' ].ir .imported_symbols
667+
668+ imports = FindNodes (ir .Import ).visit (scheduler ['#driver' ].ir .spec )
669+ assert len (imports ) == 2
670+ assert 'n' in scheduler ['#driver' ].ir .imported_symbols
0 commit comments