Skip to content

Commit f90a463

Browse files
authored
Merge pull request #506 from Blosc/fix_out
Solve issue #503
2 parents 923c462 + 1c7c900 commit f90a463

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

src/blosc2/lazyexpr.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2762,10 +2762,11 @@ def _compute_expr(self, item, kwargs): # noqa : C901
27622762
where_x = self._where_args["_where_x"]
27632763
where_y = self._where_args["_where_y"]
27642764
return np.where(lazy_expr, where_x, where_y)[key]
2765-
if hasattr(self, "_output"):
2765+
out = kwargs.get("_output", None)
2766+
if out is not None:
27662767
# This is not exactly optimized, but it works for now
2767-
self._output[:] = lazy_expr[key]
2768-
return self._output
2768+
out[:] = lazy_expr[key]
2769+
return out
27692770
arr = lazy_expr[key]
27702771
if builtins.sum(mask) > 0:
27712772
# Correct shape to adjust to NumPy convention
@@ -2834,11 +2835,11 @@ def sort(self, order: str | list[str] | None = None) -> blosc2.LazyArray:
28342835

28352836
def compute(self, item=(), **kwargs) -> blosc2.NDArray:
28362837
# When NumPy ufuncs are called, the user may add an `out` parameter to kwargs
2837-
if "out" in kwargs:
2838+
if "out" in kwargs: # use provided out preferentially
28382839
kwargs["_output"] = kwargs.pop("out")
2839-
self._output = kwargs["_output"]
2840-
if hasattr(self, "_output"):
2840+
elif hasattr(self, "_output"):
28412841
kwargs["_output"] = self._output
2842+
28422843
if "ne_args" in kwargs:
28432844
kwargs["_ne_args"] = kwargs.pop("ne_args")
28442845
if hasattr(self, "_ne_args"):

tests/ndarray/test_lazyexpr.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,11 +1745,21 @@ def test_lazylinalg():
17451745

17461746
# Test for issue #503 (LazyArray.compute() should honor out param)
17471747
def test_lazyexpr_compute_out():
1748+
# check reductions
17481749
a = blosc2.ones(10)
17491750
out = blosc2.zeros(1)
17501751
lexpr = blosc2.lazyexpr("sum(a)")
17511752
assert lexpr.compute(out=out) is out
17521753
assert out[0] == 10
1754+
assert lexpr.compute() is not out
1755+
1756+
# check normal expression
1757+
a = blosc2.ones(10)
1758+
out = blosc2.zeros(10)
1759+
lexpr = blosc2.lazyexpr("sin(a)")
1760+
assert lexpr.compute(out=out) is out
1761+
assert out[0] == np.sin(1)
1762+
assert lexpr.compute() is not out
17531763

17541764

17551765
def test_lazyexpr_2args():

0 commit comments

Comments
 (0)