Skip to content

Commit 9814077

Browse files
committed
Add get_chunk for lazyudf
1 parent 75cc0db commit 9814077

File tree

2 files changed

+89
-22
lines changed

2 files changed

+89
-22
lines changed

src/blosc2/lazyexpr.py

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import re
2222
import sys
2323
import threading
24-
from abc import ABC, abstractmethod
24+
from abc import ABC, abstractmethod, abstractproperty
2525
from dataclasses import asdict
2626
from enum import Enum
2727
from pathlib import Path
@@ -431,6 +431,41 @@ def to_cframe(self) -> bytes:
431431
"""
432432
return self.compute().to_cframe()
433433

434+
@abstractproperty
435+
def chunks(self) -> tuple[int]:
436+
"""
437+
Return :ref:`LazyArray` chunks.
438+
"""
439+
pass
440+
441+
@abstractproperty
442+
def blocks(self) -> tuple[int]:
443+
"""
444+
Return :ref:`LazyArray` blocks.
445+
"""
446+
pass
447+
448+
def get_chunk(self, nchunk):
449+
"""Get the `nchunk` of the expression, evaluating only that one."""
450+
# Create an empty array with the chunkshape and dtype; this is fast
451+
shape = self.shape
452+
chunks = self.chunks
453+
# Calculate the shape of the (chunk) slice_ (especially at the end of the array)
454+
chunks_idx, _ = get_chunks_idx(shape, chunks)
455+
coords = tuple(np.unravel_index(nchunk, chunks_idx))
456+
slice_ = tuple(
457+
slice(c * s, min((c + 1) * s, shape[i]))
458+
for i, (c, s) in enumerate(zip(coords, chunks, strict=True))
459+
)
460+
loc_chunks = tuple(s.stop - s.start for s in slice_)
461+
out = blosc2.empty(shape=self.chunks, dtype=self.dtype, chunks=self.chunks, blocks=self.blocks)
462+
if loc_chunks == self.chunks:
463+
self.compute(item=slice_, out=out)
464+
else:
465+
_slice_ = tuple(slice(0, s) for s in loc_chunks)
466+
out[_slice_] = self.compute(item=slice_)
467+
return out.schunk.get_chunk(0)
468+
434469

435470
def convert_inputs(inputs):
436471
if not inputs or len(inputs) == 0:
@@ -2415,27 +2450,6 @@ def __init__(self, new_op): # noqa: C901
24152450
self.operands = {"o0": value1, "o1": value2}
24162451
self.expression = f"(o0 {op} o1)"
24172452

2418-
def get_chunk(self, nchunk):
2419-
"""Get the `nchunk` of the expression, evaluating only that one."""
2420-
# Create an empty array with the chunkshape and dtype; this is fast
2421-
shape = self.shape
2422-
chunks = self.chunks
2423-
# Calculate the shape of the (chunk) slice_ (especially at the end of the array)
2424-
chunks_idx, _ = get_chunks_idx(shape, chunks)
2425-
coords = tuple(np.unravel_index(nchunk, chunks_idx))
2426-
slice_ = tuple(
2427-
slice(c * s, min((c + 1) * s, shape[i]))
2428-
for i, (c, s) in enumerate(zip(coords, chunks, strict=True))
2429-
)
2430-
loc_chunks = tuple(s.stop - s.start for s in slice_)
2431-
out = blosc2.empty(shape=self.chunks, dtype=self.dtype, chunks=self.chunks, blocks=self.blocks)
2432-
if loc_chunks == self.chunks:
2433-
self.compute(item=slice_, out=out)
2434-
else:
2435-
_slice_ = tuple(slice(0, s) for s in loc_chunks)
2436-
out[_slice_] = self.compute(item=slice_)
2437-
return out.schunk.get_chunk(0)
2438-
24392453
def update_expr(self, new_op): # noqa: C901
24402454
prev_flag = blosc2._disable_overloaded_equal
24412455
# We use a lot of the original NDArray.__eq__ as 'is', so deactivate the overloaded one
@@ -3212,6 +3226,38 @@ def info_items(self):
32123226
("dtype", self.dtype),
32133227
]
32143228

3229+
@property
3230+
def chunks(self):
3231+
if hasattr(self, "_chunks"):
3232+
return self._chunks
3233+
shape, self._chunks, self._blocks, fast_path = validate_inputs(
3234+
self.inputs_dict, getattr(self, "_out", None)
3235+
)
3236+
if not hasattr(self, "_shape"):
3237+
self._shape = shape
3238+
if self._shape != shape: # validate inputs only works for elementwise funcs so returned shape might
3239+
fast_path = False # be incompatible with true output shape
3240+
if not fast_path:
3241+
# Not using the fast path, so we need to compute the chunks/blocks automatically
3242+
self._chunks, self._blocks = compute_chunks_blocks(self.shape, None, None, dtype=self.dtype)
3243+
return self._chunks
3244+
3245+
@property
3246+
def blocks(self):
3247+
if hasattr(self, "_blocks"):
3248+
return self._blocks
3249+
shape, self._chunks, self._blocks, fast_path = validate_inputs(
3250+
self.inputs_dict, getattr(self, "_out", None)
3251+
)
3252+
if not hasattr(self, "_shape"):
3253+
self._shape = shape
3254+
if self._shape != shape: # validate inputs only works for elementwise funcs so returned shape might
3255+
fast_path = False # be incompatible with true output shape
3256+
if not fast_path:
3257+
# Not using the fast path, so we need to compute the chunks/blocks automatically
3258+
self._chunks, self._blocks = compute_chunks_blocks(self.shape, None, None, dtype=self.dtype)
3259+
return self._blocks
3260+
32153261
# TODO: indices and sort are repeated in LazyExpr; refactor
32163262
def indices(self, order: str | list[str] | None = None) -> blosc2.LazyArray:
32173263
if self.dtype.fields is None:

tests/ndarray/test_lazyudf.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111

1212
import blosc2
13+
from blosc2.ndarray import get_chunks_idx
1314

1415

1516
def udf1p(inputs_tuple, output, offset):
@@ -471,3 +472,23 @@ def test_save_ludf():
471472
assert isinstance(expr, blosc2.LazyUDF)
472473
res_lazyexpr = expr.compute()
473474
np.testing.assert_array_equal(res_lazyexpr[:], npc)
475+
476+
477+
# Test get_chunk method
478+
def test_get_chunk():
479+
a = blosc2.linspace(0, 100, 100, shape=(10, 10), chunks=(3, 4), blocks=(2, 3))
480+
expr = blosc2.lazyudf(udf1p, (a,), dtype=a.dtype, shape=a.shape)
481+
nres = a[:] + 1
482+
chunksize = np.prod(expr.chunks) * expr.dtype.itemsize
483+
blocksize = np.prod(expr.blocks) * expr.dtype.itemsize
484+
_, nchunks = get_chunks_idx(expr.shape, expr.chunks)
485+
out = blosc2.empty(expr.shape, dtype=expr.dtype, chunks=expr.chunks, blocks=expr.blocks)
486+
for nchunk in range(nchunks):
487+
chunk = expr.get_chunk(nchunk)
488+
out.schunk.update_chunk(nchunk, chunk)
489+
chunksize_ = int.from_bytes(chunk[4:8], byteorder="little")
490+
blocksize_ = int.from_bytes(chunk[8:12], byteorder="little")
491+
# Sometimes the actual chunksize is smaller than the expected chunks due to padding
492+
assert chunksize <= chunksize_
493+
assert blocksize == blocksize_
494+
np.testing.assert_allclose(out[:], nres)

0 commit comments

Comments
 (0)