Skip to content

Commit a4a7b4c

Browse files
authored
Merge pull request #534 from Blosc/add_checkout
Add get_chunk method to lazyudf
2 parents aa8712a + 1fe0cd9 commit a4a7b4c

File tree

2 files changed

+89
-22
lines changed

2 files changed

+89
-22
lines changed

src/blosc2/lazyexpr.py

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import re
2222
import sys
2323
import threading
24-
from abc import ABC, abstractmethod
24+
from abc import ABC, abstractmethod, abstractproperty
2525
from dataclasses import asdict
2626
from enum import Enum
2727
from pathlib import Path
@@ -437,6 +437,41 @@ def to_cframe(self) -> bytes:
437437
"""
438438
return self.compute().to_cframe()
439439

440+
@abstractproperty
441+
def chunks(self) -> tuple[int]:
442+
"""
443+
Return :ref:`LazyArray` chunks.
444+
"""
445+
pass
446+
447+
@abstractproperty
448+
def blocks(self) -> tuple[int]:
449+
"""
450+
Return :ref:`LazyArray` blocks.
451+
"""
452+
pass
453+
454+
def get_chunk(self, nchunk):
455+
"""Get the `nchunk` of the expression, evaluating only that one."""
456+
# Create an empty array with the chunkshape and dtype; this is fast
457+
shape = self.shape
458+
chunks = self.chunks
459+
# Calculate the shape of the (chunk) slice_ (especially at the end of the array)
460+
chunks_idx, _ = get_chunks_idx(shape, chunks)
461+
coords = tuple(np.unravel_index(nchunk, chunks_idx))
462+
slice_ = tuple(
463+
slice(c * s, min((c + 1) * s, shape[i]))
464+
for i, (c, s) in enumerate(zip(coords, chunks, strict=True))
465+
)
466+
loc_chunks = tuple(s.stop - s.start for s in slice_)
467+
out = blosc2.empty(shape=self.chunks, dtype=self.dtype, chunks=self.chunks, blocks=self.blocks)
468+
if loc_chunks == self.chunks:
469+
self.compute(item=slice_, out=out)
470+
else:
471+
_slice_ = tuple(slice(0, s) for s in loc_chunks)
472+
out[_slice_] = self.compute(item=slice_)
473+
return out.schunk.get_chunk(0)
474+
440475

441476
def convert_inputs(inputs):
442477
if not inputs or len(inputs) == 0:
@@ -2421,27 +2456,6 @@ def __init__(self, new_op): # noqa: C901
24212456
self.operands = {"o0": value1, "o1": value2}
24222457
self.expression = f"(o0 {op} o1)"
24232458

2424-
def get_chunk(self, nchunk):
2425-
"""Get the `nchunk` of the expression, evaluating only that one."""
2426-
# Create an empty array with the chunkshape and dtype; this is fast
2427-
shape = self.shape
2428-
chunks = self.chunks
2429-
# Calculate the shape of the (chunk) slice_ (especially at the end of the array)
2430-
chunks_idx, _ = get_chunks_idx(shape, chunks)
2431-
coords = tuple(np.unravel_index(nchunk, chunks_idx))
2432-
slice_ = tuple(
2433-
slice(c * s, min((c + 1) * s, shape[i]))
2434-
for i, (c, s) in enumerate(zip(coords, chunks, strict=True))
2435-
)
2436-
loc_chunks = tuple(s.stop - s.start for s in slice_)
2437-
out = blosc2.empty(shape=self.chunks, dtype=self.dtype, chunks=self.chunks, blocks=self.blocks)
2438-
if loc_chunks == self.chunks:
2439-
self.compute(item=slice_, out=out)
2440-
else:
2441-
_slice_ = tuple(slice(0, s) for s in loc_chunks)
2442-
out[_slice_] = self.compute(item=slice_)
2443-
return out.schunk.get_chunk(0)
2444-
24452459
def update_expr(self, new_op): # noqa: C901
24462460
prev_flag = blosc2._disable_overloaded_equal
24472461
# We use a lot of the original NDArray.__eq__ as 'is', so deactivate the overloaded one
@@ -3218,6 +3232,38 @@ def info_items(self):
32183232
("dtype", self.dtype),
32193233
]
32203234

3235+
@property
3236+
def chunks(self):
3237+
if hasattr(self, "_chunks"):
3238+
return self._chunks
3239+
shape, self._chunks, self._blocks, fast_path = validate_inputs(
3240+
self.inputs_dict, getattr(self, "_out", None)
3241+
)
3242+
if not hasattr(self, "_shape"):
3243+
self._shape = shape
3244+
if self._shape != shape: # validate inputs only works for elementwise funcs so returned shape might
3245+
fast_path = False # be incompatible with true output shape
3246+
if not fast_path:
3247+
# Not using the fast path, so we need to compute the chunks/blocks automatically
3248+
self._chunks, self._blocks = compute_chunks_blocks(self.shape, None, None, dtype=self.dtype)
3249+
return self._chunks
3250+
3251+
@property
3252+
def blocks(self):
3253+
if hasattr(self, "_blocks"):
3254+
return self._blocks
3255+
shape, self._chunks, self._blocks, fast_path = validate_inputs(
3256+
self.inputs_dict, getattr(self, "_out", None)
3257+
)
3258+
if not hasattr(self, "_shape"):
3259+
self._shape = shape
3260+
if self._shape != shape: # validate inputs only works for elementwise funcs so returned shape might
3261+
fast_path = False # be incompatible with true output shape
3262+
if not fast_path:
3263+
# Not using the fast path, so we need to compute the chunks/blocks automatically
3264+
self._chunks, self._blocks = compute_chunks_blocks(self.shape, None, None, dtype=self.dtype)
3265+
return self._blocks
3266+
32213267
# TODO: indices and sort are repeated in LazyExpr; refactor
32223268
def indices(self, order: str | list[str] | None = None) -> blosc2.LazyArray:
32233269
if self.dtype.fields is None:

tests/ndarray/test_lazyudf.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111

1212
import blosc2
13+
from blosc2.ndarray import get_chunks_idx
1314

1415

1516
def udf1p(inputs_tuple, output, offset):
@@ -471,3 +472,23 @@ def test_save_ludf():
471472
assert isinstance(expr, blosc2.LazyUDF)
472473
res_lazyexpr = expr.compute()
473474
np.testing.assert_array_equal(res_lazyexpr[:], npc)
475+
476+
477+
# Test get_chunk method
478+
def test_get_chunk():
479+
a = blosc2.linspace(0, 100, 100, shape=(10, 10), chunks=(3, 4), blocks=(2, 3))
480+
expr = blosc2.lazyudf(udf1p, (a,), dtype=a.dtype, shape=a.shape)
481+
nres = a[:] + 1
482+
chunksize = np.prod(expr.chunks) * expr.dtype.itemsize
483+
blocksize = np.prod(expr.blocks) * expr.dtype.itemsize
484+
_, nchunks = get_chunks_idx(expr.shape, expr.chunks)
485+
out = blosc2.empty(expr.shape, dtype=expr.dtype, chunks=expr.chunks, blocks=expr.blocks)
486+
for nchunk in range(nchunks):
487+
chunk = expr.get_chunk(nchunk)
488+
out.schunk.update_chunk(nchunk, chunk)
489+
chunksize_ = int.from_bytes(chunk[4:8], byteorder="little")
490+
blocksize_ = int.from_bytes(chunk[8:12], byteorder="little")
491+
# Sometimes the actual chunksize is smaller than the expected chunks due to padding
492+
assert chunksize <= chunksize_
493+
assert blocksize == blocksize_
494+
np.testing.assert_allclose(out[:], nres)

0 commit comments

Comments
 (0)