Skip to content

Commit 78d9f45

Browse files
committed
PR fixes to hoist/inline trafos
1 parent 9836069 commit 78d9f45

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

loki/transformations/hoist_variables.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,9 @@ def transform_subroutine(self, routine, **kwargs):
142142
variables = self.find_variables(routine)
143143
item.trafo_data[self._key]["to_hoist"] = variables
144144
dims = flatten([getattr(v, 'shape', []) for v in variables])
145+
import_map = routine.import_map
145146
item.trafo_data[self._key]["imported_sizes"] = [(d.type.module, d) for d in dims
146-
if str(d) in routine.import_map]
147+
if str(d) in import_map]
147148
item.trafo_data[self._key]["hoist_variables"] = [var.clone(name=f'{routine.name}_{var.name}')
148149
for var in variables]
149150
else:
@@ -281,8 +282,9 @@ def transform_subroutine(self, routine, **kwargs):
281282

282283
# Add imports used to define hoisted
283284
missing_imports_map = defaultdict(set)
285+
import_map = routine.import_map
284286
for module, var in item.trafo_data[self._key]["imported_sizes"]:
285-
if not var.name in routine.import_map:
287+
if not var.name in import_map:
286288
missing_imports_map[module] |= {var}
287289

288290
if missing_imports_map:

loki/transformations/inline.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ def map_inline_call(self, expr, *args, **kwargs):
199199

200200
def resolve_sequence_association_for_inlined_calls(routine, inline_internals, inline_marked):
201201
"""
202-
Resolve sequence association in calls to all member procedures (if `inline_internals = True`)
203-
or in calls to procedures that have been marked with an inline pragma (if `inline_marked = True`).
204-
If both `inline_internals` and `inline_marked` are `False`, no processing is done.
202+
Resolve sequence association in calls to all member procedures (if ``inline_internals = True``)
203+
or in calls to procedures that have been marked with an inline pragma (if ``inline_marked = True``).
204+
If both ``inline_internals`` and ``inline_marked`` are ``False``, no processing is done.
205205
"""
206206
call_map = {}
207207
with pragmas_attached(routine, node_type=CallStatement):
@@ -636,10 +636,9 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
636636
intf_symbols = routine.interface_symbols
637637
for callee in call_sets.keys():
638638
for intf in callee.interfaces:
639-
for b in intf.body:
640-
s = getattr(b, 'procedure_symbol', None)
639+
for s in intf.symbols:
641640
if not s in intf_symbols:
642-
new_intfs += [b,]
641+
new_intfs += [s.type.dtype.procedure,]
643642

644643
if new_intfs:
645644
routine.spec.append(Interface(body=as_tuple(new_intfs)))

0 commit comments

Comments
 (0)