99import re
1010import math
1111import pytools .lex
12+ import numpy as np
1213from pymbolic .parser import Parser as ParserBase # , FinalizedTuple
1314from pymbolic .mapper import Mapper
1415import pymbolic .primitives as pmbl
1516from pymbolic .mapper .evaluator import EvaluationMapper
1617from pymbolic .parser import (
1718 _openpar , _closepar , _minus , FinalizedTuple , _PREC_UNARY ,
18- _PREC_TIMES , _PREC_PLUS , _times , _plus
19+ _PREC_TIMES , _PREC_PLUS , _PREC_CALL , _times , _plus
1920)
2021try :
2122 from fparser .two .Fortran2003 import Intrinsic_Name
@@ -109,14 +110,19 @@ def map_meta_symbol(self, expr, *args, **kwargs):
109110 map_array = map_meta_symbol
110111
111112 def map_slice (self , expr , * args , ** kwargs ):
112- return sym .RangeIndex (tuple (self .rec (child , * args , ** kwargs ) for child in expr .children ))
113+ children = tuple (self .rec (child , * args , ** kwargs ) if child is not None else child for child in expr .children )
114+ if len (children ) == 1 and children [0 ] is None :
115+ # this corresponds to ':' (sym.RangeIndex((None, None)))
116+ children = (None , None )
117+ return sym .RangeIndex (children )
113118
114119 map_range = map_slice
115120 map_range_index = map_slice
116121 map_loop_range = map_slice
117122
118123 def map_variable (self , expr , * args , ** kwargs ):
119- return sym .Variable (name = expr .name )
124+ parent = kwargs .pop ('parent' , None )
125+ return sym .Variable (name = expr .name , parent = parent )
120126
121127 def map_algebraic_leaf (self , expr , * args , ** kwargs ):
122128 if str (expr ).isnumeric ():
@@ -127,7 +133,12 @@ def map_algebraic_leaf(self, expr, *args, **kwargs):
127133 if expr .function .name .upper () in FORTRAN_INTRINSIC_PROCEDURES :
128134 return sym .InlineCall (function = sym .Variable (name = expr .function .name ),
129135 parameters = tuple (self .rec (param , * args , ** kwargs ) for param in expr .parameters ))
130- return sym .Variable (name = expr .function .name ,
136+ parent = kwargs .pop ('parent' , None )
137+ dimensions = tuple (self .rec (param , * args , ** kwargs ) for param in expr .parameters )
138+ if not dimensions :
139+ return sym .InlineCall (function = sym .Variable (name = expr .function .name , parent = parent ),
140+ parameters = dimensions )
141+ return sym .Variable (name = expr .function .name , parent = parent ,
131142 dimensions = tuple (self .rec (param , * args , ** kwargs ) for param in expr .parameters ))
132143 try :
133144 return self .map_variable (expr , * args , ** kwargs )
@@ -151,6 +162,16 @@ def map_tuple(self, expr, *args, **kwargs):
151162 def map_list (self , expr , * args , ** kwargs ):
152163 return sym .LiteralList ([self .rec (elem , * args , ** kwargs ) for elem in expr ])
153164
165+ def map_remainder (self , expr , * args , ** kwargs ):
166+ # this should never happen as '%' is overwritten to represent derived types
167+ raise NotImplementedError
168+
169+ def map_lookup (self , expr , * args , ** kwargs ):
170+ # construct derived type(s) variables
171+ parent = kwargs .pop ('parent' , None )
172+ parent = self .rec (expr .aggregate , parent = parent )
173+ return self .rec (expr .name , parent = parent )
174+
154175
155176class LokiEvaluationMapper (EvaluationMapper ):
156177 """
@@ -163,6 +184,16 @@ class LokiEvaluationMapper(EvaluationMapper):
163184 Raise exception for unknown symbols/expressions (default: `False`).
164185 """
165186
187+ @staticmethod
188+ def case_insensitive_getattr (obj , attr ):
189+ """
190+ Case-insensitive version of `getattr`.
191+ """
192+ for elem in dir (obj ):
193+ if elem .lower () == attr .lower ():
194+ return getattr (obj , elem )
195+ return getattr (obj , attr )
196+
166197 def __init__ (self , strict = False , ** kwargs ):
167198 self .strict = strict
168199 super ().__init__ (** kwargs )
@@ -183,6 +214,15 @@ def map_variable(self, expr):
183214 return super ().map_variable (expr )
184215 return expr
185216
217+ @staticmethod
218+ def _evaluate_array (arr , dims ):
219+ """
220+ Evaluate arrays by converting to numpy array and
221+ adapting the dimensions corresponding to the different
222+ starting index.
223+ """
224+ return np .array (arr , order = 'F' ).item (* [dim - 1 for dim in dims ])
225+
186226 def map_call (self , expr ):
187227 if expr .function .name .lower () == 'min' :
188228 return min (self .rec (par ) for par in expr .parameters )
@@ -201,8 +241,55 @@ def map_call(self, expr):
201241 return math .sqrt (float ([self .rec (par ) for par in expr .parameters ][0 ]))
202242 if expr .function .name .lower () == 'exp' :
203243 return math .exp (float ([self .rec (par ) for par in expr .parameters ][0 ]))
244+ if expr .function .name in self .context and not callable (self .context [expr .function .name ]):
245+ return self ._evaluate_array (self .context [expr .function .name ],
246+ [self .rec (par ) for par in expr .parameters ])
204247 return super ().map_call (expr )
205248
249+ def map_call_with_kwargs (self , expr ):
250+ args = [self .rec (par ) for par in expr .parameters ]
251+ kwargs = {
252+ k : self .rec (v )
253+ for k , v in expr .kw_parameters .items ()}
254+ kwargs = CaseInsensitiveDict (kwargs )
255+ return self .rec (expr .function )(* args , ** kwargs )
256+
257+ def map_lookup (self , expr ):
258+
259+ def rec_lookup (expr , obj , name ):
260+ return expr .name , self .case_insensitive_getattr (obj , name )
261+
262+ try :
263+ current_expr = expr
264+ obj = self .rec (expr .aggregate )
265+ while isinstance (current_expr .name , pmbl .Lookup ):
266+ current_expr , obj = rec_lookup (current_expr , obj , current_expr .name .aggregate .name )
267+ if isinstance (current_expr .name , pmbl .Variable ):
268+ _ , obj = rec_lookup (current_expr , obj , current_expr .name .name )
269+ return obj
270+ if isinstance (current_expr .name , pmbl .Call ):
271+ name = current_expr .name .function .name
272+ _ , obj = rec_lookup (current_expr , obj , name )
273+ if callable (obj ):
274+ return obj (* [self .rec (par ) for par in current_expr .name .parameters ])
275+ return self ._evaluate_array (obj , [self .rec (par ) for par in current_expr .name .parameters ])
276+ if isinstance (current_expr .name , pmbl .CallWithKwargs ):
277+ name = current_expr .name .function .name
278+ _ , obj = rec_lookup (current_expr , obj , name )
279+ args = [self .rec (par ) for par in current_expr .name .parameters ]
280+ kwargs = CaseInsensitiveDict (
281+ (k , self .rec (v ))
282+ for k , v in current_expr .name .kw_parameters .items ()
283+ )
284+ return obj (* args , ** kwargs )
285+ except Exception as e :
286+ if self .strict :
287+ raise e
288+ return expr
289+ if self .strict :
290+ raise NotImplementedError
291+ return expr
292+
206293
207294class ExpressionParser (ParserBase ):
208295 """
@@ -273,6 +360,7 @@ class ExpressionParser(ParserBase):
273360 _f_string = intern ("f_string" )
274361 _f_openbracket = intern ("openbracket" )
275362 _f_closebracket = intern ("closebracket" )
363+ _f_derived_type = intern ("dot" )
276364
277365 lex_table = [
278366 (_f_true , pytools .lex .RE (r"\.true\." , re .IGNORECASE )),
@@ -292,6 +380,7 @@ class ExpressionParser(ParserBase):
292380 pytools .lex .RE (r"\'.*\'" , re .IGNORECASE ))),
293381 (_f_openbracket , pytools .lex .RE (r"\(/" )),
294382 (_f_closebracket , pytools .lex .RE (r"/\)" )),
383+ (_f_derived_type , pytools .lex .RE (r"\%" )),
295384 ] + ParserBase .lex_table
296385 """
297386 Extend :any:`pymbolic.parser.Parser.lex_table` to accomodate for Fortran specifix syntax/expressions.
@@ -357,7 +446,12 @@ def parse_prefix(self, pstate):
357446 def parse_postfix (self , pstate , min_precedence , left_exp ):
358447
359448 did_something = False
360- if pstate .is_next (_times ) and _PREC_TIMES > min_precedence :
449+ if pstate .is_next (self ._f_derived_type ) and _PREC_CALL > min_precedence :
450+ pstate .advance ()
451+ right_exp = self .parse_expression (pstate , _PREC_PLUS )
452+ left_exp = pmbl .Lookup (left_exp , right_exp )
453+ did_something = True
454+ elif pstate .is_next (_times ) and _PREC_TIMES > min_precedence :
361455 pstate .advance ()
362456 right_exp = self .parse_expression (pstate , _PREC_PLUS )
363457 # NECESSARY to ensure correct ordering!
@@ -433,6 +527,15 @@ def __call__(self, expr_str, scope=None, evaluate=False, strict=False, context=N
433527 ir = PymbolicMapper ()(result )
434528 return AttachScopes ().visit (ir , scope = scope or Scope ())
435529
530+ def parse_float (self , s ):
531+ """
532+ Parse float literals.
533+
534+ Do not cast to float via 'float()' in order to keep the original
535+ notation, e.g., do not convert 1E-3 to 0.003.
536+ """
537+ return sym .FloatLiteral (value = s .replace ("d" , "e" ).replace ("D" , "e" ))
538+
436539 def parse_f_float (self , s ):
437540 """
438541 Parse "Fortran-style" float literals.
@@ -441,7 +544,7 @@ def parse_f_float(self, s):
441544 """
442545 stripped = s .split ('_' , 1 )
443546 if len (stripped ) == 2 :
444- return sym .Literal (value = self .parse_float (stripped [0 ]), kind = sym .Variable (name = stripped [1 ].lower ()))
547+ return sym .FloatLiteral (value = self .parse_float (stripped [0 ]), kind = sym .Variable (name = stripped [1 ].lower ()))
445548 return self .parse_float (stripped [0 ])
446549
447550 def parse_f_int (self , s ):
0 commit comments