Skip to content

Commit fcdbb2a

Browse files
committed
Improve documentation for jnp.bincount
1 parent 4461b7f commit fcdbb2a

File tree

1 file changed

+59
-9
lines changed

1 file changed

+59
-9
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,18 +1620,68 @@ def select(
16201620
return lax.select_n(*broadcast_arrays(idx, *choicelist))
16211621

16221622

1623-
@util.implements(np.bincount, lax_description="""\
1624-
Jax adds the optional `length` parameter which specifies the output length, and
1625-
defaults to ``x.max() + 1``. It must be specified for bincount to be compiled
1626-
with non-static operands. Values larger than the specified length will be discarded.
1627-
If `length` is specified, `minlength` will be ignored.
1628-
1629-
Additionally, while ``np.bincount`` raises an error if the input array contains
1630-
negative values, ``jax.numpy.bincount`` clips negative values to zero.
1631-
""")
16321623
def bincount(x: ArrayLike, weights: ArrayLike | None = None,
16331624
minlength: int = 0, *, length: int | None = None
16341625
) -> Array:
1626+
"""Count the number of occurrences of each value in an integer array.
1627+
1628+
JAX implementation of :func:`numpy.bincount`.
1629+
1630+
For an array of positive integers ``x``, this function returns an array ``counts``
1631+
of size ``x.max() + 1``, such that ``counts[i]`` contains the number of occurrences
1632+
of the value ``i`` in ``x``.
1633+
1634+
The JAX version has a few differences from the NumPy version:
1635+
1636+
- In NumPy, passing an array ``x`` with negative entries will result in an error.
1637+
In JAX, negative values are clipped to zero.
1638+
- JAX adds an optional ``length`` parameter which can be used to statically specify
1639+
the length of the output array so that this function can be used with transformations
1640+
like :func:`jax.jit`. In this case, items larger than `length + 1` will be dropped.
1641+
1642+
Args:
1643+
x : N-dimensional array of positive integers
1644+
weights: optional array of weights associated with ``x``. If not specified, the
1645+
weight for each entry will be ``1``.
1646+
minlength: the minimum length of the output counts array.
1647+
length: the length of the output counts array. Must be specified statically for
1648+
``bincount`` to be used with :func:`jax.jit` and other JAX transformations.
1649+
1650+
Returns:
1651+
An array of counts or summed weights reflecting the number of occurrances of values
1652+
in ``x``.
1653+
1654+
See Also:
1655+
- :func:`jax.numpy.histogram`
1656+
- :func:`jax.numpy.digitize`
1657+
- :func:`jax.numpy.unique_counts`
1658+
1659+
Examples:
1660+
Basic bincount:
1661+
1662+
>>> x = jnp.array([1, 1, 2, 3, 3, 3])
1663+
>>> jnp.bincount(x)
1664+
Array([0, 2, 1, 3], dtype=int32)
1665+
1666+
Weighted bincount:
1667+
1668+
>>> weights = jnp.array([1, 2, 3, 4, 5, 6])
1669+
>>> jnp.bincount(x, weights)
1670+
Array([ 0, 3, 3, 15], dtype=int32)
1671+
1672+
Specifying a static ``length`` makes this jit-compatible:
1673+
1674+
>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length'])
1675+
>>> jit_bincount(x, length=5)
1676+
Array([0, 2, 1, 3, 0], dtype=int32)
1677+
1678+
Any negative numbers are clipped to the first bin, and numbers beyond the
1679+
specified ``length`` are dropped:
1680+
1681+
>>> x = jnp.array([-1, -1, 1, 3, 10])
1682+
>>> jnp.bincount(x, length=5)
1683+
Array([2, 1, 0, 1, 0], dtype=int32)
1684+
"""
16351685
util.check_arraylike("bincount", x)
16361686
if not issubdtype(_dtype(x), integer):
16371687
raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}")

0 commit comments

Comments
 (0)