diff --git a/numpyro/distributions/censored.py b/numpyro/distributions/censored.py index 5d94468ea..aa3fce60e 100644 --- a/numpyro/distributions/censored.py +++ b/numpyro/distributions/censored.py @@ -81,7 +81,7 @@ def __init__( base_dist: DistributionT, censored: ArrayLike = False, *, - validate_args: Optional[bool] = None, + validate_args: bool = False, ): # test if base_dist has an implemented cdf method if not hasattr(base_dist, "cdf"): @@ -197,7 +197,7 @@ def __init__( base_dist: DistributionT, censored: ArrayLike = False, *, - validate_args: Optional[bool] = None, + validate_args: bool = False, ): # test if base_dist has an implemented cdf method if not hasattr(base_dist, "cdf"): @@ -335,7 +335,7 @@ def __init__( left_censored: ArrayLike, right_censored: ArrayLike, *, - validate_args: Optional[bool] = None, + validate_args: bool = False, ): # Optionally test that cdf actually works (in validate_args mode) if validate_args: diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 4742082cd..254a8e60d 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -3577,7 +3577,7 @@ class Levy(Distribution): """ arg_constraints = { - "loc": constraints.positive, + "loc": constraints.real, "scale": constraints.positive, } diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 7e67edad5..b8403ec7e 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -53,7 +53,7 @@ from . import constraints -_VALIDATION_ENABLED = False +_VALIDATION_ENABLED = True def enable_validation(is_validate: bool = True) -> None: @@ -1320,7 +1320,7 @@ class Unit(Distribution): arg_constraints = {"log_factor": constraints.real} support = constraints.real - def __init__(self, log_factor: ArrayLike, *, validate_args: Optional[bool] = None): + def __init__(self, log_factor: ArrayLike, *, validate_args: bool = False): batch_shape = jnp.shape(log_factor) event_shape = (0,) # This satisfies .size == 0. self.log_factor = log_factor diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index f51d24221..d896ccda4 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -620,7 +620,7 @@ def __call__(self, x, state): return x, state # Eager initialization of the Net module outside the model - net_module, eager_state = eqx.nn.make_with_state(Net)(key=random.PRNGKey(0)) # noqa: E1111 + net_module, eager_state = eqx.nn.make_with_state(Net)(key=random.PRNGKey(0)) x = dist.Normal(0, 1).expand([4, 3]).to_event(2).sample(random.PRNGKey(0)) def model(): diff --git a/test/test_distributions.py b/test/test_distributions.py index 2ff1362b6..4c0daea3d 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3693,11 +3693,11 @@ def test_vmap_validate_args(): def test_explicit_validate_args(): # Check validation passes for valid parameters. - d = dist.Normal(0, 1) + d = dist.Normal(0, 1, validate_args=False) d.validate_args() # Check validation fails for invalid parameters. - d = dist.Normal(0, -1) + d = dist.Normal(0, -1, validate_args=False) with pytest.raises(ValueError, match="got invalid scale parameter"): d.validate_args() @@ -4766,3 +4766,16 @@ def log_prob_fn(params): f"All gradients for Beta({concentration1},{concentration0}) at x={value} " f"should be finite" ) + + +def test_uniform_log_prob_outside_support(): + from numpyro.distributions.distribution import enable_validation + + enable_validation() + + d = dist.Uniform(0, 1) + with pytest.warns( + UserWarning, + match="Out-of-support values provided to log prob method. The value argument should be within the support.", + ): + d.log_prob(-0.5) diff --git a/test/test_pickle.py b/test/test_pickle.py index 7ed338578..ef210defd 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -65,7 +65,7 @@ def bernoulli_model(): def logistic_regression(): - data = jnp.arange(10) + data = random.choice(random.PRNGKey(0), jnp.array([0, 1]), (10,)) x = numpyro.sample("x", dist.Normal(0, 1)) with numpyro.plate("N", 10, subsample_size=2): batch = numpyro.subsample(data, 0) diff --git a/test/test_util.py b/test/test_util.py index f351b45f3..62ead6395 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -112,7 +112,7 @@ def test_format_shapes(): def model_test(): mean = numpyro.param("mean", jnp.zeros(len(data))) - scale = numpyro.sample("scale", dist.Normal(0, 1).expand([3]).to_event(1)) + scale = numpyro.sample("scale", dist.LogNormal(0, 1).expand([3]).to_event(1)) scale = scale.sum() with numpyro.plate("data", len(data), subsample_size=10) as ind: batch = data[ind]