@@ -1620,18 +1620,68 @@ def select(
16201620 return lax .select_n (* broadcast_arrays (idx , * choicelist ))
16211621
16221622
1623- @util .implements (np .bincount , lax_description = """\
1624- Jax adds the optional `length` parameter which specifies the output length, and
1625- defaults to ``x.max() + 1``. It must be specified for bincount to be compiled
1626- with non-static operands. Values larger than the specified length will be discarded.
1627- If `length` is specified, `minlength` will be ignored.
1628-
1629- Additionally, while ``np.bincount`` raises an error if the input array contains
1630- negative values, ``jax.numpy.bincount`` clips negative values to zero.
1631- """ )
16321623def bincount (x : ArrayLike , weights : ArrayLike | None = None ,
16331624 minlength : int = 0 , * , length : int | None = None
16341625 ) -> Array :
1626+ """Count the number of occurrences of each value in an integer array.
1627+
1628+ JAX implementation of :func:`numpy.bincount`.
1629+
1630+ For an array of positive integers ``x``, this function returns an array ``counts``
1631+ of size ``x.max() + 1``, such that ``counts[i]`` contains the number of occurrences
1632+ of the value ``i`` in ``x``.
1633+
1634+ The JAX version has a few differences from the NumPy version:
1635+
1636+ - In NumPy, passing an array ``x`` with negative entries will result in an error.
1637+ In JAX, negative values are clipped to zero.
1638+ - JAX adds an optional ``length`` parameter which can be used to statically specify
1639+ the length of the output array so that this function can be used with transformations
1640+ like :func:`jax.jit`. In this case, items larger than `length + 1` will be dropped.
1641+
1642+ Args:
1643+ x : N-dimensional array of positive integers
1644+ weights: optional array of weights associated with ``x``. If not specified, the
1645+ weight for each entry will be ``1``.
1646+ minlength: the minimum length of the output counts array.
1647+ length: the length of the output counts array. Must be specified statically for
1648+ ``bincount`` to be used with :func:`jax.jit` and other JAX transformations.
1649+
1650+ Returns:
1651+ An array of counts or summed weights reflecting the number of occurrances of values
1652+ in ``x``.
1653+
1654+ See Also:
1655+ - :func:`jax.numpy.histogram`
1656+ - :func:`jax.numpy.digitize`
1657+ - :func:`jax.numpy.unique_counts`
1658+
1659+ Examples:
1660+ Basic bincount:
1661+
1662+ >>> x = jnp.array([1, 1, 2, 3, 3, 3])
1663+ >>> jnp.bincount(x)
1664+ Array([0, 2, 1, 3], dtype=int32)
1665+
1666+ Weighted bincount:
1667+
1668+ >>> weights = jnp.array([1, 2, 3, 4, 5, 6])
1669+ >>> jnp.bincount(x, weights)
1670+ Array([ 0, 3, 3, 15], dtype=int32)
1671+
1672+ Specifying a static ``length`` makes this jit-compatible:
1673+
1674+ >>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length'])
1675+ >>> jit_bincount(x, length=5)
1676+ Array([0, 2, 1, 3, 0], dtype=int32)
1677+
1678+ Any negative numbers are clipped to the first bin, and numbers beyond the
1679+ specified ``length`` are dropped:
1680+
1681+ >>> x = jnp.array([-1, -1, 1, 3, 10])
1682+ >>> jnp.bincount(x, length=5)
1683+ Array([2, 1, 0, 1, 0], dtype=int32)
1684+ """
16351685 util .check_arraylike ("bincount" , x )
16361686 if not issubdtype (_dtype (x ), integer ):
16371687 raise TypeError (f"x argument to bincount must have an integer type; got { _dtype (x )} " )
0 commit comments