Skip to content

Commit 84eea6c

Browse files
authored
Merge pull request #295 from ecmwf-ifs/naan-derived-type-inheritance
Derived-type inheritance
2 parents 9ba0b96 + 7650790 commit 84eea6c

File tree

2 files changed

+123
-1
lines changed

2 files changed

+123
-1
lines changed

loki/ir/nodes.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1529,13 +1529,31 @@ def __post_init__(self, parent=None):
15291529
def ir(self):
15301530
return self.body
15311531

1532+
@property
1533+
def parent_type(self):
1534+
if not self.extends:
1535+
return None
1536+
if not self.parent:
1537+
return BasicType.DEFERRED
1538+
return self.parent.symbol_attrs[self.extends].dtype.typedef
1539+
15321540
@property
15331541
def declarations(self):
1534-
return tuple(
1542+
decls = tuple(
15351543
c for c in as_tuple(self.body)
15361544
if isinstance(c, (VariableDeclaration, ProcedureDeclaration))
15371545
)
15381546

1547+
# Inherit non-overriden symbols from parent type
1548+
if (parent_type := self.parent_type) and parent_type is not BasicType.DEFERRED:
1549+
local_symbols = [s for decl in decls for s in decl.symbols]
1550+
for decl in parent_type.declarations:
1551+
decl_symbols = tuple(s.clone(scope=self) for s in decl.symbols if s not in local_symbols)
1552+
if decl_symbols:
1553+
decls += (decl.clone(symbols=decl_symbols),)
1554+
1555+
return decls
1556+
15391557
@property
15401558
def comments(self):
15411559
return tuple(c for c in as_tuple(self.body) if isinstance(c, Comment))

loki/tests/test_derived_types.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# granted to it by virtue of its status as an intergovernmental organisation
66
# nor does it submit to any jurisdiction.
77

8+
# pylint: disable=too-many-lines
89
from sys import getrecursionlimit
910
from inspect import stack
1011

@@ -1415,3 +1416,106 @@ def test_derived_types_abstract_deferred_procedure(frontend):
14151416

14161417
assert typedef.imported_symbols == ()
14171418
assert not typedef.imported_symbol_map
1419+
1420+
1421+
@pytest.mark.parametrize('frontend', available_frontends())
1422+
def test_derived_type_symbol_inheritance(frontend):
1423+
fcode = """
1424+
module some_mod
1425+
implicit none
1426+
type :: base_type
1427+
integer :: memberA
1428+
real :: memberB
1429+
contains
1430+
procedure :: init => init_base_type
1431+
procedure :: final => final_base_type
1432+
procedure :: copy
1433+
end type base_type
1434+
1435+
type, extends(base_type) :: extended_type
1436+
integer :: memberC
1437+
contains
1438+
procedure :: init => init_extended_type
1439+
procedure :: final => final_extended_type
1440+
procedure :: do_something
1441+
end type extended_type
1442+
1443+
type, extends(extended_type) :: extended_extended_type
1444+
integer :: memberD
1445+
contains
1446+
procedure :: init => init_extended_extended_type
1447+
procedure :: final => final_extended_extended_type
1448+
procedure :: do_something => do_something_else
1449+
end type extended_extended_type
1450+
1451+
contains
1452+
1453+
subroutine init_base_type(self)
1454+
class(base_type) :: self
1455+
end subroutine init_base_type
1456+
subroutine final_base_type(self)
1457+
class(base_type) :: self
1458+
end subroutine final_base_type
1459+
subroutine copy(self)
1460+
class(base_type) :: self
1461+
end subroutine copy
1462+
1463+
subroutine init_extended_type(self)
1464+
class(extended_type) :: self
1465+
end subroutine init_extended_type
1466+
subroutine final_extended_type(self)
1467+
class(extended_type) :: self
1468+
end subroutine final_extended_type
1469+
subroutine do_something(self)
1470+
class(extended_type) :: self
1471+
end subroutine do_something
1472+
1473+
subroutine init_extended_extended_type(self)
1474+
class(extended_extended_type) :: self
1475+
end subroutine init_extended_extended_type
1476+
subroutine final_extended_extended_type(self)
1477+
class(extended_extended_type) :: self
1478+
end subroutine final_extended_extended_type
1479+
subroutine do_something_else(self)
1480+
class(extended_extended_type) :: self
1481+
end subroutine do_something_else
1482+
end module some_mod
1483+
""".strip()
1484+
1485+
module = Module.from_source(fcode, frontend=frontend)
1486+
1487+
base_type = module['base_type']
1488+
extended_type = module['extended_type']
1489+
extended_extended_type = module['extended_extended_type']
1490+
1491+
assert base_type.variables == ('memberA', 'memberB', 'init', 'final', 'copy')
1492+
assert base_type.variables[2].type.bind_names[0] == 'init_base_type'
1493+
assert base_type.variables[3].type.bind_names[0] == 'final_base_type'
1494+
assert not base_type.variables[4].type.bind_names
1495+
assert all(s.scope is base_type for d in base_type.declarations for s in d.symbols)
1496+
assert base_type.imported_symbols == ()
1497+
assert not base_type.imported_symbol_map
1498+
1499+
assert extended_type.variables == ('memberC', 'init', 'final', 'do_something', 'memberA', 'memberB', 'copy')
1500+
assert extended_type.variables[1].type.bind_names[0] == 'init_extended_type'
1501+
assert extended_type.variables[2].type.bind_names[0] == 'final_extended_type'
1502+
assert not extended_type.variables[3].type.bind_names
1503+
assert not extended_type.variables[6].type.bind_names
1504+
assert all(s.scope is extended_type for d in extended_type.declarations for s in d.symbols)
1505+
assert extended_type.imported_symbols == ()
1506+
assert not extended_type.imported_symbol_map
1507+
#check for non-empty declarations
1508+
assert all(decl.symbols for decl in extended_type.declarations)
1509+
1510+
1511+
assert extended_extended_type.variables == ('memberD', 'init', 'final', 'do_something', 'memberC',
1512+
'memberA', 'memberB', 'copy')
1513+
assert extended_extended_type.variables[1].type.bind_names[0] == 'init_extended_extended_type'
1514+
assert extended_extended_type.variables[2].type.bind_names[0] == 'final_extended_extended_type'
1515+
assert extended_extended_type.variables[3].type.bind_names[0] == 'do_something_else'
1516+
assert not extended_extended_type.variables[7].type.bind_names
1517+
assert all(s.scope is extended_extended_type for d in extended_extended_type.declarations for s in d.symbols)
1518+
assert extended_extended_type.imported_symbols == ()
1519+
assert not extended_extended_type.imported_symbol_map
1520+
#check for non-empty declarations
1521+
assert all(decl.symbols for decl in extended_extended_type.declarations)

0 commit comments

Comments
 (0)