@@ -2700,60 +2700,34 @@ class ZeroSumNormal(Distribution):
27002700 r"""
27012701 Normal distribution where one or several axes are constrained to sum to zero.
27022702
2703- By default, the last axis is constrained to sum to zero.
2704- See `n_zerosum_axes` kwarg for more details.
2703+ By default, the last axis is constrained to sum to zero. See the `n_zerosum_axes`
2704+ kwarg for more details.
2705+
2706+ The constrained distribution follows a multivariate Normal distribution. For the
2707+ standard 1D case with a single constrained axis of size K, the covariance is:
27052708
27062709 .. math::
27072710
2708- \begin{align*}
2709- ZSN(\sigma) = N \Big( 0, \sigma^2 (I - \tfrac{1}{n}J) \Big) \\
2710- \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
2711- n = \text{nbr of zero-sum axes}
2712- \end{align*}
2711+ ZSN(\sigma) = N\left(0, \sigma^2 \left(I_K - \tfrac{1}{K} J_K\right)\right)
2712+
2713+ where:
2714+
2715+ - :math:`I_K` is the :math:`K \times K` identity matrix,
2716+ - :math:`J_K` is the :math:`K \times K` matrix of ones,
2717+ - :math:`K` is the size of the constrained axis.
2718+
2719+ Using :math:`K` avoids confusion with ``n_zerosum_axes``, which counts how many
2720+ axes are constrained, not their length.
27132721
27142722 Parameters
27152723 ----------
27162724 sigma : tensor_like of float
2717- Scale parameter (sigma > 0).
2718- It's actually the standard deviation of the underlying, unconstrained Normal distribution.
2719- Defaults to 1 if not specified. ``sigma`` cannot have length > 1 across the zero-sum axes.
2720- n_zerosum_axes: int, defaults to 1
2721- Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
2722- Defaults to 1, i.e the rightmost axis.
2723- dims: sequence of strings, optional
2724- Dimension names of the distribution. Works the same as for other PyMC distributions.
2725- Necessary if ``shape`` is not passed.
2726- shape: tuple of integers, optional
2727- Shape of the distribution. Works the same as for other PyMC distributions.
2728- Necessary if ``dims`` or ``observed`` is not passed.
2729-
2730- Warnings
2731- --------
2732- Currently, ``sigma``cannot have length > 1 across the zero-sum axes to ensure the zero-sum constraint.
2733-
2734- ``n_zerosum_axes`` has to be > 0. If you want the behavior of ``n_zerosum_axes = 0``,
2735- just use ``pm.Normal``.
2736-
2737- Examples
2738- --------
2739- Define a `ZeroSumNormal` variable, with `sigma=1` and
2740- `n_zerosum_axes=1` by default::
2741-
2742- COORDS = {
2743- "regions": ["a", "b", "c"],
2744- "answers": ["yes", "no", "whatever", "don't understand question"],
2745- }
2746- with pm.Model(coords=COORDS) as m:
2747- # the zero sum axis will be 'answers'
2748- v = pm.ZeroSumNormal("v", dims=("regions", "answers"))
2749-
2750- with pm.Model(coords=COORDS) as m:
2751- # the zero sum axes will be 'answers' and 'regions'
2752- v = pm.ZeroSumNormal("v", dims=("regions", "answers"), n_zerosum_axes=2)
2753-
2754- with pm.Model(coords=COORDS) as m:
2755- # the zero sum axes will be the last two
2756- v = pm.ZeroSumNormal("v", shape=(3, 4, 5), n_zerosum_axes=2)
2725+ Scale parameter (sigma > 0). Defaults to 1.
2726+ ``sigma`` cannot have length > 1 across the zero-sum axes.
2727+ n_zerosum_axes : int, defaults to 1
2728+ Number of axes along which the zero-sum constraint is enforced.
2729+ dims : sequence of strings, optional
2730+ shape : tuple of integers, optional
27572731 """
27582732
27592733 rv_type = ZeroSumNormalRV
0 commit comments