99import re
1010import math
1111import pytools .lex
12+ import numpy as np
1213from pymbolic .parser import Parser as ParserBase # , FinalizedTuple
1314from pymbolic .mapper import Mapper
1415import pymbolic .primitives as pmbl
1516from pymbolic .mapper .evaluator import EvaluationMapper
1617from pymbolic .parser import (
1718 _openpar , _closepar , _minus , FinalizedTuple , _PREC_UNARY ,
18- _PREC_TIMES , _PREC_PLUS , _times , _plus
19+ _PREC_TIMES , _PREC_PLUS , _PREC_CALL , _times , _plus
1920)
2021try :
2122 from fparser .two .Fortran2003 import Intrinsic_Name
@@ -161,10 +162,15 @@ def map_tuple(self, expr, *args, **kwargs):
161162 def map_list (self , expr , * args , ** kwargs ):
162163 return sym .LiteralList ([self .rec (elem , * args , ** kwargs ) for elem in expr ])
163164
164- # hijack 'pymbolic.Remainder' to construct DerivedTypes ...
165165 def map_remainder (self , expr , * args , ** kwargs ):
166- parent = self .rec (expr .numerator )
167- return self .rec (expr .denominator , parent = parent )
166+ # this should never happen as '%' is overwritten to represent derived types
167+ raise NotImplementedError
168+
169+ def map_lookup (self , expr , * args , ** kwargs ):
170+ # construct derived type(s) variables
171+ parent = kwargs .pop ('parent' , None )
172+ parent = self .rec (expr .aggregate , parent = parent )
173+ return self .rec (expr .name , parent = parent )
168174
169175
170176class LokiEvaluationMapper (EvaluationMapper ):
@@ -178,6 +184,16 @@ class LokiEvaluationMapper(EvaluationMapper):
178184 Raise exception for unknown symbols/expressions (default: `False`).
179185 """
180186
187+ @staticmethod
188+ def case_insensitive_getattr (obj , attr ):
189+ """
190+ Case-insensitive version of `getattr`.
191+ """
192+ for elem in dir (obj ):
193+ if elem .lower () == attr .lower ():
194+ return getattr (obj , elem )
195+ return getattr (obj , attr )
196+
181197 def __init__ (self , strict = False , ** kwargs ):
182198 self .strict = strict
183199 super ().__init__ (** kwargs )
@@ -198,6 +214,15 @@ def map_variable(self, expr):
198214 return super ().map_variable (expr )
199215 return expr
200216
217+ @staticmethod
218+ def _evaluate_array (arr , dims ):
219+ """
220+ Evaluate arrays by converting to numpy array and
221+ adapting the dimensions corresponding to the different
222+ starting index.
223+ """
224+ return np .array (arr , order = 'F' ).item (* [dim - 1 for dim in dims ])
225+
201226 def map_call (self , expr ):
202227 if expr .function .name .lower () == 'min' :
203228 return min (self .rec (par ) for par in expr .parameters )
@@ -216,8 +241,46 @@ def map_call(self, expr):
216241 return math .sqrt (float ([self .rec (par ) for par in expr .parameters ][0 ]))
217242 if expr .function .name .lower () == 'exp' :
218243 return math .exp (float ([self .rec (par ) for par in expr .parameters ][0 ]))
244+ if expr .function .name in self .context and not callable (self .context [expr .function .name ]):
245+ return self ._evaluate_array (self .context [expr .function .name ],
246+ [self .rec (par ) for par in expr .parameters ])
219247 return super ().map_call (expr )
220248
249+ def map_call_with_kwargs (self , expr ):
250+ args = [self .rec (par ) for par in expr .parameters ]
251+ kwargs = {
252+ k : self .rec (v )
253+ for k , v in expr .kw_parameters .items ()}
254+ kwargs = CaseInsensitiveDict (kwargs )
255+ return self .rec (expr .function )(* args , ** kwargs )
256+
257+ def map_lookup (self , expr ):
258+ try :
259+ if isinstance (expr .name , pmbl .Variable ):
260+ name = expr .name .name
261+ return self .case_insensitive_getattr (self .rec (expr .aggregate ), name )
262+ if isinstance (expr .name , pmbl .Call ):
263+ name = expr .name .function .name
264+ if callable (self .case_insensitive_getattr (self .rec (expr .aggregate ), name )):
265+ return self .case_insensitive_getattr (self .rec (expr .aggregate ),
266+ name )(* [self .rec (par ) for par in expr .name .parameters ])
267+ return self ._evaluate_array (self .case_insensitive_getattr (self .rec (expr .aggregate ), name ),
268+ [self .rec (par ) for par in expr .name .parameters ])
269+ if isinstance (expr .name , pmbl .CallWithKwargs ):
270+ name = expr .name .function .name
271+ args = [self .rec (par ) for par in expr .name .parameters ]
272+ kwargs = {
273+ k : self .rec (v )
274+ for k , v in expr .name .kw_parameters .items ()}
275+ kwargs = CaseInsensitiveDict (kwargs )
276+ return self .case_insensitive_getattr (self .rec (expr .aggregate ), name )(* args , ** kwargs )
277+
278+ except Exception as e :
279+ if self .strict :
280+ raise e
281+ return expr
282+ return expr
283+
221284
222285class ExpressionParser (ParserBase ):
223286 """
@@ -288,6 +351,7 @@ class ExpressionParser(ParserBase):
288351 _f_string = intern ("f_string" )
289352 _f_openbracket = intern ("openbracket" )
290353 _f_closebracket = intern ("closebracket" )
354+ _f_derived_type = intern ("dot" )
291355
292356 lex_table = [
293357 (_f_true , pytools .lex .RE (r"\.true\." , re .IGNORECASE )),
@@ -307,6 +371,7 @@ class ExpressionParser(ParserBase):
307371 pytools .lex .RE (r"\'.*\'" , re .IGNORECASE ))),
308372 (_f_openbracket , pytools .lex .RE (r"\(/" )),
309373 (_f_closebracket , pytools .lex .RE (r"/\)" )),
374+ (_f_derived_type , pytools .lex .RE (r"\%" )),
310375 ] + ParserBase .lex_table
311376 """
312377 Extend :any:`pymbolic.parser.Parser.lex_table` to accomodate for Fortran specifix syntax/expressions.
@@ -372,7 +437,12 @@ def parse_prefix(self, pstate):
372437 def parse_postfix (self , pstate , min_precedence , left_exp ):
373438
374439 did_something = False
375- if pstate .is_next (_times ) and _PREC_TIMES > min_precedence :
440+ if pstate .is_next (self ._f_derived_type ) and _PREC_CALL > min_precedence :
441+ pstate .advance ()
442+ right_exp = self .parse_expression (pstate , _PREC_PLUS )
443+ left_exp = pmbl .Lookup (left_exp , right_exp )
444+ did_something = True
445+ elif pstate .is_next (_times ) and _PREC_TIMES > min_precedence :
376446 pstate .advance ()
377447 right_exp = self .parse_expression (pstate , _PREC_PLUS )
378448 # NECESSARY to ensure correct ordering!
0 commit comments