@@ -1043,9 +1043,62 @@ def ravel(a: ArrayLike, order: str = "C") -> Array:
10431043 return reshape (a , (size (a ),), order )
10441044
10451045
1046- @util .implements (np .ravel_multi_index )
10471046def ravel_multi_index (multi_index : Sequence [ArrayLike ], dims : Sequence [int ],
10481047 mode : str = 'raise' , order : str = 'C' ) -> Array :
1048+ """Convert multi-dimensional indices into flat indices.
1049+
1050+ JAX implementation of :func:`numpy.ravel_multi_index`
1051+
1052+ Args:
1053+ multi_index: sequence of integer arrays containing indices in each dimension.
1054+ dims: sequence of integer sizes; must have ``len(dims) == len(multi_index)``
1055+ mode: how to handle out-of bound indices. Options are
1056+
1057+ - ``"raise"`` (default): raise a ValueError. This mode is incompatible
1058+ with :func:`~jax.jit` or other JAX transformations.
1059+ - ``"clip"``: clip out-of-bound indices to valid range.
1060+ - ``"wrap"``: wrap out-of-bound indices to valid range.
1061+
1062+ order: ``"C"`` (default) or ``"F"``, specify whether to assume C-style
1063+ row-major order or Fortran-style column-major order.
1064+
1065+ Returns:
1066+ array of flattened indices
1067+
1068+ See also:
1069+ :func:`jax.numpy.unravel_index`: inverse of this function.
1070+
1071+ Example:
1072+ Define a 2-dimensional array and a sequence of indices of even values:
1073+
1074+ >>> x = jnp.array([[2., 3., 4.],
1075+ ... [5., 6., 7.]])
1076+ >>> indices = jnp.where(x % 2 == 0)
1077+ >>> indices
1078+ (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))
1079+ >>> x[indices]
1080+ Array([2., 4., 6.], dtype=float32)
1081+
1082+ Compute the flattened indices:
1083+
1084+ >>> indices_flat = jnp.ravel_multi_index(indices, x.shape)
1085+ >>> indices_flat
1086+ Array([0, 2, 4], dtype=int32)
1087+
1088+ These flattened indices can be used to extract the same values from the
1089+ flattened ``x`` array:
1090+
1091+ >>> x_flat = x.ravel()
1092+ >>> x_flat
1093+ Array([2., 3., 4., 5., 6., 7.], dtype=float32)
1094+ >>> x_flat[indices_flat]
1095+ Array([2., 4., 6.], dtype=float32)
1096+
1097+ The original indices can be recovered with :func:`~jax.numpy.unravel_index`:
1098+
1099+ >>> jnp.unravel_index(indices_flat, x.shape)
1100+ (Array([0, 0, 1], dtype=int32), Array([0, 2, 1], dtype=int32))
1101+ """
10491102 assert len (multi_index ) == len (dims ), f"len(multi_index)={ len (multi_index )} != len(dims)={ len (dims )} "
10501103 dims = tuple (core .concrete_or_error (operator .index , d , "in `dims` argument of ravel_multi_index()." ) for d in dims )
10511104 util .check_arraylike ("ravel_multi_index" , * multi_index )
@@ -1081,13 +1134,48 @@ def ravel_multi_index(multi_index: Sequence[ArrayLike], dims: Sequence[int],
10811134 return result
10821135
10831136
1084- _UNRAVEL_INDEX_DOC = """\
1085- Unlike numpy's implementation of unravel_index, negative indices are accepted
1086- and out-of-bounds indices are clipped into the valid range.
1087- """
1088-
1089- @util .implements (np .unravel_index , lax_description = _UNRAVEL_INDEX_DOC )
10901137def unravel_index (indices : ArrayLike , shape : Shape ) -> tuple [Array , ...]:
1138+ """Convert flat indices into multi-dimensional indices.
1139+
1140+ JAX implementation of :func:`numpy.unravel_index`. The JAX version differs in
1141+ its treatment of out-of-bound indices: unlike NumPy, negative indices are
1142+ supported, and out-of-bound indices are clipped to the nearest valid value.
1143+
1144+ Args:
1145+ indices: integer array of flat indices
1146+ shape: shape of multidimensional array to index into
1147+
1148+ Returns:
1149+ Tuple of unraveled indices
1150+
1151+ See also:
1152+ :func:`jax.numpy.ravel_multi_index`: Inverse of this function.
1153+
1154+ Examples:
1155+ Start with a 1D array values and indices:
1156+
1157+ >>> x = jnp.array([2., 3., 4., 5., 6., 7.])
1158+ >>> indices = jnp.array([1, 3, 5])
1159+ >>> print(x[indices])
1160+ [3. 5. 7.]
1161+
1162+ Now if ``x`` is reshaped, ``unravel_indices`` can be used to convert
1163+ the flat indices into a tuple of indices that access the same entries:
1164+
1165+ >>> shape = (2, 3)
1166+ >>> x_2D = x.reshape(shape)
1167+ >>> indices_2D = jnp.unravel_index(indices, shape)
1168+ >>> indices_2D
1169+ (Array([0, 1, 1], dtype=int32), Array([1, 0, 2], dtype=int32))
1170+ >>> print(x_2D[indices_2D])
1171+ [3. 5. 7.]
1172+
1173+ The inverse function, ``ravel_multi_index``, can be used to obtain the
1174+ original indices:
1175+
1176+ >>> jnp.ravel_multi_index(indices_2D, shape)
1177+ Array([1, 3, 5], dtype=int32)
1178+ """
10911179 util .check_arraylike ("unravel_index" , indices )
10921180 indices_arr = asarray (indices )
10931181 # Note: we do not convert shape to an array, because it may be passed as a
0 commit comments