2626from cubed .spec import spec_from_config
2727from cubed .storage .backend import open_backend_array
2828from cubed .types import T_RegularChunks , T_Shape
29- from cubed .utils import (
30- _concatenate2 ,
31- array_memory ,
32- array_size ,
33- get_item ,
34- offset_to_block_id ,
35- to_chunksize ,
36- )
29+ from cubed .utils import _concatenate2 , array_memory , array_size , get_item
30+ from cubed .utils import numblocks as compute_numblocks
31+ from cubed .utils import offset_to_block_id , to_chunksize
3732from cubed .vendor .dask .array .core import normalize_chunks
3833from cubed .vendor .dask .array .utils import validate_axis
3934from cubed .vendor .dask .blockwise import broadcast_dimensions , lol_product
@@ -342,6 +337,77 @@ def general_blockwise(
342337 target_paths = None ,
343338 extra_func_kwargs = None ,
344339 ** kwargs ,
340+ ) -> Union ["Array" , Tuple ["Array" , ...]]:
341+ if has_keyword (func , "block_id" ):
342+ from cubed .array_api .creation_functions import offsets_virtual_array
343+
344+ # Create an array of index offsets with the same chunk structure as the args,
345+ # which we convert to block ids (chunk coordinates) later.
346+ array0 = arrays [0 ]
347+ # note that primitive general_blockwise checks that all chunkss have same numblocks
348+ numblocks = compute_numblocks (chunkss [0 ])
349+ offsets = offsets_virtual_array (numblocks , array0 .spec )
350+ new_arrays = arrays + (offsets ,)
351+
352+ def key_function_with_offset (key_function ):
353+ def wrap (out_key ):
354+ out_coords = out_key [1 :]
355+ offset_in_key = ((offsets .name ,) + out_coords ,)
356+ return key_function (out_key ) + offset_in_key
357+
358+ return wrap
359+
360+ def func_with_block_id (func ):
361+ def wrap (* a , ** kw ):
362+ offset = int (a [- 1 ]) # convert from 0-d array
363+ block_id = offset_to_block_id (offset , numblocks )
364+ return func (* a [:- 1 ], block_id = block_id , ** kw )
365+
366+ return wrap
367+
368+ num_input_blocks = kwargs .pop ("num_input_blocks" , None )
369+ if num_input_blocks is not None :
370+ num_input_blocks = num_input_blocks + (1 ,) # for offsets array
371+
372+ return _general_blockwise (
373+ func_with_block_id (func ),
374+ key_function_with_offset (key_function ),
375+ * new_arrays ,
376+ shapes = shapes ,
377+ dtypes = dtypes ,
378+ chunkss = chunkss ,
379+ target_stores = target_stores ,
380+ target_paths = target_paths ,
381+ extra_func_kwargs = extra_func_kwargs ,
382+ num_input_blocks = num_input_blocks ,
383+ ** kwargs ,
384+ )
385+
386+ return _general_blockwise (
387+ func ,
388+ key_function ,
389+ * arrays ,
390+ shapes = shapes ,
391+ dtypes = dtypes ,
392+ chunkss = chunkss ,
393+ target_stores = target_stores ,
394+ target_paths = target_paths ,
395+ extra_func_kwargs = extra_func_kwargs ,
396+ ** kwargs ,
397+ )
398+
399+
400+ def _general_blockwise (
401+ func ,
402+ key_function ,
403+ * arrays ,
404+ shapes ,
405+ dtypes ,
406+ chunkss ,
407+ target_stores = None ,
408+ target_paths = None ,
409+ extra_func_kwargs = None ,
410+ ** kwargs ,
345411) -> Union ["Array" , Tuple ["Array" , ...]]:
346412 assert len (arrays ) > 0
347413
@@ -504,12 +570,6 @@ def merged_chunk_len_for_indexer(ia, c):
504570 if _is_chunk_aligned_selection (idx ):
505571 # use general_blockwise, which allows more opportunities for optimization than map_direct
506572
507- from cubed .array_api .creation_functions import offsets_virtual_array
508-
509- # general_blockwise doesn't support block_id, so emulate it ourselves
510- numblocks = tuple (map (len , target_chunks ))
511- offsets = offsets_virtual_array (numblocks , x .spec )
512-
513573 def key_function (out_key ):
514574 out_coords = out_key [1 :]
515575
@@ -521,24 +581,17 @@ def key_function(out_key):
521581 in_sel , x .zarray_maybe_lazy .shape , x .zarray_maybe_lazy .chunks
522582 )
523583
524- offset_in_key = ((offsets .name ,) + out_coords ,)
525- return (
526- tuple ((x .name ,) + chunk_coords for (chunk_coords , _ , _ ) in indexer )
527- + offset_in_key
584+ return tuple (
585+ (x .name ,) + chunk_coords for (chunk_coords , _ , _ ) in indexer
528586 )
529587
530- # since selection is chunk-aligned, we know that we only read one block of x
531- num_input_blocks = (1 , 1 ) # x, offsets
532-
533588 out = general_blockwise (
534589 _assemble_index_chunk ,
535590 key_function ,
536591 x ,
537- offsets ,
538592 shapes = [shape ],
539593 dtypes = [x .dtype ],
540594 chunkss = [target_chunks ],
541- num_input_blocks = num_input_blocks ,
542595 target_chunks = target_chunks ,
543596 selection = selection ,
544597 in_shape = x .shape ,
@@ -622,14 +675,8 @@ def _assemble_index_chunk(
622675 selection = None ,
623676 in_shape = None ,
624677 in_chunksize = None ,
678+ block_id = None ,
625679):
626- # last array contains the offset for the block_id
627- offset = int (arrs [- 1 ]) # convert from 0-d array
628- numblocks = tuple (map (len , target_chunks ))
629- block_id = offset_to_block_id (offset , numblocks )
630-
631- arrs = arrs [:- 1 ] # drop offset array
632-
633680 # compute the selection on x required to get the relevant chunk for out_coords
634681 out_coords = block_id
635682 in_sel = _target_chunk_selection (target_chunks , out_coords , selection )
0 commit comments