Skip to content

Commit 9e80c30

Browse files
committed
Improve documentation for jnp.take & jnp.take_along_axis
1 parent 2b6bcb5 commit 9e80c30

File tree

1 file changed

+143
-32
lines changed

1 file changed

+143
-32
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 143 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5608,27 +5608,6 @@ def unpackbits(
56085608
return swapaxes(unpacked, axis, -1)
56095609

56105610

5611-
@util.implements(np.take, skip_params=['out'],
5612-
lax_description="""
5613-
By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound
5614-
index semantics can be specified via the ``mode`` parameter (see below).
5615-
""",
5616-
extra_params="""
5617-
mode : string, default="fill"
5618-
Out-of-bounds indexing mode. The default mode="fill" returns invalid values
5619-
(e.g. NaN) for out-of bounds indices (see also ``fill_value`` below).
5620-
For more discussion of mode options, see :attr:`jax.numpy.ndarray.at`.
5621-
fill_value : optional
5622-
The fill value to return for out-of-bounds slices when mode is 'fill'. Ignored
5623-
otherwise. Defaults to NaN for inexact types, the largest negative value for
5624-
signed types, the largest positive value for unsigned types, and True for booleans.
5625-
unique_indices : bool, default=False
5626-
If True, the implementation will assume that the indices are unique,
5627-
which can result in more efficient execution on some backends.
5628-
indices_are_sorted : bool, default=False
5629-
If True, the implementation will assume that the indices are sorted in
5630-
ascending order, which can lead to more efficient execution on some backends.
5631-
""")
56325611
def take(
56335612
a: ArrayLike,
56345613
indices: ArrayLike,
@@ -5639,6 +5618,78 @@ def take(
56395618
indices_are_sorted: bool = False,
56405619
fill_value: StaticScalar | None = None,
56415620
) -> Array:
5621+
"""Take elements from an array.
5622+
5623+
JAX implementation of :func:`numpy.take`, implemented in terms of
5624+
:func:`jax.lax.gather`. JAX's behavior differs from NumPy in the case
5625+
of out-of-bound indices; see the ``mode`` parameter below.
5626+
5627+
Args:
5628+
a: array from which to take values.
5629+
indices: N-dimensional array of integer indices of values to take from the array.
5630+
axis: the axis along which to take values. If not specified, the array will
5631+
be flattened before indexing is applied.
5632+
mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default
5633+
``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices;
5634+
the ``fill_value`` argument gives control over this value. For more discussion
5635+
of ``mode`` options, see :attr:`jax.numpy.ndarray.at`.
5636+
fill_value: The fill value to return for out-of-bounds slices when mode is 'fill'.
5637+
Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for
5638+
signed types, the largest positive value for unsigned types, and True for booleans.
5639+
unique_indices: If True, the implementation will assume that the indices are unique,
5640+
which can result in more efficient execution on some backends. If set to True and
5641+
indices are not unique, the output is undefined.
5642+
indices_are_sorted : If True, the implementation will assume that the indices are
5643+
sorted in ascending order, which can lead to more efficient execution on some
5644+
backends. If set to True and indices are not sorted, the output is undefined.
5645+
5646+
Returns:
5647+
Array of values extracted from ``a``.
5648+
5649+
See also:
5650+
- :attr:`jax.numpy.ndarray.at`: take values via indexing syntax.
5651+
- :func:`jax.numpy.take_along_axis`: take values along an axis
5652+
5653+
Example:
5654+
>>> x = jnp.array([[1., 2., 3.],
5655+
... [4., 5., 6.]])
5656+
>>> indices = jnp.array([2, 0])
5657+
5658+
Passing no axis results in indexing into the flattened array:
5659+
5660+
>>> jnp.take(x, indices)
5661+
Array([3., 1.], dtype=float32)
5662+
>>> x.ravel()[indices] # equivalent indexing syntax
5663+
Array([3., 1.], dtype=float32)
5664+
5665+
Passing an axis results ind applying the index to every subarray along the axis:
5666+
5667+
>>> jnp.take(x, indices, axis=1)
5668+
Array([[3., 1.],
5669+
[6., 4.]], dtype=float32)
5670+
>>> x[:, indices] # equivalent indexing syntax
5671+
Array([[3., 1.],
5672+
[6., 4.]], dtype=float32)
5673+
5674+
Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`:
5675+
5676+
>>> jnp.take(x, indices, axis=0)
5677+
Array([[nan, nan, nan],
5678+
[ 1., 2., 3.]], dtype=float32)
5679+
>>> x.at[indices].get(mode='fill', fill_value=jnp.nan) # equivalent indexing syntax
5680+
Array([[nan, nan, nan],
5681+
[ 1., 2., 3.]], dtype=float32)
5682+
5683+
This default out-of-bound behavior can be adjusted using the ``mode`` parameter, for
5684+
example, we can instead clip to the last valid value:
5685+
5686+
>>> jnp.take(x, indices, axis=0, mode='clip')
5687+
Array([[4., 5., 6.],
5688+
[1., 2., 3.]], dtype=float32)
5689+
>>> x.at[indices].get(mode='clip') # equivalent indexing syntax
5690+
Array([[4., 5., 6.],
5691+
[1., 2., 3.]], dtype=float32)
5692+
"""
56425693
return _take(a, indices, None if axis is None else operator.index(axis), out,
56435694
mode, unique_indices=unique_indices, indices_are_sorted=indices_are_sorted,
56445695
fill_value=fill_value)
@@ -5714,17 +5765,6 @@ def _normalize_index(index, axis_size):
57145765
return lax.select(index < 0, lax.add(index, axis_size_val), index)
57155766

57165767

5717-
TAKE_ALONG_AXIS_DOC = """
5718-
Unlike :func:`numpy.take_along_axis`, :func:`jax.numpy.take_along_axis` takes
5719-
an optional ``mode`` parameter controlling how out-of-bounds indices should be
5720-
handled. By default, out-of-bounds indices yield invalid values (e.g., ``NaN``).
5721-
See :attr:`jax.numpy.ndarray.at` for further discussion of out-of-bounds
5722-
indexing in JAX.
5723-
"""
5724-
5725-
5726-
@util.implements(np.take_along_axis, update_doc=False,
5727-
lax_description=TAKE_ALONG_AXIS_DOC)
57285768
@partial(jit, static_argnames=('axis', 'mode', 'fill_value'))
57295769
def take_along_axis(
57305770
arr: ArrayLike,
@@ -5733,6 +5773,77 @@ def take_along_axis(
57335773
mode: str | lax.GatherScatterMode | None = None,
57345774
fill_value: StaticScalar | None = None,
57355775
) -> Array:
5776+
"""Take elements from an array.
5777+
5778+
JAX implementation of :func:`numpy.take_along_axis`, implemented in
5779+
terms of :func:`jax.lax.gather`. JAX's behavior differs from NumPy
5780+
in the case of out-of-bound indices; see the ``mode`` parameter below.
5781+
5782+
Args:
5783+
a: array from which to take values.
5784+
indices: array of integer indices. If ``axis`` is ``None``, must be one-dimensional.
5785+
If ``axis`` is not None, must have ``a.ndim == indices.ndim``, and ``a`` must be
5786+
broadcast-compaible with ``indices`` along dimensions other than ``axis``.
5787+
axis: the axis along which to take values. If not specified, the array will
5788+
be flattened before indexing is applied.
5789+
mode: Out-of-bounds indexing mode, either ``"fill"`` or ``"clip"``. The default
5790+
``mode="fill"`` returns invalid values (e.g. NaN) for out-of bounds indices.
5791+
For more discussion of ``mode`` options, see :attr:`jax.numpy.ndarray.at`.
5792+
5793+
Returns:
5794+
Array of values extracted from ``a``.
5795+
5796+
See also:
5797+
- :attr:`jax.numpy.ndarray.at`: take values via indexing syntax.
5798+
- :func:`jax.numpy.take`: take the same indices along every axis slice.
5799+
5800+
Examples:
5801+
>>> x = jnp.array([[1., 2., 3.],
5802+
... [4., 5., 6.]])
5803+
>>> indices = jnp.array([[0, 2],
5804+
... [1, 0]])
5805+
>>> jnp.take_along_axis(x, indices, axis=1)
5806+
Array([[1., 3.],
5807+
[5., 4.]], dtype=float32)
5808+
>>> x[jnp.arange(2)[:, None], indices] # equivalent via indexing syntax
5809+
Array([[1., 3.],
5810+
[5., 4.]], dtype=float32)
5811+
5812+
Out-of-bound indices fill with invalid values. For float inputs, this is `NaN`:
5813+
5814+
>>> indices = jnp.array([[1, 0, 2]])
5815+
>>> jnp.take_along_axis(x, indices, axis=0)
5816+
Array([[ 4., 2., nan]], dtype=float32)
5817+
>>> x.at[indices, jnp.arange(3)].get(
5818+
... mode='fill', fill_value=jnp.nan) # equivalent via indexing syntax
5819+
Array([[ 4., 2., nan]], dtype=float32)
5820+
5821+
``take_along_axis`` is helpful for extracting values from multi-dimensional
5822+
argsorts and arg reductions. For, here we compute :func:`~jax.numpy.argsort`
5823+
indices along an axis, and use ``take_along_axis`` to construct the sorted
5824+
array:
5825+
5826+
>>> x = jnp.array([[5, 3, 4],
5827+
... [2, 7, 6]])
5828+
>>> indices = jnp.argsort(x, axis=1)
5829+
>>> indices
5830+
Array([[1, 2, 0],
5831+
[0, 2, 1]], dtype=int32)
5832+
>>> jnp.take_along_axis(x, indices, axis=1)
5833+
Array([[3, 4, 5],
5834+
[2, 6, 7]], dtype=int32)
5835+
5836+
Similarly, we can use :func:`~jax.numpy.argmin` with ``keepdims=True`` and
5837+
use ``take_along_axis`` to extract the minimum value:
5838+
5839+
>>> idx = jnp.argmin(x, axis=1, keepdims=True)
5840+
>>> idx
5841+
Array([[1],
5842+
[0]], dtype=int32)
5843+
>>> jnp.take_along_axis(x, idx, axis=1)
5844+
Array([[3],
5845+
[2]], dtype=int32)
5846+
"""
57365847
util.check_arraylike("take_along_axis", arr, indices)
57375848
a = asarray(arr)
57385849
index_dtype = dtypes.dtype(indices)

0 commit comments

Comments
 (0)