Skip to content

Commit 868b36a

Browse files
committed
Fix where handling for string lexpr
1 parent dbca827 commit 868b36a

File tree

2 files changed

+7
-15
lines changed

2 files changed

+7
-15
lines changed

src/blosc2/lazyexpr.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3002,26 +3002,12 @@ def _new_expr(cls, expression, operands, guess, out=None, where=None, ne_args=No
30023002
expression_, operands_ = conserve_functions(
30033003
_expression, _operands, new_expr.operands | local_vars
30043004
)
3005-
# if new_expr has where_args, must have come from where(...) - or possibly where(where(..
3006-
# since 5*where, where + ... are evaluated eagerly
3007-
if hasattr(new_expr, "_where_args"):
3008-
st = expression_.find("where(") + len(
3009-
"where("
3010-
) # expr always begins where( - should have st = 6 always
3011-
finalexpr = ""
3012-
counter = 0
3013-
for char in expression_[st:]: # get rid of external where(...)
3014-
finalexpr += char
3015-
counter += 1 * (char == "(") - 1 * (char == ")")
3016-
if counter == 0 and char == ",":
3017-
break
3018-
expression_ = finalexpr[:-1] # remove trailing comma
30193005
else:
3020-
new_expr = cls(None)
30213006
# An immediate evaluation happened
30223007
# (e.g. all operands are numpy arrays or constructors)
30233008
# or passed "a", "a[:10]", 'sum(a)'
30243009
expression_, operands_ = conserve_functions(_expression, _operands, local_vars)
3010+
new_expr = cls(None)
30253011
new_expr.expression = f"({expression_})" # force parenthesis
30263012
new_expr.operands = operands_
30273013
new_expr.expression_tosave = expression

tests/ndarray/test_lazyexpr_fields.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,17 @@ def test_where(array_fixture):
220220
res = expr.where(0, 1).compute()
221221
nres = ne_evaluate("where(na1**2 + na2**2 > 2 * na1 * na2 + 1, 0, 1)")
222222
np.testing.assert_allclose(res[:], nres)
223+
223224
# Test with getitem
224225
sl = slice(100)
225226
res = expr.where(0, 1)[sl]
226227
np.testing.assert_allclose(res, nres[sl])
227228

229+
# Test with string
230+
res = blosc2.evaluate("where(a1**2 + a2**2 > 2 * a1 * a2 + 1, a1 + 5, a2)")
231+
nres = ne_evaluate("where(na1**2 + na2**2 > 2 * na1 * na2 + 1, na1 + 5, na2)")
232+
np.testing.assert_allclose(res, nres)
233+
228234

229235
# Test expressions with where() and string comps
230236
def test_lazy_where(array_fixture):

0 commit comments

Comments
 (0)