Skip to content

Commit ef5b5cc

Browse files
committed
parse_expr: proper handling of derived type(s) variables includig evaluation and evaluation for arrays
1 parent 47e435d commit ef5b5cc

File tree

2 files changed

+125
-7
lines changed

2 files changed

+125
-7
lines changed

loki/expression/parser.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
import re
1010
import math
1111
import pytools.lex
12+
import numpy as np
1213
from pymbolic.parser import Parser as ParserBase # , FinalizedTuple
1314
from pymbolic.mapper import Mapper
1415
import pymbolic.primitives as pmbl
1516
from pymbolic.mapper.evaluator import EvaluationMapper
1617
from pymbolic.parser import (
1718
_openpar, _closepar, _minus, FinalizedTuple, _PREC_UNARY,
18-
_PREC_TIMES, _PREC_PLUS, _times, _plus
19+
_PREC_TIMES, _PREC_PLUS, _PREC_CALL, _times, _plus
1920
)
2021
try:
2122
from fparser.two.Fortran2003 import Intrinsic_Name
@@ -161,10 +162,15 @@ def map_tuple(self, expr, *args, **kwargs):
161162
def map_list(self, expr, *args, **kwargs):
162163
return sym.LiteralList([self.rec(elem, *args, **kwargs) for elem in expr])
163164

164-
# hijack 'pymbolic.Remainder' to construct DerivedTypes ...
165165
def map_remainder(self, expr, *args, **kwargs):
166-
parent = self.rec(expr.numerator)
167-
return self.rec(expr.denominator, parent=parent)
166+
# this should never happen as '%' is overwritten to represent derived types
167+
raise NotImplementedError
168+
169+
def map_lookup(self, expr, *args, **kwargs):
170+
# construct derived type(s) variables
171+
parent = kwargs.pop('parent', None)
172+
parent = self.rec(expr.aggregate, parent=parent)
173+
return self.rec(expr.name, parent=parent)
168174

169175

170176
class LokiEvaluationMapper(EvaluationMapper):
@@ -178,6 +184,16 @@ class LokiEvaluationMapper(EvaluationMapper):
178184
Raise exception for unknown symbols/expressions (default: `False`).
179185
"""
180186

187+
@staticmethod
188+
def case_insensitive_getattr(obj, attr):
189+
"""
190+
Case-insensitive version of `getattr`.
191+
"""
192+
for elem in dir(obj):
193+
if elem.lower() == attr.lower():
194+
return getattr(obj, elem)
195+
return getattr(obj, attr)
196+
181197
def __init__(self, strict=False, **kwargs):
182198
self.strict = strict
183199
super().__init__(**kwargs)
@@ -198,6 +214,15 @@ def map_variable(self, expr):
198214
return super().map_variable(expr)
199215
return expr
200216

217+
@staticmethod
218+
def _evaluate_array(arr, dims):
219+
"""
220+
Evaluate arrays by converting to numpy array and
221+
adapting the dimensions corresponding to the different
222+
starting index.
223+
"""
224+
return np.array(arr, order='F').item(*[dim-1 for dim in dims])
225+
201226
def map_call(self, expr):
202227
if expr.function.name.lower() == 'min':
203228
return min(self.rec(par) for par in expr.parameters)
@@ -216,8 +241,46 @@ def map_call(self, expr):
216241
return math.sqrt(float([self.rec(par) for par in expr.parameters][0]))
217242
if expr.function.name.lower() == 'exp':
218243
return math.exp(float([self.rec(par) for par in expr.parameters][0]))
244+
if expr.function.name in self.context and not callable(self.context[expr.function.name]):
245+
return self._evaluate_array(self.context[expr.function.name],
246+
[self.rec(par) for par in expr.parameters])
219247
return super().map_call(expr)
220248

249+
def map_call_with_kwargs(self, expr):
250+
args = [self.rec(par) for par in expr.parameters]
251+
kwargs = {
252+
k: self.rec(v)
253+
for k, v in expr.kw_parameters.items()}
254+
kwargs = CaseInsensitiveDict(kwargs)
255+
return self.rec(expr.function)(*args, **kwargs)
256+
257+
def map_lookup(self, expr):
258+
try:
259+
if isinstance(expr.name, pmbl.Variable):
260+
name = expr.name.name
261+
return self.case_insensitive_getattr(self.rec(expr.aggregate), name)
262+
if isinstance(expr.name, pmbl.Call):
263+
name = expr.name.function.name
264+
if callable(self.case_insensitive_getattr(self.rec(expr.aggregate), name)):
265+
return self.case_insensitive_getattr(self.rec(expr.aggregate),
266+
name)(*[self.rec(par) for par in expr.name.parameters])
267+
return self._evaluate_array(self.case_insensitive_getattr(self.rec(expr.aggregate), name),
268+
[self.rec(par) for par in expr.name.parameters])
269+
if isinstance(expr.name, pmbl.CallWithKwargs):
270+
name = expr.name.function.name
271+
args = [self.rec(par) for par in expr.name.parameters]
272+
kwargs = {
273+
k: self.rec(v)
274+
for k, v in expr.name.kw_parameters.items()}
275+
kwargs = CaseInsensitiveDict(kwargs)
276+
return self.case_insensitive_getattr(self.rec(expr.aggregate), name)(*args, **kwargs)
277+
278+
except Exception as e:
279+
if self.strict:
280+
raise e
281+
return expr
282+
return expr
283+
221284

222285
class ExpressionParser(ParserBase):
223286
"""
@@ -288,6 +351,7 @@ class ExpressionParser(ParserBase):
288351
_f_string = intern("f_string")
289352
_f_openbracket = intern("openbracket")
290353
_f_closebracket = intern("closebracket")
354+
_f_derived_type = intern("dot")
291355

292356
lex_table = [
293357
(_f_true, pytools.lex.RE(r"\.true\.", re.IGNORECASE)),
@@ -307,6 +371,7 @@ class ExpressionParser(ParserBase):
307371
pytools.lex.RE(r"\'.*\'", re.IGNORECASE))),
308372
(_f_openbracket, pytools.lex.RE(r"\(/")),
309373
(_f_closebracket, pytools.lex.RE(r"/\)")),
374+
(_f_derived_type, pytools.lex.RE(r"\%")),
310375
] + ParserBase.lex_table
311376
"""
312377
Extend :any:`pymbolic.parser.Parser.lex_table` to accomodate for Fortran specifix syntax/expressions.
@@ -372,7 +437,12 @@ def parse_prefix(self, pstate):
372437
def parse_postfix(self, pstate, min_precedence, left_exp):
373438

374439
did_something = False
375-
if pstate.is_next(_times) and _PREC_TIMES > min_precedence:
440+
if pstate.is_next(self._f_derived_type) and _PREC_CALL > min_precedence:
441+
pstate.advance()
442+
right_exp = self.parse_expression(pstate, _PREC_PLUS)
443+
left_exp = pmbl.Lookup(left_exp, right_exp)
444+
did_something = True
445+
elif pstate.is_next(_times) and _PREC_TIMES > min_precedence:
376446
pstate.advance()
377447
right_exp = self.parse_expression(pstate, _PREC_PLUS)
378448
# NECESSARY to ensure correct ordering!

loki/expression/tests/test_expression.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,6 +1891,10 @@ def to_str(_parsed):
18911891
assert isinstance(parsed, sym.DeferredTypeSymbol)
18921892
assert isinstance(parsed.parent, sym.DeferredTypeSymbol)
18931893
assert to_str(parsed) == 'some_type%val'
1894+
parsed = parse_expr(convert_to_case('-some_type%val', mode=case))
1895+
assert isinstance(parsed, sym.Product)
1896+
assert isinstance(parsed.children[1].parent, sym.DeferredTypeSymbol)
1897+
assert to_str(parsed) == '-some_type%val'
18941898
parsed = parse_expr(convert_to_case('some_type%another_type%val', mode=case))
18951899
assert isinstance(parsed, sym.DeferredTypeSymbol)
18961900
assert isinstance(parsed.parent, sym.DeferredTypeSymbol)
@@ -2097,6 +2101,11 @@ def test_expression_parser_evaluate(case):
20972101
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
20982102
assert parsed == 1
20992103

2104+
context = {'arr': [[1, 2], [3, 4]]}
2105+
test_str = '1 + arr(1, 2)'
2106+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
2107+
assert parsed == 3
2108+
21002109
context = {'a': 6}
21012110
test_str = '1 + 1 + a + some_func(a, 10)'
21022111
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
@@ -2105,15 +2114,54 @@ def test_expression_parser_evaluate(case):
21052114
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, strict=False, context=context)
21062115
assert str(parsed).lower().replace(' ', '') == '8+some_func(6,10)'
21072116

2108-
def some_func(a, b):
2109-
return a + b
2117+
def some_func(a, b, c=None):
2118+
if c is None:
2119+
return a + b
2120+
return a + b + c
21102121

21112122
context = {'a': 6, 'some_func': some_func}
21122123
test_str = '1 + 1 + a + some_func(a, 10)'
21132124
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
21142125
assert parsed == 24
21152126

2127+
context = {'a': 6, 'some_func': some_func}
2128+
test_str = '1 + 1 + a + some_func(a, 10, c=2)'
2129+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
2130+
assert parsed == 26
2131+
21162132
context = {'a': 6, 'b': 7}
21172133
test_str = '(a + b + c + 1)/(c + 1)'
21182134
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
21192135
assert str(parsed).lower().replace(' ', '') == '(13+c+1)/(c+1)'
2136+
2137+
class Foo:
2138+
val3 = 1
2139+
arr = [[1, 2], [3, 4]]
2140+
def __init__(self, _val1, _val2):
2141+
self.val1 = _val1
2142+
self.val2 = _val2
2143+
def some_func(self, a, b):
2144+
return a + b
2145+
@staticmethod
2146+
def static_func(a):
2147+
return 2*a
2148+
2149+
context = {'foo': Foo(2, 3)}
2150+
test_str = 'foo%val1 + foo%val2 + foo%val3'
2151+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case))
2152+
assert str(parsed).lower().replace(' ', '') == 'foo%val1+foo%val2+foo%val3'
2153+
with pytest.raises(Exception):
2154+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, strict=True)
2155+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
2156+
assert parsed == 6
2157+
test_str = 'foo%val1 + foo%some_func(1, 2) + foo%static_func_2(3)'
2158+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
2159+
assert str(parsed).lower().replace(' ', '') == '5+foo%static_func_2(3)'
2160+
with pytest.raises(Exception):
2161+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, strict=True)
2162+
test_str = 'foo%val1 + foo%some_func(1, 2) + foo%static_func(3) + foo%arr(1, 2)'
2163+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context, strict=True)
2164+
assert parsed == 13
2165+
test_str = 'foo%val1 + foo%some_func(1, b=2) + foo%static_func(a=3) + foo%arr(1, 2)'
2166+
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context, strict=True)
2167+
assert parsed == 13

0 commit comments

Comments
 (0)