Skip to content

Commit a60479e

Browse files
committed
parse_expr: allow for nested derived types
1 parent ef5b5cc commit a60479e

File tree

2 files changed

+64
-23
lines changed

2 files changed

+64
-23
lines changed

loki/expression/parser.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def map_slice(self, expr, *args, **kwargs):
122122

123123
def map_variable(self, expr, *args, **kwargs):
124124
parent = kwargs.pop('parent', None)
125-
return sym.Variable(name=expr.name, parent=parent) # , **kwargs)
125+
return sym.Variable(name=expr.name, parent=parent)
126126

127127
def map_algebraic_leaf(self, expr, *args, **kwargs):
128128
if str(expr).isnumeric():
@@ -255,30 +255,39 @@ def map_call_with_kwargs(self, expr):
255255
return self.rec(expr.function)(*args, **kwargs)
256256

257257
def map_lookup(self, expr):
258-
try:
259-
if isinstance(expr.name, pmbl.Variable):
260-
name = expr.name.name
261-
return self.case_insensitive_getattr(self.rec(expr.aggregate), name)
262-
if isinstance(expr.name, pmbl.Call):
263-
name = expr.name.function.name
264-
if callable(self.case_insensitive_getattr(self.rec(expr.aggregate), name)):
265-
return self.case_insensitive_getattr(self.rec(expr.aggregate),
266-
name)(*[self.rec(par) for par in expr.name.parameters])
267-
return self._evaluate_array(self.case_insensitive_getattr(self.rec(expr.aggregate), name),
268-
[self.rec(par) for par in expr.name.parameters])
269-
if isinstance(expr.name, pmbl.CallWithKwargs):
270-
name = expr.name.function.name
271-
args = [self.rec(par) for par in expr.name.parameters]
272-
kwargs = {
273-
k: self.rec(v)
274-
for k, v in expr.name.kw_parameters.items()}
275-
kwargs = CaseInsensitiveDict(kwargs)
276-
return self.case_insensitive_getattr(self.rec(expr.aggregate), name)(*args, **kwargs)
277258

259+
def rec_lookup(expr, obj, name):
260+
return expr.name, self.case_insensitive_getattr(obj, name)
261+
262+
try:
263+
current_expr = expr
264+
obj = self.rec(expr.aggregate)
265+
while isinstance(current_expr.name, pmbl.Lookup):
266+
current_expr, obj = rec_lookup(current_expr, obj, current_expr.name.aggregate.name)
267+
if isinstance(current_expr.name, pmbl.Variable):
268+
_, obj = rec_lookup(current_expr, obj, current_expr.name.name)
269+
return obj
270+
if isinstance(current_expr.name, pmbl.Call):
271+
name = current_expr.name.function.name
272+
_, obj = rec_lookup(current_expr, obj, name)
273+
if callable(obj):
274+
return obj(*[self.rec(par) for par in current_expr.name.parameters])
275+
return self._evaluate_array(obj, [self.rec(par) for par in current_expr.name.parameters])
276+
if isinstance(current_expr.name, pmbl.CallWithKwargs):
277+
name = current_expr.name.function.name
278+
_, obj = rec_lookup(current_expr, obj, name)
279+
args = [self.rec(par) for par in current_expr.name.parameters]
280+
kwargs = CaseInsensitiveDict(
281+
(k, self.rec(v))
282+
for k, v in current_expr.name.kw_parameters.items()
283+
)
284+
return obj(*args, **kwargs)
278285
except Exception as e:
279286
if self.strict:
280287
raise e
281288
return expr
289+
if self.strict:
290+
raise NotImplementedError
282291
return expr
283292

284293

loki/expression/tests/test_expression.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import numpy as np
1414

1515
import pymbolic.primitives as pmbl
16+
import pymbolic.mapper as pmbl_mapper
1617

1718
from loki import (
1819
Sourcefile, Subroutine, Module, Scope, BasicType,
@@ -2109,7 +2110,7 @@ def test_expression_parser_evaluate(case):
21092110
context = {'a': 6}
21102111
test_str = '1 + 1 + a + some_func(a, 10)'
21112112
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
2112-
with pytest.raises(Exception):
2113+
with pytest.raises(pmbl_mapper.evaluator.UnknownVariableError):
21132114
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, strict=True, context=context)
21142115
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, strict=False, context=context)
21152116
assert str(parsed).lower().replace(' ', '') == '8+some_func(6,10)'
@@ -2134,7 +2135,23 @@ def some_func(a, b, c=None):
21342135
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
21352136
assert str(parsed).lower().replace(' ', '') == '(13+c+1)/(c+1)'
21362137

2138+
class BarBarBar:
2139+
val_barbarbar = 5
2140+
2141+
class BarBar:
2142+
barbarbar = BarBarBar()
2143+
val_barbar = -3
2144+
def barbar_func(self, a):
2145+
return a - 1
2146+
2147+
class Bar:
2148+
barbar = BarBar()
2149+
val_bar = 5
2150+
def bar_func(self, a):
2151+
return a**2
2152+
21372153
class Foo:
2154+
bar = Bar() # pylint: disable=disallowed-name
21382155
val3 = 1
21392156
arr = [[1, 2], [3, 4]]
21402157
def __init__(self, _val1, _val2):
@@ -2150,18 +2167,33 @@ def static_func(a):
21502167
test_str = 'foo%val1 + foo%val2 + foo%val3'
21512168
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case))
21522169
assert str(parsed).lower().replace(' ', '') == 'foo%val1+foo%val2+foo%val3'
2153-
with pytest.raises(Exception):
2170+
with pytest.raises(pmbl_mapper.evaluator.UnknownVariableError):
21542171
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, strict=True)
21552172
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
21562173
assert parsed == 6
21572174
test_str = 'foo%val1 + foo%some_func(1, 2) + foo%static_func_2(3)'
21582175
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
21592176
assert str(parsed).lower().replace(' ', '') == '5+foo%static_func_2(3)'
2160-
with pytest.raises(Exception):
2177+
with pytest.raises(pmbl_mapper.evaluator.UnknownVariableError):
21612178
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, strict=True)
21622179
test_str = 'foo%val1 + foo%some_func(1, 2) + foo%static_func(3) + foo%arr(1, 2)'
21632180
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context, strict=True)
21642181
assert parsed == 13
21652182
test_str = 'foo%val1 + foo%some_func(1, b=2) + foo%static_func(a=3) + foo%arr(1, 2)'
21662183
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context, strict=True)
21672184
assert parsed == 13
2185+
test_str = 'foo%bar%val_bar + 1'
2186+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
2187+
assert parsed == 6
2188+
test_str = 'foo%bar%bar_func(2) + 1'
2189+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
2190+
assert parsed == 5
2191+
test_str = 'foo%bar%barbar%val_barbar + 1'
2192+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
2193+
assert parsed == -2
2194+
test_str = 'foo%bar%barbar%barbar_func(0) + 1'
2195+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
2196+
assert parsed == 0
2197+
test_str = 'foo%bar%barbar%barbarbar%val_barbarbar + 1'
2198+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
2199+
assert parsed == 6

0 commit comments

Comments
 (0)