@@ -5608,27 +5608,6 @@ def unpackbits(
56085608 return swapaxes (unpacked , axis , - 1 )
56095609
56105610
5611- @util .implements (np .take , skip_params = ['out' ],
5612- lax_description = """
5613- By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound
5614- index semantics can be specified via the ``mode`` parameter (see below).
5615- """ ,
5616- extra_params = """
5617- mode : string, default="fill"
5618- Out-of-bounds indexing mode. The default mode="fill" returns invalid values
5619- (e.g. NaN) for out-of bounds indices (see also ``fill_value`` below).
5620- For more discussion of mode options, see :attr:`jax.numpy.ndarray.at`.
5621- fill_value : optional
5622- The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored
5623- otherwise. Defaults to NaN for inexact types, the largest negative value for
5624- signed types, the largest positive value for unsigned types, and True for booleans.
5625- unique_indices : bool, default=False
5626- If True, the implementation will assume that the indices are unique,
5627- which can result in more efficient execution on some backends.
5628- indices_are_sorted : bool, default=False
5629- If True, the implementation will assume that the indices are sorted in
5630- ascending order, which can lead to more efficient execution on some backends.
5631- """ )
56325611def take (
56335612 a : ArrayLike ,
56345613 indices : ArrayLike ,
@@ -5639,6 +5618,78 @@ def take(
56395618 indices_are_sorted : bool = False ,
56405619 fill_value : StaticScalar | None = None ,
56415620) -> Array :
5621+ """Take elements from an array.
5622+
5623+ JAX implementation of :func:`numpy.take`, implemented in terms of
5624+ :func:`jax.lax.gather`. JAX's behavior differs from NumPy in the case
5625+ of out-of-bound indices; see the ``mode`` parameter below.
5626+
5627+ Args:
5628+ a: array from which to take values.
5629+ indices: N-dimensional array of integer indices of values to take from the array.
5630+ axis: the axis along which to take values. If not specified, the array will
5631+ be flattened before indexing is applied.
5632+ mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default
5633+ ``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices;
5634+ the ``fill_value`` argument gives control over this value. For more discussion
5635+ of ``mode`` options, see :attr:`jax.numpy.ndarray.at`.
5636+ fill_value: The fill value to return for out-of-bounds slices when mode is 'fill'.
5637+ Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for
5638+ signed types, the largest positive value for unsigned types, and True for booleans.
5639+ unique_indices: If True, the implementation will assume that the indices are unique,
5640+ which can result in more efficient execution on some backends. If set to True and
5641+ indices are not unique, the output is undefined.
5642+ indices_are_sorted : If True, the implementation will assume that the indices are
5643+ sorted in ascending order, which can lead to more efficient execution on some
5644+ backends. If set to True and indices are not sorted, the output is undefined.
5645+
5646+ Returns:
5647+ Array of values extracted from ``a``.
5648+
5649+ See also:
5650+ - :attr:`jax.numpy.ndarray.at`: take values via indexing syntax.
5651+ - :func:`jax.numpy.take_along_axis`: take values along an axis
5652+
5653+ Example:
5654+ >>> x = jnp.array([[1., 2., 3.],
5655+ ... [4., 5., 6.]])
5656+ >>> indices = jnp.array([2, 0])
5657+
5658+ Passing no axis results in indexing into the flattened array:
5659+
5660+ >>> jnp.take(x, indices)
5661+ Array([3., 1.], dtype=float32)
5662+ >>> x.ravel()[indices] # equivalent indexing syntax
5663+ Array([3., 1.], dtype=float32)
5664+
5665+ Passing an axis results ind applying the index to every subarray along the axis:
5666+
5667+ >>> jnp.take(x, indices, axis=1)
5668+ Array([[3., 1.],
5669+ [6., 4.]], dtype=float32)
5670+ >>> x[:, indices] # equivalent indexing syntax
5671+ Array([[3., 1.],
5672+ [6., 4.]], dtype=float32)
5673+
5674+ Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`:
5675+
5676+ >>> jnp.take(x, indices, axis=0)
5677+ Array([[nan, nan, nan],
5678+ [ 1., 2., 3.]], dtype=float32)
5679+ >>> x.at[indices].get(mode='fill', fill_value=jnp.nan) # equivalent indexing syntax
5680+ Array([[nan, nan, nan],
5681+ [ 1., 2., 3.]], dtype=float32)
5682+
5683+ This default out-of-bound behavior can be adjusted using the ``mode`` parameter, for
5684+ example, we can instead clip to the last valid value:
5685+
5686+ >>> jnp.take(x, indices, axis=0, mode='clip')
5687+ Array([[4., 5., 6.],
5688+ [1., 2., 3.]], dtype=float32)
5689+ >>> x.at[indices].get(mode='clip') # equivalent indexing syntax
5690+ Array([[4., 5., 6.],
5691+ [1., 2., 3.]], dtype=float32)
5692+ """
56425693 return _take (a , indices , None if axis is None else operator .index (axis ), out ,
56435694 mode , unique_indices = unique_indices , indices_are_sorted = indices_are_sorted ,
56445695 fill_value = fill_value )
@@ -5714,17 +5765,6 @@ def _normalize_index(index, axis_size):
57145765 return lax .select (index < 0 , lax .add (index , axis_size_val ), index )
57155766
57165767
5717- TAKE_ALONG_AXIS_DOC = """
5718- Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes
5719- an optional ``mode`` parameter controlling how out-of-bounds indices should be
5720- handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``).
5721- See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds
5722- indexing in JAX.
5723- """
5724-
5725-
5726- @util .implements (np .take_along_axis , update_doc = False ,
5727- lax_description = TAKE_ALONG_AXIS_DOC )
57285768@partial (jit , static_argnames = ('axis' , 'mode' , 'fill_value' ))
57295769def take_along_axis (
57305770 arr : ArrayLike ,
@@ -5733,6 +5773,77 @@ def take_along_axis(
57335773 mode : str | lax .GatherScatterMode | None = None ,
57345774 fill_value : StaticScalar | None = None ,
57355775) -> Array :
5776+ """Take elements from an array.
5777+
5778+ JAX implementation of :func:`numpy.take_along_axis`, implemented in
5779+ terms of :func:`jax.lax.gather`. JAX's behavior differs from NumPy
5780+ in the case of out-of-bound indices; see the ``mode`` parameter below.
5781+
5782+ Args:
5783+ a: array from which to take values.
5784+ indices: array of integer indices. If ``axis`` is ``None``, must be one-dimensional.
5785+ If ``axis`` is not None, must have ``a.ndim == indices.ndim``, and ``a`` must be
5786+ broadcast-compaible with ``indices`` along dimensions other than ``axis``.
5787+ axis: the axis along which to take values. If not specified, the array will
5788+ be flattened before indexing is applied.
5789+ mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default
5790+ ``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices.
5791+ For more discussion of ``mode`` options, see :attr:`jax.numpy.ndarray.at`.
5792+
5793+ Returns:
5794+ Array of values extracted from ``a``.
5795+
5796+ See also:
5797+ - :attr:`jax.numpy.ndarray.at`: take values via indexing syntax.
5798+ - :func:`jax.numpy.take`: take the same indices along every axis slice.
5799+
5800+ Examples:
5801+ >>> x = jnp.array([[1., 2., 3.],
5802+ ... [4., 5., 6.]])
5803+ >>> indices = jnp.array([[0, 2],
5804+ ... [1, 0]])
5805+ >>> jnp.take_along_axis(x, indices, axis=1)
5806+ Array([[1., 3.],
5807+ [5., 4.]], dtype=float32)
5808+ >>> x[jnp.arange(2)[:, None], indices] # equivalent via indexing syntax
5809+ Array([[1., 3.],
5810+ [5., 4.]], dtype=float32)
5811+
5812+ Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`:
5813+
5814+ >>> indices = jnp.array([[1, 0, 2]])
5815+ >>> jnp.take_along_axis(x, indices, axis=0)
5816+ Array([[ 4., 2., nan]], dtype=float32)
5817+ >>> x.at[indices, jnp.arange(3)].get(
5818+ ... mode='fill', fill_value=jnp.nan) # equivalent via indexing syntax
5819+ Array([[ 4., 2., nan]], dtype=float32)
5820+
5821+ ``take_along_axis`` is helpful for extracting values from multi-dimensional
5822+ argsorts and arg reductions. For, here we compute :func:`~jax.numpy.argsort`
5823+ indices along an axis, and use ``take_along_axis`` to construct the sorted
5824+ array:
5825+
5826+ >>> x = jnp.array([[5, 3, 4],
5827+ ... [2, 7, 6]])
5828+ >>> indices = jnp.argsort(x, axis=1)
5829+ >>> indices
5830+ Array([[1, 2, 0],
5831+ [0, 2, 1]], dtype=int32)
5832+ >>> jnp.take_along_axis(x, indices, axis=1)
5833+ Array([[3, 4, 5],
5834+ [2, 6, 7]], dtype=int32)
5835+
5836+ Similarly, we can use :func:`~jax.numpy.argmin` with ``keepdims=True`` and
5837+ use ``take_along_axis`` to extract the minimum value:
5838+
5839+ >>> idx = jnp.argmin(x, axis=1, keepdims=True)
5840+ >>> idx
5841+ Array([[1],
5842+ [0]], dtype=int32)
5843+ >>> jnp.take_along_axis(x, idx, axis=1)
5844+ Array([[3],
5845+ [2]], dtype=int32)
5846+ """
57365847 util .check_arraylike ("take_along_axis" , arr , indices )
57375848 a = asarray (arr )
57385849 index_dtype = dtypes .dtype (indices )
0 commit comments