Skip to content

Commit 02235a0

Browse files
committed
DOC: clarify ZeroSumNormal covariance (use K for constrained-axis size) (#7904)
1 parent ebd836e commit 02235a0

File tree

1 file changed

+21
-47
lines changed

1 file changed

+21
-47
lines changed

pymc/distributions/multivariate.py

Lines changed: 21 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2700,60 +2700,34 @@ class ZeroSumNormal(Distribution):
27002700
r"""
27012701
Normal distribution where one or several axes are constrained to sum to zero.
27022702
2703-
By default, the last axis is constrained to sum to zero.
2704-
See `n_zerosum_axes` kwarg for more details.
2703+
By default, the last axis is constrained to sum to zero. See the `n_zerosum_axes`
2704+
kwarg for more details.
2705+
2706+
The constrained distribution follows a multivariate Normal distribution. For the
2707+
standard 1D case with a single constrained axis of size K, the covariance is:
27052708
27062709
.. math::
27072710
2708-
\begin{align*}
2709-
ZSN(\sigma) = N \Big( 0, \sigma^2 (I - \tfrac{1}{n}J) \Big) \\
2710-
\text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
2711-
n = \text{nbr of zero-sum axes}
2712-
\end{align*}
2711+
ZSN(\sigma) = N\left(0, \sigma^2 \left(I_K - \tfrac{1}{K} J_K\right)\right)
2712+
2713+
where:
2714+
2715+
- :math:`I_K` is the :math:`K \times K` identity matrix,
2716+
- :math:`J_K` is the :math:`K \times K` matrix of ones,
2717+
- :math:`K` is the size of the constrained axis.
2718+
2719+
Using :math:`K` avoids confusion with ``n_zerosum_axes``, which counts how many
2720+
axes are constrained, not their length.
27132721
27142722
Parameters
27152723
----------
27162724
sigma : tensor_like of float
2717-
Scale parameter (sigma > 0).
2718-
It's actually the standard deviation of the underlying, unconstrained Normal distribution.
2719-
Defaults to 1 if not specified. ``sigma`` cannot have length > 1 across the zero-sum axes.
2720-
n_zerosum_axes: int, defaults to 1
2721-
Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
2722-
Defaults to 1, i.e the rightmost axis.
2723-
dims: sequence of strings, optional
2724-
Dimension names of the distribution. Works the same as for other PyMC distributions.
2725-
Necessary if ``shape`` is not passed.
2726-
shape: tuple of integers, optional
2727-
Shape of the distribution. Works the same as for other PyMC distributions.
2728-
Necessary if ``dims`` or ``observed`` is not passed.
2729-
2730-
Warnings
2731-
--------
2732-
Currently, ``sigma``cannot have length > 1 across the zero-sum axes to ensure the zero-sum constraint.
2733-
2734-
``n_zerosum_axes`` has to be > 0. If you want the behavior of ``n_zerosum_axes = 0``,
2735-
just use ``pm.Normal``.
2736-
2737-
Examples
2738-
--------
2739-
Define a `ZeroSumNormal` variable, with `sigma=1` and
2740-
`n_zerosum_axes=1` by default::
2741-
2742-
COORDS = {
2743-
"regions": ["a", "b", "c"],
2744-
"answers": ["yes", "no", "whatever", "don't understand question"],
2745-
}
2746-
with pm.Model(coords=COORDS) as m:
2747-
# the zero sum axis will be 'answers'
2748-
v = pm.ZeroSumNormal("v", dims=("regions", "answers"))
2749-
2750-
with pm.Model(coords=COORDS) as m:
2751-
# the zero sum axes will be 'answers' and 'regions'
2752-
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), n_zerosum_axes=2)
2753-
2754-
with pm.Model(coords=COORDS) as m:
2755-
# the zero sum axes will be the last two
2756-
v = pm.ZeroSumNormal("v", shape=(3, 4, 5), n_zerosum_axes=2)
2725+
Scale parameter (sigma > 0). Defaults to 1.
2726+
``sigma`` cannot have length > 1 across the zero-sum axes.
2727+
n_zerosum_axes : int, defaults to 1
2728+
Number of axes along which the zero-sum constraint is enforced.
2729+
dims : sequence of strings, optional
2730+
shape : tuple of integers, optional
27572731
"""
27582732

27592733
rv_type = ZeroSumNormalRV

0 commit comments

Comments
 (0)