Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2705,55 +2705,55 @@ class ZeroSumNormal(Distribution):

.. math::

\begin{align*}
ZSN(\sigma) = N \Big( 0, \sigma^2 (I - \tfrac{1}{n}J) \Big) \\
\text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
n = \text{nbr of zero-sum axes}
\end{align*}
\begin{align*}
ZSN(\sigma) = N \Big( 0, \sigma^2 (I_K - \tfrac{1}{K}J_K) \Big) \\
\text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
K = \text{size (length) of the constrained axis}
\end{align*}

Parameters
----------
sigma : tensor_like of float
Scale parameter (sigma > 0).
Copy link
Member

@ricardoV94 ricardoV94 Dec 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not remove any of this previous stuff

It's actually the standard deviation of the underlying, unconstrained Normal distribution.
Defaults to 1 if not specified. ``sigma`` cannot have length > 1 across the zero-sum axes.
Scale parameter (sigma > 0).
It's actually the standard deviation of the underlying, unconstrained Normal distribution.
Defaults to 1 if not specified. ``sigma`` cannot have length > 1 across the zero-sum axes.
n_zerosum_axes: int, defaults to 1
Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
Defaults to 1, i.e the rightmost axis.
Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
Defaults to 1, i.e the rightmost axis.
dims: sequence of strings, optional
Dimension names of the distribution. Works the same as for other PyMC distributions.
Necessary if ``shape`` is not passed.
Dimension names of the distribution. Works the same as for other PyMC distributions.
Necessary if ``shape`` is not passed.
shape: tuple of integers, optional
Shape of the distribution. Works the same as for other PyMC distributions.
Necessary if ``dims`` or ``observed`` is not passed.
Shape of the distribution. Works the same as for other PyMC distributions.
Necessary if ``dims`` or ``observed`` is not passed.

Warnings
--------
Currently, ``sigma``cannot have length > 1 across the zero-sum axes to ensure the zero-sum constraint.
Currently, ``sigma`` cannot have length > 1 across the zero-sum axes to ensure the zero-sum constraint.

``n_zerosum_axes`` has to be > 0. If you want the behavior of ``n_zerosum_axes = 0``,
just use ``pm.Normal``.

Examples
--------
Define a `ZeroSumNormal` variable, with `sigma=1` and
`n_zerosum_axes=1` by default::

COORDS = {
"regions": ["a", "b", "c"],
"answers": ["yes", "no", "whatever", "don't understand question"],
}
with pm.Model(coords=COORDS) as m:
# the zero sum axis will be 'answers'
v = pm.ZeroSumNormal("v", dims=("regions", "answers"))

with pm.Model(coords=COORDS) as m:
# the zero sum axes will be 'answers' and 'regions'
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), n_zerosum_axes=2)

with pm.Model(coords=COORDS) as m:
# the zero sum axes will be the last two
v = pm.ZeroSumNormal("v", shape=(3, 4, 5), n_zerosum_axes=2)
`n_zerosum_axes=1` by default::

COORDS = {
"regions": ["a", "b", "c"],
"answers": ["yes", "no", "whatever", "don't understand question"],
}
with pm.Model(coords=COORDS) as m:
# the zero sum axis will be 'answers'
v = pm.ZeroSumNormal("v", dims=("regions", "answers"))

with pm.Model(coords=COORDS) as m:
# the zero sum axes will be 'answers' and 'regions'
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), n_zerosum_axes=2)

with pm.Model(coords=COORDS) as m:
# the zero sum axes will be the last two
v = pm.ZeroSumNormal("v", shape=(3, 4, 5), n_zerosum_axes=2)
"""

rv_type = ZeroSumNormalRV
Expand Down