Skip to content

Commit fb875e9

Browse files
committed
Add SimpleProxy to setitem
1 parent ac33315 commit fb875e9

File tree

3 files changed

+23
-14
lines changed

3 files changed

+23
-14
lines changed

src/blosc2/ndarray.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4261,6 +4261,9 @@ def __setitem__(
42614261
key_, mask = process_key(key, self.shape) # internally handles key an integer
42624262
if hasattr(value, "shape") and value.shape == ():
42634263
value = value.item()
4264+
value = (
4265+
value if np.isscalar(value) else blosc2.as_simpleproxy(value)
4266+
) # convert to SimpleProxy for e.g. JAX, Tensorflow, PyTorch
42644267

42654268
if builtins.any(isinstance(k, (list, np.ndarray)) for k in key_): # fancy indexing
42664269
_slice = ndindex.ndindex(key_).expand(
@@ -4284,20 +4287,17 @@ def __setitem__(
42844287
return self._get_set_nonunit_steps((start, stop, step, mask), value=value)
42854288

42864289
shape = [sp - st for sp, st in zip(stop, start, strict=False)]
4287-
if isinstance(value, NDArray):
4288-
value = value[...] # convert to numpy
4289-
if np.isscalar(value):
4290+
if isinstance(value, blosc2.Operand): # handles SimpleProxy, NDArray, LazyExpr etc.
4291+
value = value[()] # convert to numpy
4292+
if np.isscalar(value) or value.shape == ():
42904293
value = np.full(shape, value, dtype=self.dtype)
4291-
elif isinstance(value, np.ndarray): # handles decompressed NDArray too
4292-
if value.dtype != self.dtype:
4293-
try:
4294-
value = value.astype(self.dtype)
4295-
except ComplexWarning:
4296-
# numexpr type inference can lead to unnecessary type promotions
4297-
# when using complex functions (e.g. conj) with real arrays
4298-
value = value.real.astype(self.dtype)
4299-
if value.shape == ():
4300-
value = np.full(shape, value, dtype=self.dtype)
4294+
if isinstance(value, np.ndarray) and value.dtype != self.dtype: # handles decompressed NDArray too
4295+
try:
4296+
value = value.astype(self.dtype)
4297+
except ComplexWarning:
4298+
# numexpr type inference can lead to unnecessary type promotions
4299+
# when using complex functions (e.g. conj) with real arrays
4300+
value = value.real.astype(self.dtype)
43014301

43024302
return super().set_slice((start, stop), value)
43034303

src/blosc2/proxy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ class SimpleProxy(blosc2.Operand):
612612

613613
def __init__(self, src, chunks: tuple | None = None, blocks: tuple | None = None):
614614
if not hasattr(src, "shape") or not hasattr(src, "dtype"):
615-
# If the source is not a NumPy array, convert it to one
615+
# If the source is not an array, convert it to NumPy
616616
src = np.asarray(src)
617617
if not hasattr(src, "__getitem__"):
618618
raise TypeError("The source must have a __getitem__ method")

tests/ndarray/test_setitem.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
import pytest
11+
import torch
1112

1213
import blosc2
1314

@@ -44,6 +45,14 @@ def test_setitem(shape, chunks, blocks, slices, dtype):
4445
nparray[slices] = val
4546
np.testing.assert_almost_equal(a[...], nparray)
4647

48+
# Object called via SimpleProxy
49+
slice_shape = a[slices].shape
50+
dtype_ = {np.float32: torch.float32, np.int32: torch.int32, np.float64: torch.float64}[dtype]
51+
val = torch.ones(slice_shape, dtype=dtype_)
52+
a[slices] = val
53+
nparray[slices] = val
54+
np.testing.assert_almost_equal(a[...], nparray)
55+
4756
# blosc2.NDArray
4857
if np.prod(slice_shape) == 1 or len(slice_shape) != len(blocks):
4958
chunks = None

0 commit comments

Comments
 (0)