@@ -5517,17 +5517,57 @@ def argsort(
55175517 return lax .rev (indices , dimensions = [dimension ]) if descending else indices
55185518
55195519
5520- @util .implements (np .partition , lax_description = """
5521- The JAX version requires the ``kth`` argument to be a static integer rather than
5522- a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
5523- you're only accessing the top or bottom k values of the output, it may be more
5524- efficient to call :func:`jax.lax.top_k` directly.
5525-
5526- The JAX version differs from the NumPy version in the treatment of NaN entries;
5527- NaNs which have the negative bit set are sorted to the beginning of the array.
5528- """ )
55295520@partial (jit , static_argnames = ['kth' , 'axis' ])
55305521def partition (a : ArrayLike , kth : int , axis : int = - 1 ) -> Array :
5522+ """Returns a partially-sorted copy of an array.
5523+
5524+ JAX implementation of :func:`numpy.partition`. The JAX version differs from
5525+ NumPy in the treatment of NaN entries: NaNs which have the negative bit set
5526+ are sorted to the beginning of the array.
5527+
5528+ Args:
5529+ a: array to be partitioned.
5530+ kth: static integer index about which to partition the array.
5531+ axis: static integer axis along which to partition the array; default is -1.
5532+
5533+ Returns:
5534+ A copy of ``a`` partitioned at the ``kth`` value along ``axis``. The entries
5535+ before ``kth`` are values smaller than ``take(a, kth, axis)``, and entries
5536+ after ``kth`` are indices of values larger than ``take(a, kth, axis)``
5537+
5538+ Note:
5539+ The JAX version requires the ``kth`` argument to be a static integer rather than
5540+ a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
5541+ you're only accessing the top or bottom k values of the output, it may be more
5542+ efficient to call :func:`jax.lax.top_k` directly.
5543+
5544+ See Also:
5545+ - :func:`jax.numpy.sort`: full sort
5546+ - :func:`jax.numpy.argpartition`: indirect partial sort
5547+ - :func:`jax.lax.top_k`: directly find the top k entries
5548+ - :func:`jax.lax.approx_max_k`: compute the approximate top k entries
5549+ - :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries
5550+
5551+ Examples:
5552+ >>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
5553+ >>> kth = 4
5554+ >>> x_partitioned = jnp.partition(x, kth)
5555+ >>> x_partitioned
5556+ Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
5557+
5558+ The result is a partially-sorted copy of the input. All values before ``kth``
5559+ are of smaller than the pivot value, and all values after ``kth`` are larger
5560+ than the pivot value:
5561+
5562+ >>> smallest_values = x_partitioned[:kth]
5563+ >>> pivot_value = x_partitioned[kth]
5564+ >>> largest_values = x_partitioned[kth + 1:]
5565+ >>> print(smallest_values, pivot_value, largest_values)
5566+ [1 2 3 3] 4 [9 8 7 6 5]
5567+
5568+ Notice that among ``smallest_values`` and ``largest_values``, the returned
5569+ order is arbitrary and implementation-dependent.
5570+ """
55315571 # TODO(jakevdp): handle NaN values like numpy.
55325572 util .check_arraylike ("partition" , a )
55335573 arr = asarray (a )
@@ -5543,17 +5583,58 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
55435583 return swapaxes (out , - 1 , axis )
55445584
55455585
5546- @util .implements (np .argpartition , lax_description = """
5547- The JAX version requires the ``kth`` argument to be a static integer rather than
5548- a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
5549- you're only accessing the top or bottom k values of the output, it may be more
5550- efficient to call :func:`jax.lax.top_k` directly.
5551-
5552- The JAX version differs from the NumPy version in the treatment of NaN entries;
5553- NaNs which have the negative bit set are sorted to the beginning of the array.
5554- """ )
55555586@partial (jit , static_argnames = ['kth' , 'axis' ])
55565587def argpartition (a : ArrayLike , kth : int , axis : int = - 1 ) -> Array :
5588+ """Returns indices that partially sort an array.
5589+
5590+ JAX implementation of :func:`numpy.argpartition`. The JAX version differs from
5591+ NumPy in the treatment of NaN entries: NaNs which have the negative bit set are
5592+ sorted to the beginning of the array.
5593+
5594+ Args:
5595+ a: array to be partitioned.
5596+ kth: static integer index about which to partition the array.
5597+ axis: static integer axis along which to partition the array; default is -1.
5598+
5599+ Returns:
5600+ Indices which partition ``a`` at the ``kth`` value along ``axis``. The entries
5601+ before ``kth`` are indices of values smaller than ``take(a, kth, axis)``, and
5602+ entries after ``kth`` are indices of values larger than ``take(a, kth, axis)``
5603+
5604+ Note:
5605+ The JAX version requires the ``kth`` argument to be a static integer rather than
5606+ a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
5607+ you're only accessing the top or bottom k values of the output, it may be more
5608+ efficient to call :func:`jax.lax.top_k` directly.
5609+
5610+ See Also:
5611+ - :func:`jax.numpy.partition`: direct partial sort
5612+ - :func:`jax.numpy.argsort`: full indirect sort
5613+ - :func:`jax.lax.top_k`: directly find the top k entries
5614+ - :func:`jax.lax.approx_max_k`: compute the approximate top k entries
5615+ - :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries
5616+
5617+ Examples:
5618+ >>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
5619+ >>> kth = 4
5620+ >>> idx = jnp.argpartition(x, kth)
5621+ >>> idx
5622+ Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
5623+
5624+ The result is a sequence of indices that partially sort the input. All indices
5625+ before ``kth`` are of values smaller than the pivot value, and all indices
5626+ after ``kth`` are of values larger than the pivot value:
5627+
5628+ >>> x_partitioned = x[idx]
5629+ >>> smallest_values = x_partitioned[:kth]
5630+ >>> pivot_value = x_partitioned[kth]
5631+ >>> largest_values = x_partitioned[kth + 1:]
5632+ >>> print(smallest_values, pivot_value, largest_values)
5633+ [1 2 3 3] 4 [6 8 9 7 5]
5634+
5635+ Notice that among ``smallest_values`` and ``largest_values``, the returned
5636+ order is arbitrary and implementation-dependent.
5637+ """
55575638 # TODO(jakevdp): handle NaN values like numpy.
55585639 util .check_arraylike ("partition" , a )
55595640 arr = asarray (a )
0 commit comments