diff --git a/cubed/__init__.py b/cubed/__init__.py index 27d0b40c..dadb2964 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -341,6 +341,10 @@ __all__ += ["argmax", "argmin", "count_nonzero", "searchsorted", "where"] +from .array_api.set_functions import isin + +__all__ += ["isin"] + from .array_api.statistical_functions import ( cumulative_prod, cumulative_sum, diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index 29fa8b8b..773d47ea 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -272,6 +272,10 @@ __all__ += ["argmax", "argmin", "count_nonzero", "searchsorted", "where"] +from .set_functions import isin + +__all__ += ["isin"] + from .statistical_functions import ( cumulative_prod, cumulative_sum, diff --git a/cubed/array_api/set_functions.py b/cubed/array_api/set_functions.py new file mode 100644 index 00000000..431dd1fd --- /dev/null +++ b/cubed/array_api/set_functions.py @@ -0,0 +1,31 @@ +from cubed.array_api.utility_functions import any as cubed_any +from cubed.backend_array_api import namespace as nxp +from cubed.core import blockwise + + +def isin(x1, x2, /, *, invert=False): + # based on dask isin + + x1_axes = tuple(range(x1.ndim)) + x2_axes = tuple(i + x1.ndim for i in range(x2.ndim)) + mapped = blockwise( + _isin, + x1_axes + x2_axes, + x1, + x1_axes, + x2, + x2_axes, + dtype=nxp.bool, + adjust_chunks={axis: lambda _: 1 for axis in x2_axes}, + ) + + result = cubed_any(mapped, axis=x2_axes) + if invert: + result = ~result + return result + + +def _isin(a1, a2): + a1_flattened = nxp.reshape(a1, (-1,)) + values = nxp.isin(a1_flattened, a2) + return nxp.reshape(values, a1.shape + (1,) * a2.ndim) diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index c23b22df..dd8a260c 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -847,6 +847,36 @@ def test_where_scalars(): xp.where(condition, 0, 1) +# Set functions + +@pytest.mark.parametrize(("low", "high"), [(0, 10)]) +@pytest.mark.parametrize( + ("elements_shape", "elements_chunks"), + [((10,), (5,)), ((10,), (3,)), ((4, 5), (3, 2)), ((20, 20), (4, 5))], +) +@pytest.mark.parametrize( + ("test_shape", "test_chunks"), + [((10,), (5,)), ((10,), (3,)), ((4, 5), (3, 2)), ((20, 20), (4, 5))], +) +@pytest.mark.parametrize("invert", [True, False]) +def test_isin( + low, high, elements_shape, elements_chunks, test_shape, test_chunks, invert +): + # based on dask test + rng = np.random.default_rng() + + a1 = rng.integers(low, high, size=elements_shape) + c1 = cubed.from_array(a1, chunks=elements_chunks) + + a2 = rng.integers(low, high, size=test_shape) - 5 + c2 = cubed.from_array(a2, chunks=test_chunks) + + r_a = np.isin(a1, a2, invert=invert) + r_c = xp.isin(c1, c2, invert=invert) + + assert_array_equal(r_c, r_a) + + # Statistical functions diff --git a/docs/api.rst b/docs/api.rst index 6afacf4f..49a3135b 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -58,6 +58,7 @@ These are functions that have not (yet) been included in the Python Array API St :nosignatures: :toctree: generated/ + isin nanmean nansum pad