@@ -171,6 +171,17 @@ def ne_evaluate(expression, local_dict=None, **kwargs):
171171relational_ops = ["==" , "!=" , "<" , "<=" , ">" , ">=" ]
172172logical_ops = ["&" , "|" , "^" , "~" ]
173173not_complex_ops = ["maximum" , "minimum" , "<" , "<=" , ">" , ">=" ]
174+ funcs_2args = (
175+ "arctan2" ,
176+ "contains" ,
177+ "pow" ,
178+ "power" ,
179+ "nextafter" ,
180+ "copysign" ,
181+ "hypot" ,
182+ "maximum" ,
183+ "minimum" ,
184+ )
174185
175186
176187def get_expr_globals (expression ):
@@ -2239,23 +2250,22 @@ def __init__(self, new_op): # noqa: C901
22392250 dtype_ = check_dtype (op , value1 , value2 ) # perform some checks
22402251 if value2 is None :
22412252 if isinstance (value1 , LazyExpr ):
2242- self .expression = f"{ op } ({ value1 .expression } )"
2253+ self .expression = value1 . expression if op is None else f"{ op } ({ value1 .expression } )"
22432254 self .operands = value1 .operands
22442255 else :
22452256 self .operands = {"o0" : value1 }
22462257 self .expression = "o0" if op is None else f"{ op } (o0)"
22472258 return
2248- elif op in (
2249- "arctan2" ,
2250- "contains" ,
2251- "pow" ,
2252- "power" ,
2253- "nextafter" ,
2254- "copysign" ,
2255- "hypot" ,
2256- "maximum" ,
2257- "minimum" ,
2258- ):
2259+ elif isinstance (value1 , LazyExpr ) or isinstance (value2 , LazyExpr ):
2260+ if isinstance (value1 , LazyExpr ):
2261+ newexpr = value1 .update_expr (new_op )
2262+ else :
2263+ newexpr = value2 .update_expr (new_op )
2264+ self .expression = newexpr .expression
2265+ self .operands = newexpr .operands
2266+ self ._dtype = newexpr .dtype
2267+ return
2268+ elif op in funcs_2args :
22592269 if np .isscalar (value1 ) and np .isscalar (value2 ):
22602270 self .expression = f"{ op } ({ value1 } , { value2 } )"
22612271 elif np .isscalar (value2 ):
@@ -2268,15 +2278,6 @@ def __init__(self, new_op): # noqa: C901
22682278 self .operands = {"o0" : value1 , "o1" : value2 }
22692279 self .expression = f"{ op } (o0, o1)"
22702280 return
2271- elif isinstance (value1 , LazyExpr ) or isinstance (value2 , LazyExpr ):
2272- if isinstance (value1 , LazyExpr ):
2273- newexpr = value1 .update_expr (new_op )
2274- else :
2275- newexpr = value2 .update_expr (new_op )
2276- self .expression = newexpr .expression
2277- self .operands = newexpr .operands
2278- self ._dtype = newexpr .dtype
2279- return
22802281
22812282 self ._dtype = dtype_
22822283 if np .isscalar (value1 ) and np .isscalar (value2 ):
@@ -2351,51 +2352,64 @@ def update_expr(self, new_op): # noqa: C901
23512352 if not isinstance (value1 , LazyExpr ) and not isinstance (value2 , LazyExpr ):
23522353 # We converted some of the operands to NDArray (where() handling above)
23532354 new_operands = {"o0" : value1 , "o1" : value2 }
2354- expression = f"(o0 { op } o1)"
2355+ expression = "op(o0, o1)" if op in funcs_2args else f"(o0 { op } o1)"
23552356 return self ._new_expr (expression , new_operands , guess = False , out = None , where = None )
23562357 elif isinstance (value1 , LazyExpr ) and isinstance (value2 , LazyExpr ):
23572358 # Expression fusion
23582359 # Fuse operands in expressions and detect duplicates
23592360 new_operands , dup_op = fuse_operands (value1 .operands , value2 .operands )
23602361 # Take expression 2 and rebase the operands while removing duplicates
23612362 new_expr = fuse_expressions (value2 .expression , len (value1 .operands ), dup_op )
2362- expression = f"({ value1 .expression } { op } { new_expr } )"
2363- self .operands = value1 .operands
2363+ expression = (
2364+ f"{ op } ({ value1 .expression } , { new_expr } )"
2365+ if op in funcs_2args
2366+ else f"({ value1 .expression } { op } { new_expr } )"
2367+ )
2368+ def_operands = value1 .operands
23642369 elif isinstance (value1 , LazyExpr ):
2365- if op == "~" :
2366- expression = f"({ op } { value1 .expression } )"
2367- elif np .isscalar (value2 ):
2368- expression = f"({ value1 .expression } { op } { value2 } )"
2370+ if np .isscalar (value2 ):
2371+ v2 = value2
23692372 elif hasattr (value2 , "shape" ) and value2 .shape == ():
2370- expression = f"( { value1 . expression } { op } { value2 [()]} )"
2373+ v2 = value2 [()]
23712374 else :
23722375 operand_to_key = {id (v ): k for k , v in value1 .operands .items ()}
23732376 try :
2374- op_name = operand_to_key [id (value2 )]
2377+ v2 = operand_to_key [id (value2 )]
23752378 except KeyError :
2376- op_name = f"o{ len (value1 .operands )} "
2377- new_operands = {op_name : value2 }
2378- expression = f"({ value1 .expression } { op } { op_name } )"
2379- self .operands = value1 .operands
2379+ v2 = f"o{ len (value1 .operands )} "
2380+ new_operands = {v2 : value2 }
2381+ if op == "~" :
2382+ expression = f"({ op } { value1 .expression } )"
2383+ else :
2384+ expression = (
2385+ f"{ op } ({ value1 .expression } , { v2 } )"
2386+ if op in funcs_2args
2387+ else f"({ value1 .expression } { op } { v2 } )"
2388+ )
2389+ def_operands = value1 .operands
23802390 else :
23812391 if np .isscalar (value1 ):
2382- expression = f"( { value1 } { op } { value2 . expression } )"
2392+ v1 = value1
23832393 elif hasattr (value1 , "shape" ) and value1 .shape == ():
2384- expression = f"( { value1 [()]} { op } { value2 . expression } )"
2394+ v1 = value1 [()]
23852395 else :
23862396 operand_to_key = {id (v ): k for k , v in value2 .operands .items ()}
23872397 try :
2388- op_name = operand_to_key [id (value1 )]
2398+ v1 = operand_to_key [id (value1 )]
23892399 except KeyError :
2390- op_name = f"o{ len (value2 .operands )} "
2391- new_operands = {op_name : value1 }
2392- if op == "[]" : # syntactic sugar for slicing
2393- expression = f"({ op_name } [{ value2 .expression } ])"
2394- else :
2395- expression = f"({ op_name } { op } { value2 .expression } )"
2396- self .operands = value2 .operands
2400+ v1 = f"o{ len (value2 .operands )} "
2401+ new_operands = {v1 : value1 }
2402+ if op == "[]" : # syntactic sugar for slicing
2403+ expression = f"({ v1 } [{ value2 .expression } ])"
2404+ else :
2405+ expression = (
2406+ f"{ op } ({ v1 } , { value2 .expression } )"
2407+ if op in funcs_2args
2408+ else f"({ v1 } { op } { value2 .expression } )"
2409+ )
2410+ def_operands = value2 .operands
23972411 # Return a new expression
2398- operands = self . operands | new_operands
2412+ operands = def_operands | new_operands
23992413 expr = self ._new_expr (expression , operands , guess = False , out = None , where = None )
24002414 expr ._dtype = dtype_ # override dtype with preserved dtype
24012415 return expr
0 commit comments