Skip to content

Commit f5ea95a

Browse files
committed
Add properties to get chunk and shard slices
1 parent c8d8e64 commit f5ea95a

File tree

4 files changed

+97
-1
lines changed

4 files changed

+97
-1
lines changed

changes/3573.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added new ``Array.chunk_slices`` and ``Array.shard_slices`` to get slices aligned with array chunks and shards respectively.

docs/user-guide/arrays.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,26 @@ In this example a shard shape of (1000, 1000) and a chunk shape of (100, 100) is
566566
This means that `10*10` chunks are stored in each shard, and there are `10*10` shards in total.
567567
Without the `shards` argument, there would be 10,000 chunks stored as individual files.
568568

569+
## Accessing chunks and shards
570+
571+
Arrays have useful properties for accessing data aligned to chunks and shards.
572+
This can be useful for getting slices that can be used to write to shards in parallel, or read from chunks in parallel.
573+
574+
```python exec="true" session="arrays" source="above" result="ansi"
575+
a = zarr.create_array(store={}, shape=(100, 50), shards=(50, 40), chunks=(25, 20), dtype='uint8')
576+
577+
print("All shard slices:")
578+
for shard_slice in a.shard_slices:
579+
print(shard_slice)
580+
# shard_data = a[shard_slice]
581+
582+
print("All chunk slices:")
583+
for chunk_slice in a.chunk_slices:
584+
print(chunk_slice)
585+
# chunk_data = a[chunk_slice]
586+
```
587+
588+
569589
## Missing features in 3.0
570590

571591
The following features have not been ported to 3.0 yet.

src/zarr/core/array.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import warnings
55
from asyncio import gather
6-
from collections.abc import Iterable, Mapping
6+
from collections.abc import Generator, Iterable, Mapping
77
from dataclasses import dataclass, field, replace
88
from itertools import starmap
99
from logging import getLogger
@@ -1381,6 +1381,32 @@ async def example():
13811381
async def nbytes_stored(self) -> int:
13821382
return await self.store_path.store.getsize_prefix(self.store_path.path)
13831383

1384+
@property
1385+
def chunk_slices(self) -> Generator[tuple[slice, ...]]:
1386+
"""
1387+
Iterator over all chunks.
1388+
1389+
Yields
1390+
------
1391+
chunk_slice :
1392+
Slice for each chunk in this array.
1393+
"""
1394+
yield from self._iter_chunk_regions()
1395+
1396+
@property
1397+
def shard_slices(self) -> Generator[tuple[slice, ...]]:
1398+
"""
1399+
Iterator over all shards.
1400+
1401+
This can be used to loop through and index every shard of an array.
1402+
1403+
Yields
1404+
------
1405+
shard_slice :
1406+
Slice for each shard in this array.
1407+
"""
1408+
yield from self._iter_shard_regions()
1409+
13841410
def _iter_chunk_coords(
13851411
self, *, origin: Sequence[int] | None = None, selection_shape: Sequence[int] | None = None
13861412
) -> Iterator[tuple[int, ...]]:
@@ -2355,6 +2381,34 @@ def shards(self) -> tuple[int, ...] | None:
23552381
"""
23562382
return self._async_array.shards
23572383

2384+
@property
2385+
def chunk_slices(self) -> Generator[tuple[slice, ...]]:
2386+
"""
2387+
Iterator over all chunks.
2388+
2389+
This can be used to loop through and index every chunk of an array.
2390+
2391+
Yields
2392+
------
2393+
chunk_slice :
2394+
Slice for each chunk in this array.
2395+
"""
2396+
yield from self._async_array.chunk_slices
2397+
2398+
@property
2399+
def shard_slices(self) -> Generator[tuple[slice, ...]]:
2400+
"""
2401+
Iterator over all shards.
2402+
2403+
This can be used to loop through and index every shard of an array.
2404+
2405+
Yields
2406+
------
2407+
shard_slice :
2408+
Slice for each shard in this array.
2409+
"""
2410+
yield from self._async_array.shard_slices
2411+
23582412
@property
23592413
def size(self) -> int:
23602414
"""Returns the total number of elements in the array.

tests/test_array.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,3 +2153,24 @@ def test_create_array_with_data_num_gets(
21532153
# one get for the metadata and one per shard.
21542154
# Note: we don't actually need one get per shard, but this is the current behavior
21552155
assert store.counter["get"] == 1 + num_shards
2156+
2157+
2158+
@pytest.mark.parametrize("shards", [None, (4, 6)])
2159+
def test_chunk_slices(shards: None | tuple[int, ...]) -> None:
2160+
arr = zarr.create_array(store={}, shape=(4, 8), dtype="uint8", chunks=(2, 3), shards=shards)
2161+
assert list(arr.chunk_slices) == [
2162+
(slice(0, 2, 1), slice(0, 3, 1)),
2163+
(slice(0, 2, 1), slice(3, 6, 1)),
2164+
(slice(0, 2, 1), slice(6, 8, 1)),
2165+
(slice(2, 4, 1), slice(0, 3, 1)),
2166+
(slice(2, 4, 1), slice(3, 6, 1)),
2167+
(slice(2, 4, 1), slice(6, 8, 1)),
2168+
]
2169+
2170+
2171+
def test_shard_slices() -> None:
2172+
arr = zarr.create_array(store={}, shape=(4, 8), dtype="uint8", chunks=(2, 3), shards=(4, 6))
2173+
assert list(arr.shard_slices) == [
2174+
(slice(0, 4, 1), slice(0, 6, 1)),
2175+
(slice(0, 4, 1), slice(6, 8, 1)),
2176+
]

0 commit comments

Comments
 (0)