Skip to content

Commit 5673795

Browse files
authored
Merge pull request #292 from ecmwf-ifs/nams_process_dimension_pragmas_parse_expr
Improve `parse_expr` and use in `process_dimension_pragmas`
2 parents 84eea6c + 2b8c5cb commit 5673795

File tree

10 files changed

+395
-51
lines changed

10 files changed

+395
-51
lines changed

loki/expression/expr_visitors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
"""
1212
from pymbolic.primitives import Expression
1313

14-
from loki.ir import Node, Visitor, Transformer
14+
from loki.ir.nodes import Node
15+
from loki.ir.visitor import Visitor
16+
from loki.ir.transformer import Transformer
1517
from loki.tools import flatten, as_tuple
1618
from loki.expression.mappers import (
1719
SubstituteExpressionsMapper, ExpressionRetriever, AttachScopesMapper

loki/expression/parser.py

Lines changed: 109 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
import re
1010
import math
1111
import pytools.lex
12+
import numpy as np
1213
from pymbolic.parser import Parser as ParserBase # , FinalizedTuple
1314
from pymbolic.mapper import Mapper
1415
import pymbolic.primitives as pmbl
1516
from pymbolic.mapper.evaluator import EvaluationMapper
1617
from pymbolic.parser import (
1718
_openpar, _closepar, _minus, FinalizedTuple, _PREC_UNARY,
18-
_PREC_TIMES, _PREC_PLUS, _times, _plus
19+
_PREC_TIMES, _PREC_PLUS, _PREC_CALL, _times, _plus
1920
)
2021
try:
2122
from fparser.two.Fortran2003 import Intrinsic_Name
@@ -109,14 +110,19 @@ def map_meta_symbol(self, expr, *args, **kwargs):
109110
map_array = map_meta_symbol
110111

111112
def map_slice(self, expr, *args, **kwargs):
112-
return sym.RangeIndex(tuple(self.rec(child, *args, **kwargs) for child in expr.children))
113+
children = tuple(self.rec(child, *args, **kwargs) if child is not None else child for child in expr.children)
114+
if len(children) == 1 and children[0] is None:
115+
# this corresponds to ':' (sym.RangeIndex((None, None)))
116+
children = (None, None)
117+
return sym.RangeIndex(children)
113118

114119
map_range = map_slice
115120
map_range_index = map_slice
116121
map_loop_range = map_slice
117122

118123
def map_variable(self, expr, *args, **kwargs):
119-
return sym.Variable(name=expr.name)
124+
parent = kwargs.pop('parent', None)
125+
return sym.Variable(name=expr.name, parent=parent)
120126

121127
def map_algebraic_leaf(self, expr, *args, **kwargs):
122128
if str(expr).isnumeric():
@@ -127,7 +133,12 @@ def map_algebraic_leaf(self, expr, *args, **kwargs):
127133
if expr.function.name.upper() in FORTRAN_INTRINSIC_PROCEDURES:
128134
return sym.InlineCall(function=sym.Variable(name=expr.function.name),
129135
parameters=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters))
130-
return sym.Variable(name=expr.function.name,
136+
parent = kwargs.pop('parent', None)
137+
dimensions = tuple(self.rec(param, *args, **kwargs) for param in expr.parameters)
138+
if not dimensions:
139+
return sym.InlineCall(function=sym.Variable(name=expr.function.name, parent=parent),
140+
parameters=dimensions)
141+
return sym.Variable(name=expr.function.name, parent=parent,
131142
dimensions=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters))
132143
try:
133144
return self.map_variable(expr, *args, **kwargs)
@@ -151,6 +162,16 @@ def map_tuple(self, expr, *args, **kwargs):
151162
def map_list(self, expr, *args, **kwargs):
152163
return sym.LiteralList([self.rec(elem, *args, **kwargs) for elem in expr])
153164

165+
def map_remainder(self, expr, *args, **kwargs):
166+
# this should never happen as '%' is overwritten to represent derived types
167+
raise NotImplementedError
168+
169+
def map_lookup(self, expr, *args, **kwargs):
170+
# construct derived type(s) variables
171+
parent = kwargs.pop('parent', None)
172+
parent = self.rec(expr.aggregate, parent=parent)
173+
return self.rec(expr.name, parent=parent)
174+
154175

155176
class LokiEvaluationMapper(EvaluationMapper):
156177
"""
@@ -163,6 +184,16 @@ class LokiEvaluationMapper(EvaluationMapper):
163184
Raise exception for unknown symbols/expressions (default: `False`).
164185
"""
165186

187+
@staticmethod
188+
def case_insensitive_getattr(obj, attr):
189+
"""
190+
Case-insensitive version of `getattr`.
191+
"""
192+
for elem in dir(obj):
193+
if elem.lower() == attr.lower():
194+
return getattr(obj, elem)
195+
return getattr(obj, attr)
196+
166197
def __init__(self, strict=False, **kwargs):
167198
self.strict = strict
168199
super().__init__(**kwargs)
@@ -183,6 +214,15 @@ def map_variable(self, expr):
183214
return super().map_variable(expr)
184215
return expr
185216

217+
@staticmethod
218+
def _evaluate_array(arr, dims):
219+
"""
220+
Evaluate arrays by converting to numpy array and
221+
adapting the dimensions corresponding to the different
222+
starting index.
223+
"""
224+
return np.array(arr, order='F').item(*[dim-1 for dim in dims])
225+
186226
def map_call(self, expr):
187227
if expr.function.name.lower() == 'min':
188228
return min(self.rec(par) for par in expr.parameters)
@@ -201,8 +241,55 @@ def map_call(self, expr):
201241
return math.sqrt(float([self.rec(par) for par in expr.parameters][0]))
202242
if expr.function.name.lower() == 'exp':
203243
return math.exp(float([self.rec(par) for par in expr.parameters][0]))
244+
if expr.function.name in self.context and not callable(self.context[expr.function.name]):
245+
return self._evaluate_array(self.context[expr.function.name],
246+
[self.rec(par) for par in expr.parameters])
204247
return super().map_call(expr)
205248

249+
def map_call_with_kwargs(self, expr):
250+
args = [self.rec(par) for par in expr.parameters]
251+
kwargs = {
252+
k: self.rec(v)
253+
for k, v in expr.kw_parameters.items()}
254+
kwargs = CaseInsensitiveDict(kwargs)
255+
return self.rec(expr.function)(*args, **kwargs)
256+
257+
def map_lookup(self, expr):
258+
259+
def rec_lookup(expr, obj, name):
260+
return expr.name, self.case_insensitive_getattr(obj, name)
261+
262+
try:
263+
current_expr = expr
264+
obj = self.rec(expr.aggregate)
265+
while isinstance(current_expr.name, pmbl.Lookup):
266+
current_expr, obj = rec_lookup(current_expr, obj, current_expr.name.aggregate.name)
267+
if isinstance(current_expr.name, pmbl.Variable):
268+
_, obj = rec_lookup(current_expr, obj, current_expr.name.name)
269+
return obj
270+
if isinstance(current_expr.name, pmbl.Call):
271+
name = current_expr.name.function.name
272+
_, obj = rec_lookup(current_expr, obj, name)
273+
if callable(obj):
274+
return obj(*[self.rec(par) for par in current_expr.name.parameters])
275+
return self._evaluate_array(obj, [self.rec(par) for par in current_expr.name.parameters])
276+
if isinstance(current_expr.name, pmbl.CallWithKwargs):
277+
name = current_expr.name.function.name
278+
_, obj = rec_lookup(current_expr, obj, name)
279+
args = [self.rec(par) for par in current_expr.name.parameters]
280+
kwargs = CaseInsensitiveDict(
281+
(k, self.rec(v))
282+
for k, v in current_expr.name.kw_parameters.items()
283+
)
284+
return obj(*args, **kwargs)
285+
except Exception as e:
286+
if self.strict:
287+
raise e
288+
return expr
289+
if self.strict:
290+
raise NotImplementedError
291+
return expr
292+
206293

207294
class ExpressionParser(ParserBase):
208295
"""
@@ -273,6 +360,7 @@ class ExpressionParser(ParserBase):
273360
_f_string = intern("f_string")
274361
_f_openbracket = intern("openbracket")
275362
_f_closebracket = intern("closebracket")
363+
_f_derived_type = intern("dot")
276364

277365
lex_table = [
278366
(_f_true, pytools.lex.RE(r"\.true\.", re.IGNORECASE)),
@@ -292,6 +380,7 @@ class ExpressionParser(ParserBase):
292380
pytools.lex.RE(r"\'.*\'", re.IGNORECASE))),
293381
(_f_openbracket, pytools.lex.RE(r"\(/")),
294382
(_f_closebracket, pytools.lex.RE(r"/\)")),
383+
(_f_derived_type, pytools.lex.RE(r"\%")),
295384
] + ParserBase.lex_table
296385
"""
297386
Extend :any:`pymbolic.parser.Parser.lex_table` to accomodate for Fortran specifix syntax/expressions.
@@ -357,7 +446,12 @@ def parse_prefix(self, pstate):
357446
def parse_postfix(self, pstate, min_precedence, left_exp):
358447

359448
did_something = False
360-
if pstate.is_next(_times) and _PREC_TIMES > min_precedence:
449+
if pstate.is_next(self._f_derived_type) and _PREC_CALL > min_precedence:
450+
pstate.advance()
451+
right_exp = self.parse_expression(pstate, _PREC_PLUS)
452+
left_exp = pmbl.Lookup(left_exp, right_exp)
453+
did_something = True
454+
elif pstate.is_next(_times) and _PREC_TIMES > min_precedence:
361455
pstate.advance()
362456
right_exp = self.parse_expression(pstate, _PREC_PLUS)
363457
# NECESSARY to ensure correct ordering!
@@ -433,6 +527,15 @@ def __call__(self, expr_str, scope=None, evaluate=False, strict=False, context=N
433527
ir = PymbolicMapper()(result)
434528
return AttachScopes().visit(ir, scope=scope or Scope())
435529

530+
def parse_float(self, s):
531+
"""
532+
Parse float literals.
533+
534+
Do not cast to float via 'float()' in order to keep the original
535+
notation, e.g., do not convert 1E-3 to 0.003.
536+
"""
537+
return sym.FloatLiteral(value=s.replace("d", "e").replace("D", "e"))
538+
436539
def parse_f_float(self, s):
437540
"""
438541
Parse "Fortran-style" float literals.
@@ -441,7 +544,7 @@ def parse_f_float(self, s):
441544
"""
442545
stripped = s.split('_', 1)
443546
if len(stripped) == 2:
444-
return sym.Literal(value=self.parse_float(stripped[0]), kind=sym.Variable(name=stripped[1].lower()))
547+
return sym.FloatLiteral(value=self.parse_float(stripped[0]), kind=sym.Variable(name=stripped[1].lower()))
445548
return self.parse_float(stripped[0])
446549

447550
def parse_f_int(self, s):

0 commit comments

Comments
 (0)