Skip to content
6 changes: 3 additions & 3 deletions numpyro/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
base_dist: DistributionT,
censored: ArrayLike = False,
*,
validate_args: Optional[bool] = None,
validate_args: bool = False,
):
# test if base_dist has an implemented cdf method
if not hasattr(base_dist, "cdf"):
Expand Down Expand Up @@ -197,7 +197,7 @@ def __init__(
base_dist: DistributionT,
censored: ArrayLike = False,
*,
validate_args: Optional[bool] = None,
validate_args: bool = False,
):
# test if base_dist has an implemented cdf method
if not hasattr(base_dist, "cdf"):
Expand Down Expand Up @@ -335,7 +335,7 @@ def __init__(
left_censored: ArrayLike,
right_censored: ArrayLike,
*,
validate_args: Optional[bool] = None,
validate_args: bool = False,
):
# Optionally test that cdf actually works (in validate_args mode)
if validate_args:
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3577,7 +3577,7 @@ class Levy(Distribution):
"""

arg_constraints = {
"loc": constraints.positive,
"loc": constraints.real,
"scale": constraints.positive,
}

Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

from . import constraints

_VALIDATION_ENABLED = False
_VALIDATION_ENABLED = True


def enable_validation(is_validate: bool = True) -> None:
Expand Down Expand Up @@ -1320,7 +1320,7 @@ class Unit(Distribution):
arg_constraints = {"log_factor": constraints.real}
support = constraints.real

def __init__(self, log_factor: ArrayLike, *, validate_args: Optional[bool] = None):
def __init__(self, log_factor: ArrayLike, *, validate_args: bool = False):
batch_shape = jnp.shape(log_factor)
event_shape = (0,) # This satisfies .size == 0.
self.log_factor = log_factor
Expand Down
2 changes: 1 addition & 1 deletion test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def __call__(self, x, state):
return x, state

# Eager initialization of the Net module outside the model
net_module, eager_state = eqx.nn.make_with_state(Net)(key=random.PRNGKey(0)) # noqa: E1111
net_module, eager_state = eqx.nn.make_with_state(Net)(key=random.PRNGKey(0))
x = dist.Normal(0, 1).expand([4, 3]).to_event(2).sample(random.PRNGKey(0))

def model():
Expand Down
17 changes: 15 additions & 2 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3693,11 +3693,11 @@ def test_vmap_validate_args():

def test_explicit_validate_args():
# Check validation passes for valid parameters.
d = dist.Normal(0, 1)
d = dist.Normal(0, 1, validate_args=False)
d.validate_args()

# Check validation fails for invalid parameters.
d = dist.Normal(0, -1)
d = dist.Normal(0, -1, validate_args=False)
with pytest.raises(ValueError, match="got invalid scale parameter"):
d.validate_args()

Expand Down Expand Up @@ -4766,3 +4766,16 @@ def log_prob_fn(params):
f"All gradients for Beta({concentration1},{concentration0}) at x={value} "
f"should be finite"
)


def test_uniform_log_prob_outside_support():
from numpyro.distributions.distribution import enable_validation

enable_validation()

d = dist.Uniform(0, 1)
with pytest.warns(
UserWarning,
match="Out-of-support values provided to log prob method. The value argument should be within the support.",
):
d.log_prob(-0.5)
2 changes: 1 addition & 1 deletion test/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def bernoulli_model():


def logistic_regression():
data = jnp.arange(10)
data = random.choice(random.PRNGKey(0), jnp.array([0, 1]), (10,))
x = numpyro.sample("x", dist.Normal(0, 1))
with numpyro.plate("N", 10, subsample_size=2):
batch = numpyro.subsample(data, 0)
Expand Down
2 changes: 1 addition & 1 deletion test/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_format_shapes():

def model_test():
mean = numpyro.param("mean", jnp.zeros(len(data)))
scale = numpyro.sample("scale", dist.Normal(0, 1).expand([3]).to_event(1))
scale = numpyro.sample("scale", dist.LogNormal(0, 1).expand([3]).to_event(1))
scale = scale.sum()
with numpyro.plate("data", len(data), subsample_size=10) as ind:
batch = data[ind]
Expand Down