Skip to content

Commit 6803d77

Browse files
committed
Improve docstrings for unravel_index & ravel_multi_index
1 parent de28ee6 commit 6803d77

File tree

1 file changed

+95
-7
lines changed

1 file changed

+95
-7
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,9 +1043,62 @@ def ravel(a: ArrayLike, order: str = "C") -> Array:
10431043
return reshape(a, (size(a),), order)
10441044

10451045

1046-
@util.implements(np.ravel_multi_index)
10471046
def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int],
10481047
mode: str = 'raise', order: str = 'C') -> Array:
1048+
"""Convert multi-dimensional indices into flat indices.
1049+
1050+
JAX implementation of :func:`numpy.ravel_multi_index`
1051+
1052+
Args:
1053+
multi_index: sequence of integer arrays containing indices in each dimension.
1054+
dims: sequence of integer sizes; must have ``len(dims) == len(multi_index)``
1055+
mode: how to handle out-of bound indices. Options are
1056+
1057+
- ``"raise"`` (default): raise a ValueError. This mode is incompatible
1058+
with :func:`~jax.jit` or other JAX transformations.
1059+
- ``"clip"``: clip out-of-bound indices to valid range.
1060+
- ``"wrap"``: wrap out-of-bound indices to valid range.
1061+
1062+
order: ``"C"`` (default) or ``"F"``, specify whether to assume C-style
1063+
row-major order or Fortran-style column-major order.
1064+
1065+
Returns:
1066+
array of flattened indices
1067+
1068+
See also:
1069+
:func:`jax.numpy.unravel_index`: inverse of this function.
1070+
1071+
Example:
1072+
Define a 2-dimensional array and a sequence of indices of even values:
1073+
1074+
>>> x = jnp.array([[2., 3., 4.],
1075+
... [5., 6., 7.]])
1076+
>>> indices = jnp.where(x % 2 == 0)
1077+
>>> indices
1078+
(Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))
1079+
>>> x[indices]
1080+
Array([2., 4., 6.], dtype=float32)
1081+
1082+
Compute the flattened indices:
1083+
1084+
>>> indices_flat = jnp.ravel_multi_index(indices, x.shape)
1085+
>>> indices_flat
1086+
Array([0, 2, 4], dtype=int32)
1087+
1088+
These flattened indices can be used to extract the same values from the
1089+
flattened ``x`` array:
1090+
1091+
>>> x_flat = x.ravel()
1092+
>>> x_flat
1093+
Array([2., 3., 4., 5., 6., 7.], dtype=float32)
1094+
>>> x_flat[indices_flat]
1095+
Array([2., 4., 6.], dtype=float32)
1096+
1097+
The original indices can be recovered with :func:`~jax.numpy.unravel_index`:
1098+
1099+
>>> jnp.unravel_index(indices_flat, x.shape)
1100+
(Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))
1101+
"""
10491102
assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
10501103
dims = tuple(core.concrete_or_error(operator.index, d, "in `dims` argument of ravel_multi_index().") for d in dims)
10511104
util.check_arraylike("ravel_multi_index", *multi_index)
@@ -1081,13 +1134,48 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int],
10811134
return result
10821135

10831136

1084-
_UNRAVEL_INDEX_DOC = """\
1085-
Unlike numpy's implementation of unravel_index, negative indices are accepted
1086-
and out-of-bounds indices are clipped into the valid range.
1087-
"""
1088-
1089-
@util.implements(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
10901137
def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]:
1138+
"""Convert flat indices into multi-dimensional indices.
1139+
1140+
JAX implementation of :func:`numpy.unravel_index`. The JAX version differs in
1141+
its treatment of out-of-bound indices: unlike NumPy, negative indices are
1142+
supported, and out-of-bound indices are clipped to the nearest valid value.
1143+
1144+
Args:
1145+
indices: integer array of flat indices
1146+
shape: shape of multidimensional array to index into
1147+
1148+
Returns:
1149+
Tuple of unraveled indices
1150+
1151+
See also:
1152+
:func:`jax.numpy.ravel_multi_index`: Inverse of this function.
1153+
1154+
Examples:
1155+
Start with a 1D array values and indices:
1156+
1157+
>>> x = jnp.array([2., 3., 4., 5., 6., 7.])
1158+
>>> indices = jnp.array([1, 3, 5])
1159+
>>> print(x[indices])
1160+
[3. 5. 7.]
1161+
1162+
Now if ``x`` is reshaped, ``unravel_indices`` can be used to convert
1163+
the flat indices into a tuple of indices that access the same entries:
1164+
1165+
>>> shape = (2, 3)
1166+
>>> x_2D = x.reshape(shape)
1167+
>>> indices_2D = jnp.unravel_index(indices, shape)
1168+
>>> indices_2D
1169+
(Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
1170+
>>> print(x_2D[indices_2D])
1171+
[3. 5. 7.]
1172+
1173+
The inverse function, ``ravel_multi_index``, can be used to obtain the
1174+
original indices:
1175+
1176+
>>> jnp.ravel_multi_index(indices_2D, shape)
1177+
Array([1, 3, 5], dtype=int32)
1178+
"""
10911179
util.check_arraylike("unravel_index", indices)
10921180
indices_arr = asarray(indices)
10931181
# Note: we do not convert shape to an array, because it may be passed as a

0 commit comments

Comments
 (0)