Skip to content

Commit 0a5b2e1

Browse files
author
jax authors
committed
Merge pull request #21476 from jakevdp:partition-doc
PiperOrigin-RevId: 638271378
2 parents a07c781 + 1c5319d commit 0a5b2e1

File tree

1 file changed

+99
-18
lines changed

1 file changed

+99
-18
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 99 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5517,17 +5517,57 @@ def argsort(
55175517
return lax.rev(indices, dimensions=[dimension]) if descending else indices
55185518

55195519

5520-
@util.implements(np.partition, lax_description="""
5521-
The JAX version requires the ``kth`` argument to be a static integer rather than
5522-
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
5523-
you're only accessing the top or bottom k values of the output, it may be more
5524-
efficient to call :func:`jax.lax.top_k` directly.
5525-
5526-
The JAX version differs from the NumPy version in the treatment of NaN entries;
5527-
NaNs which have the negative bit set are sorted to the beginning of the array.
5528-
""")
55295520
@partial(jit, static_argnames=['kth', 'axis'])
55305521
def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
5522+
"""Returns a partially-sorted copy of an array.
5523+
5524+
JAX implementation of :func:`numpy.partition`. The JAX version differs from
5525+
NumPy in the treatment of NaN entries: NaNs which have the negative bit set
5526+
are sorted to the beginning of the array.
5527+
5528+
Args:
5529+
a: array to be partitioned.
5530+
kth: static integer index about which to partition the array.
5531+
axis: static integer axis along which to partition the array; default is -1.
5532+
5533+
Returns:
5534+
A copy of ``a`` partitioned at the ``kth`` value along ``axis``. The entries
5535+
before ``kth`` are values smaller than ``take(a, kth, axis)``, and entries
5536+
after ``kth`` are indices of values larger than ``take(a, kth, axis)``
5537+
5538+
Note:
5539+
The JAX version requires the ``kth`` argument to be a static integer rather than
5540+
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
5541+
you're only accessing the top or bottom k values of the output, it may be more
5542+
efficient to call :func:`jax.lax.top_k` directly.
5543+
5544+
See Also:
5545+
- :func:`jax.numpy.sort`: full sort
5546+
- :func:`jax.numpy.argpartition`: indirect partial sort
5547+
- :func:`jax.lax.top_k`: directly find the top k entries
5548+
- :func:`jax.lax.approx_max_k`: compute the approximate top k entries
5549+
- :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries
5550+
5551+
Examples:
5552+
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
5553+
>>> kth = 4
5554+
>>> x_partitioned = jnp.partition(x, kth)
5555+
>>> x_partitioned
5556+
Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
5557+
5558+
The result is a partially-sorted copy of the input. All values before ``kth``
5559+
are of smaller than the pivot value, and all values after ``kth`` are larger
5560+
than the pivot value:
5561+
5562+
>>> smallest_values = x_partitioned[:kth]
5563+
>>> pivot_value = x_partitioned[kth]
5564+
>>> largest_values = x_partitioned[kth + 1:]
5565+
>>> print(smallest_values, pivot_value, largest_values)
5566+
[1 2 3 3] 4 [9 8 7 6 5]
5567+
5568+
Notice that among ``smallest_values`` and ``largest_values``, the returned
5569+
order is arbitrary and implementation-dependent.
5570+
"""
55315571
# TODO(jakevdp): handle NaN values like numpy.
55325572
util.check_arraylike("partition", a)
55335573
arr = asarray(a)
@@ -5543,17 +5583,58 @@ def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
55435583
return swapaxes(out, -1, axis)
55445584

55455585

5546-
@util.implements(np.argpartition, lax_description="""
5547-
The JAX version requires the ``kth`` argument to be a static integer rather than
5548-
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
5549-
you're only accessing the top or bottom k values of the output, it may be more
5550-
efficient to call :func:`jax.lax.top_k` directly.
5551-
5552-
The JAX version differs from the NumPy version in the treatment of NaN entries;
5553-
NaNs which have the negative bit set are sorted to the beginning of the array.
5554-
""")
55555586
@partial(jit, static_argnames=['kth', 'axis'])
55565587
def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
5588+
"""Returns indices that partially sort an array.
5589+
5590+
JAX implementation of :func:`numpy.argpartition`. The JAX version differs from
5591+
NumPy in the treatment of NaN entries: NaNs which have the negative bit set are
5592+
sorted to the beginning of the array.
5593+
5594+
Args:
5595+
a: array to be partitioned.
5596+
kth: static integer index about which to partition the array.
5597+
axis: static integer axis along which to partition the array; default is -1.
5598+
5599+
Returns:
5600+
Indices which partition ``a`` at the ``kth`` value along ``axis``. The entries
5601+
before ``kth`` are indices of values smaller than ``take(a, kth, axis)``, and
5602+
entries after ``kth`` are indices of values larger than ``take(a, kth, axis)``
5603+
5604+
Note:
5605+
The JAX version requires the ``kth`` argument to be a static integer rather than
5606+
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
5607+
you're only accessing the top or bottom k values of the output, it may be more
5608+
efficient to call :func:`jax.lax.top_k` directly.
5609+
5610+
See Also:
5611+
- :func:`jax.numpy.partition`: direct partial sort
5612+
- :func:`jax.numpy.argsort`: full indirect sort
5613+
- :func:`jax.lax.top_k`: directly find the top k entries
5614+
- :func:`jax.lax.approx_max_k`: compute the approximate top k entries
5615+
- :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries
5616+
5617+
Examples:
5618+
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
5619+
>>> kth = 4
5620+
>>> idx = jnp.argpartition(x, kth)
5621+
>>> idx
5622+
Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
5623+
5624+
The result is a sequence of indices that partially sort the input. All indices
5625+
before ``kth`` are of values smaller than the pivot value, and all indices
5626+
after ``kth`` are of values larger than the pivot value:
5627+
5628+
>>> x_partitioned = x[idx]
5629+
>>> smallest_values = x_partitioned[:kth]
5630+
>>> pivot_value = x_partitioned[kth]
5631+
>>> largest_values = x_partitioned[kth + 1:]
5632+
>>> print(smallest_values, pivot_value, largest_values)
5633+
[1 2 3 3] 4 [6 8 9 7 5]
5634+
5635+
Notice that among ``smallest_values`` and ``largest_values``, the returned
5636+
order is arbitrary and implementation-dependent.
5637+
"""
55575638
# TODO(jakevdp): handle NaN values like numpy.
55585639
util.check_arraylike("partition", a)
55595640
arr = asarray(a)

0 commit comments

Comments
 (0)