@@ -225,6 +225,8 @@ class ReduceOp(Enum):
225225 MIN = np .minimum
226226 ANY = np .any
227227 ALL = np .all
228+ ARGMAX = np .argmax
229+ ARGMIN = np .argmin
228230
229231
230232class LazyArrayEnum (Enum ):
@@ -1704,6 +1706,8 @@ def infer_reduction_dtype(dtype, operation):
17041706 return dtype
17051707 elif operation in {ReduceOp .ANY , ReduceOp .ALL }:
17061708 return np .bool_
1709+ elif operation in {ReduceOp .ARGMAX , ReduceOp .ARGMIN }:
1710+ return np .int64
17071711 else :
17081712 raise ValueError (f"Unsupported operation: { operation } " )
17091713
@@ -1758,6 +1762,7 @@ def reduce_slices( # noqa: C901
17581762 The resulting output array.
17591763 """
17601764 out = kwargs .pop ("_output" , None )
1765+ res_out_ = None # temporary required to store max/min for argmax/argmin
17611766 ne_args : dict = kwargs .pop ("_ne_args" , {})
17621767 if ne_args is None :
17631768 ne_args = {}
@@ -1787,12 +1792,15 @@ def reduce_slices( # noqa: C901
17871792 # after slicing, we reduce to calculate shape of output
17881793 if axis is None :
17891794 axis = tuple (range (len (shape_slice )))
1790- elif not isinstance (axis , tuple ):
1795+ elif np . isscalar (axis ):
17911796 axis = (axis ,)
1792- axis = np . array ([ a if a >= 0 else a + len (shape_slice ) for a in axis ] )
1797+ axis = tuple ( a if a >= 0 else a + len (shape_slice ) for a in axis )
17931798 if np .any (mask_slice ):
1794- axis = tuple (axis + np .cumsum (mask_slice )[axis ]) # axis now refers to new shape with dummy dims
1795- reduce_args ["axis" ] = axis
1799+ add_idx = np .cumsum (mask_slice )
1800+ axis = tuple (a + add_idx [a ] for a in axis ) # axis now refers to new shape with dummy dims
1801+ if reduce_args ["axis" ] is not None :
1802+ # conserve as integer if was not tuple originally
1803+ reduce_args ["axis" ] = axis [0 ] if np .isscalar (reduce_args ["axis" ]) else axis
17961804 if keepdims :
17971805 reduced_shape = tuple (1 if i in axis else s for i , s in enumerate (shape_slice ))
17981806 else :
@@ -1868,15 +1876,16 @@ def reduce_slices( # noqa: C901
18681876 cslice = step_handler (cslice , _slice )
18691877 chunks_ = tuple (s .stop - s .start for s in cslice )
18701878 unit_steps = np .all ([s .step == 1 for s in cslice ])
1879+ # Starts for slice
1880+ starts = [s .start if s .start is not None else 0 for s in cslice ]
18711881 if _slice == () and fast_path and unit_steps :
18721882 # Fast path
18731883 full_chunk = chunks_ == chunks
18741884 fill_chunk_operands (
18751885 operands , cslice , chunks_ , full_chunk , aligned , nchunk , iter_disk , chunk_operands , reduc = True
18761886 )
18771887 else :
1878- # Get the starts and stops for the slice
1879- starts = [s .start if s .start is not None else 0 for s in cslice ]
1888+ # Get the stops for the slice
18801889 stops = [s .stop if s .stop is not None else sh for s , sh in zip (cslice , chunks_ , strict = True )]
18811890 # Get the slice of each operand
18821891 for key , value in operands .items ():
@@ -1952,36 +1961,81 @@ def reduce_slices( # noqa: C901
19521961 result = np .any (result , ** reduce_args )
19531962 elif reduce_op == ReduceOp .ALL :
19541963 result = np .all (result , ** reduce_args )
1964+ elif reduce_op == ReduceOp .ARGMAX or reduce_op == ReduceOp .ARGMIN :
1965+ # offset for start of slice
1966+ slice_ref = (
1967+ starts
1968+ if _slice == ()
1969+ else [
1970+ (s - sl .start - np .sign (sl .step )) // sl .step + 1
1971+ for s , sl in zip (starts , _slice , strict = True )
1972+ ]
1973+ )
1974+ result_idx = (
1975+ np .argmin (result , ** reduce_args )
1976+ if reduce_op == ReduceOp .ARGMIN
1977+ else np .argmax (result , ** reduce_args )
1978+ )
1979+ if reduce_args ["axis" ] is None : # indexing into flattened array
1980+ result = result [np .unravel_index (result_idx , shape = result .shape )]
1981+ idx_within_cslice = np .unravel_index (result_idx , shape = chunks_ )
1982+ result_idx = np .ravel_multi_index (
1983+ tuple (o + i for o , i in zip (slice_ref , idx_within_cslice , strict = True )), shape_slice
1984+ )
1985+ else : # axis is an integer
1986+ result = np .take_along_axis (
1987+ result ,
1988+ np .expand_dims (result_idx , axis = reduce_args ["axis" ]) if not keepdims else result_idx ,
1989+ axis = reduce_args ["axis" ],
1990+ )
1991+ result = result if keepdims else result .squeeze (axis = reduce_args ["axis" ])
1992+ result_idx += slice_ref [reduce_args ["axis" ]]
19551993 else :
19561994 result = reduce_op .value .reduce (result , ** reduce_args )
19571995
19581996 if not out_init :
1959- if out is None :
1960- out = convert_none_out (result .dtype , reduce_op , reduced_shape )
1997+ out_ , res_out_ = convert_none_out (result .dtype , reduce_op , reduced_shape )
1998+ if out is not None :
1999+ out [:] = out_
2000+ del out_
19612001 else :
1962- out2 = convert_none_out (result .dtype , reduce_op , reduced_shape )
1963- out [:] = out2
1964- del out2
2002+ out = out_
19652003 out_init = True
19662004
19672005 # Update the output array with the result
19682006 if reduce_op == ReduceOp .ANY :
19692007 out [reduced_slice ] += result
19702008 elif reduce_op == ReduceOp .ALL :
19712009 out [reduced_slice ] *= result
2010+ elif res_out_ is not None : # i.e. ReduceOp.ARGMAX or ReduceOp.ARGMIN
2011+ # need lowest index for which optimum attained
2012+ cond = (res_out_ [reduced_slice ] == result ) & (result_idx < out [reduced_slice ])
2013+ if reduce_op == ReduceOp .ARGMAX :
2014+ cond |= res_out_ [reduced_slice ] < result
2015+ else : # ARGMIN
2016+ cond |= res_out_ [reduced_slice ] > result
2017+ if reduced_slice == ():
2018+ out = np .where (cond , result_idx , out [reduced_slice ])
2019+ res_out_ = np .where (cond , result , res_out_ [reduced_slice ])
2020+ else :
2021+ out [reduced_slice ] = np .where (cond , result_idx , out [reduced_slice ])
2022+ res_out_ [reduced_slice ] = np .where (cond , result , res_out_ [reduced_slice ])
19722023 else :
19732024 if reduced_slice == ():
19742025 out = reduce_op .value (out , result )
19752026 else :
19762027 out [reduced_slice ] = reduce_op .value (out [reduced_slice ], result )
19772028
2029+ # No longer need res_out_
2030+ del res_out_
2031+
19782032 if out is None :
1979- if reduce_op in (ReduceOp .MIN , ReduceOp .MAX ):
1980- raise ValueError ("zero-size array in min/max reduction operation is not supported" )
2033+ if reduce_op in (ReduceOp .MIN , ReduceOp .MAX , ReduceOp . ARGMIN , ReduceOp . ARGMAX ):
2034+ raise ValueError ("zero-size array in (arg-) min/max reduction operation is not supported" )
19812035 if dtype is None :
19822036 # We have no hint here, so choose a default dtype
19832037 dtype = np .float64
1984- out = convert_none_out (dtype , reduce_op , reduced_shape )
2038+ out , _ = convert_none_out (dtype , reduce_op , reduced_shape )
19852039
19862040 final_mask = tuple (np .where (mask_slice )[0 ])
19872041 if np .any (mask_slice ): # remove dummy dims
@@ -2013,7 +2067,19 @@ def convert_none_out(dtype, reduce_op, reduced_shape):
20132067 out = np .zeros (reduced_shape , dtype = np .bool_ )
20142068 elif reduce_op == ReduceOp .ALL :
20152069 out = np .ones (reduced_shape , dtype = np .bool_ )
2016- return out
2070+ elif reduce_op == ReduceOp .ARGMIN :
2071+ if np .issubdtype (dtype , np .integer ):
2072+ res_out_ = np .iinfo (dtype ).max * np .ones (reduced_shape , dtype = dtype )
2073+ else :
2074+ res_out_ = np .inf * np .ones (reduced_shape , dtype = dtype )
2075+ out = (np .zeros (reduced_shape , dtype = blosc2 .DEFAULT_INDEX ), res_out_ )
2076+ elif reduce_op == ReduceOp .ARGMAX :
2077+ if np .issubdtype (dtype , np .integer ):
2078+ res_out_ = np .iinfo (dtype ).min * np .ones (reduced_shape , dtype = dtype )
2079+ else :
2080+ res_out_ = - np .inf * np .ones (reduced_shape , dtype = dtype )
2081+ out = (np .zeros (reduced_shape , dtype = blosc2 .DEFAULT_INDEX ), res_out_ )
2082+ return out if isinstance (out , tuple ) else (out , None )
20172083
20182084
20192085def chunked_eval ( # noqa: C901
@@ -2707,6 +2773,22 @@ def all(self, axis=None, keepdims=False, **kwargs):
27072773 }
27082774 return self .compute (_reduce_args = reduce_args , ** kwargs )
27092775
2776+ def argmax (self , axis = None , keepdims = False , ** kwargs ):
2777+ reduce_args = {
2778+ "op" : ReduceOp .ARGMAX ,
2779+ "axis" : axis ,
2780+ "keepdims" : keepdims ,
2781+ }
2782+ return self .compute (_reduce_args = reduce_args , ** kwargs )
2783+
2784+ def argmin (self , axis = None , keepdims = False , ** kwargs ):
2785+ reduce_args = {
2786+ "op" : ReduceOp .ARGMIN ,
2787+ "axis" : axis ,
2788+ "keepdims" : keepdims ,
2789+ }
2790+ return self .compute (_reduce_args = reduce_args , ** kwargs )
2791+
27102792 def _eval_constructor (self , expression , constructor , operands ):
27112793 """Evaluate a constructor function inside a string expression."""
27122794
0 commit comments