From 53cc83752c62a23df1b03cb75bbb3fcf4089015d Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 5 Nov 2025 11:05:23 +0100 Subject: [PATCH 01/14] initial fix --- numpyro/distributions/continuous.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 4742082cd..e0aafba09 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2706,8 +2706,11 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: + log_p = -jnp.log(self.high - self.low) + is_in_support = (value >= self.low) & (value < self.high) shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) - return -jnp.broadcast_to(jnp.log(self.high - self.low), shape) + log_p = jnp.broadcast_to(log_p, shape) + return jnp.where(is_in_support, log_p, -jnp.inf) def cdf(self, value: ArrayLike) -> ArrayLike: cdf = (value - self.low) / (self.high - self.low) From 5aabb9c716cfcace638d1de6ef0d322c7a8d425f Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 5 Nov 2025 11:10:00 +0100 Subject: [PATCH 02/14] initial test --- test/test_distributions.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_distributions.py b/test/test_distributions.py index 2ff1362b6..63020a45e 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4766,3 +4766,9 @@ def log_prob_fn(params): f"All gradients for Beta({concentration1},{concentration0}) at x={value} " f"should be finite" ) + + +def test_uniform_log_prob_outside_support(): + d = dist.Uniform(0, 1) + assert_allclose(d.log_prob(-0.5), -jnp.inf) + assert_allclose(d.log_prob(1.5), -jnp.inf) From fad8707a47e4a101de89a04fdad493edbef565ff Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 5 Nov 2025 11:14:33 +0100 Subject: [PATCH 03/14] add more tests --- test/test_distributions.py | 97 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/test/test_distributions.py b/test/test_distributions.py index 63020a45e..0d588f552 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4772,3 +4772,100 @@ def test_uniform_log_prob_outside_support(): d = dist.Uniform(0, 1) assert_allclose(d.log_prob(-0.5), -jnp.inf) assert_allclose(d.log_prob(1.5), -jnp.inf) + + +@pytest.mark.parametrize( + "low, high", [(0.0, 1.0), (-2.0, 3.0), (1.0, 5.0), (-5.0, -1.0)] +) +def test_uniform_log_prob_boundaries(low, high): + """Test that boundary values are handled correctly.""" + d = dist.Uniform(low, high) + expected_log_prob = -jnp.log(high - low) + + # Value at lower bound (included): should have finite log prob + assert_allclose(d.log_prob(low), expected_log_prob) + + # Value just above lower bound: should have finite log prob + assert_allclose(d.log_prob(low + 1e-10), expected_log_prob) + + # Value at upper bound (excluded): should be -inf + assert_allclose(d.log_prob(high), -jnp.inf) + + # Value just below upper bound: should have finite log prob + assert_allclose(d.log_prob(high - 1e-10), expected_log_prob) + + # Value inside support: should have finite log prob + mid = (low + high) / 2.0 + assert_allclose(d.log_prob(mid), expected_log_prob) + + # Value below lower bound: should be -inf + assert_allclose(d.log_prob(low - 1.0), -jnp.inf) + + # Value above upper bound: should be -inf + assert_allclose(d.log_prob(high + 1.0), -jnp.inf) + + +@pytest.mark.parametrize("batch_shape", [(), (3,), (2, 3), (4, 2, 3)]) +def test_uniform_log_prob_broadcasting(batch_shape): + """Test broadcasting with different batch shapes.""" + if batch_shape == (): + low = 0.0 + high = 1.0 + else: + low = jnp.linspace(0.0, 1.0, np.prod(batch_shape)).reshape(batch_shape) + high = jnp.linspace(1.0, 2.0, np.prod(batch_shape)).reshape(batch_shape) + + d = dist.Uniform(low, high) + + # Test with scalar value + value = 0.5 + log_probs = d.log_prob(value) + assert log_probs.shape == batch_shape + + # Test with batched value + if batch_shape: + value_batched = jnp.linspace(-0.5, 1.5, np.prod(batch_shape)).reshape( + batch_shape + ) + log_probs_batched = d.log_prob(value_batched) + assert log_probs_batched.shape == batch_shape + + # Check that values outside support return -inf + # Values < low should be -inf + below_low = low - 0.1 + assert_allclose(d.log_prob(below_low), -jnp.inf) + + # Values >= high should be -inf + at_high = high + assert_allclose(d.log_prob(at_high), -jnp.inf) + + +@pytest.mark.parametrize("value_shape", [(), (5,), (3, 4), (2, 3, 4)]) +def test_uniform_log_prob_value_broadcasting(value_shape): + """Test broadcasting when value has different shapes.""" + d = dist.Uniform(0.0, 1.0) + + if value_shape == (): + values = 0.5 + else: + values = jnp.linspace(-0.5, 1.5, np.prod(value_shape)).reshape(value_shape) + + log_probs = d.log_prob(values) + assert log_probs.shape == value_shape + + # Check that values inside support have finite log prob + inside_values = jnp.linspace(0.1, 0.9, np.prod(value_shape) if value_shape else 1) + if value_shape: + inside_values = inside_values.reshape(value_shape) + log_probs_inside = d.log_prob(inside_values) + assert jnp.all(jnp.isfinite(log_probs_inside)) + + # Check that values outside support have -inf + outside_values = jnp.linspace(-1.0, 2.0, np.prod(value_shape) if value_shape else 1) + if value_shape: + outside_values = outside_values.reshape(value_shape) + log_probs_outside = d.log_prob(outside_values) + # Values in [0, 1) should be finite, others should be -inf + mask_inside = (outside_values >= 0.0) & (outside_values < 1.0) + assert jnp.all(jnp.where(mask_inside, jnp.isfinite(log_probs_outside), True)) + assert jnp.all(jnp.where(~mask_inside, log_probs_outside == -jnp.inf, True)) From 360844d24f485d4d3e79482d1ee486516e3d437a Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 19 Nov 2025 09:52:11 +0100 Subject: [PATCH 04/14] add validation by default --- numpyro/distributions/continuous.py | 5 +---- numpyro/distributions/distribution.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index e0aafba09..4742082cd 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2706,11 +2706,8 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: - log_p = -jnp.log(self.high - self.low) - is_in_support = (value >= self.low) & (value < self.high) shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) - log_p = jnp.broadcast_to(log_p, shape) - return jnp.where(is_in_support, log_p, -jnp.inf) + return -jnp.broadcast_to(jnp.log(self.high - self.low), shape) def cdf(self, value: ArrayLike) -> ArrayLike: cdf = (value - self.low) / (self.high - self.low) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 7e67edad5..4bbdb04db 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -53,7 +53,7 @@ from . import constraints -_VALIDATION_ENABLED = False +_VALIDATION_ENABLED = True def enable_validation(is_validate: bool = True) -> None: From 75f7d555a5ca57fc9f1d4835c4b2e6f3f8e78118 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Wed, 19 Nov 2025 10:11:59 +0100 Subject: [PATCH 05/14] Update numpyro/distributions/distribution.py Co-authored-by: Meesum Qazalbash --- numpyro/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 4bbdb04db..bd1384f4a 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -56,7 +56,7 @@ _VALIDATION_ENABLED = True -def enable_validation(is_validate: bool = True) -> None: +def enable_validation(is_validate: bool = False) -> None: """ Enable or disable validation checks in NumPyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is From 2c014a9df9b4fc29d571f9c9f9502b8506c95740 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sun, 30 Nov 2025 22:49:45 +0100 Subject: [PATCH 06/14] improve some validation --- numpyro/distributions/continuous.py | 2 +- numpyro/distributions/distribution.py | 2 +- test/test_distributions.py | 107 +++++++++++++++++++++++++- 3 files changed, 107 insertions(+), 4 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 4742082cd..254a8e60d 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -3577,7 +3577,7 @@ class Levy(Distribution): """ arg_constraints = { - "loc": constraints.positive, + "loc": constraints.real, "scale": constraints.positive, } diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index bd1384f4a..173dae971 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -1320,7 +1320,7 @@ class Unit(Distribution): arg_constraints = {"log_factor": constraints.real} support = constraints.real - def __init__(self, log_factor: ArrayLike, *, validate_args: Optional[bool] = None): + def __init__(self, log_factor: ArrayLike, *, validate_args: bool = False): batch_shape = jnp.shape(log_factor) event_shape = (0,) # This satisfies .size == 0. self.log_factor = log_factor diff --git a/test/test_distributions.py b/test/test_distributions.py index 0d588f552..198f2c1a0 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3694,12 +3694,12 @@ def test_vmap_validate_args(): def test_explicit_validate_args(): # Check validation passes for valid parameters. d = dist.Normal(0, 1) - d.validate_args() + d.validate_args(False) # Check validation fails for invalid parameters. d = dist.Normal(0, -1) with pytest.raises(ValueError, match="got invalid scale parameter"): - d.validate_args() + d.validate_args(False) # Check validation is skipped for strict=False and raises an error for strict=True. jitted = jax.jit( @@ -4640,6 +4640,109 @@ def test_interval_censored_validate_sample( censored_dist.log_prob(value) # Should not raise +def test_uniform_log_prob_outside_support(): + d = dist.Uniform(0, 1) + assert_allclose(d.log_prob(-0.5), -jnp.inf) + assert_allclose(d.log_prob(1.5), -jnp.inf) + + +@pytest.mark.parametrize( + "low, high", [(0.0, 1.0), (-2.0, 3.0), (1.0, 5.0), (-5.0, -1.0)] +) +def test_uniform_log_prob_boundaries(low, high): + """Test that boundary values are handled correctly.""" + d = dist.Uniform(low, high) + expected_log_prob = -jnp.log(high - low) + + # Value at lower bound (included): should have finite log prob + assert_allclose(d.log_prob(low), expected_log_prob) + + # Value just above lower bound: should have finite log prob + assert_allclose(d.log_prob(low + 1e-10), expected_log_prob) + + # Value at upper bound (excluded): should be -inf + assert_allclose(d.log_prob(high), -jnp.inf) + + # Value just below upper bound: should have finite log prob + assert_allclose(d.log_prob(high - 1e-10), expected_log_prob) + + # Value inside support: should have finite log prob + mid = (low + high) / 2.0 + assert_allclose(d.log_prob(mid), expected_log_prob) + + # Value below lower bound: should be -inf + assert_allclose(d.log_prob(low - 1.0), -jnp.inf) + + # Value above upper bound: should be -inf + assert_allclose(d.log_prob(high + 1.0), -jnp.inf) + + +@pytest.mark.parametrize("batch_shape", [(), (3,), (2, 3), (4, 2, 3)]) +def test_uniform_log_prob_broadcasting(batch_shape): + """Test broadcasting with different batch shapes.""" + if batch_shape == (): + low = 0.0 + high = 1.0 + else: + low = jnp.linspace(0.0, 1.0, np.prod(batch_shape)).reshape(batch_shape) + high = jnp.linspace(1.0, 2.0, np.prod(batch_shape)).reshape(batch_shape) + + d = dist.Uniform(low, high) + + # Test with scalar value + value = 0.5 + log_probs = d.log_prob(value) + assert log_probs.shape == batch_shape + + # Test with batched value + if batch_shape: + value_batched = jnp.linspace(-0.5, 1.5, np.prod(batch_shape)).reshape( + batch_shape + ) + log_probs_batched = d.log_prob(value_batched) + assert log_probs_batched.shape == batch_shape + + # Check that values outside support return -inf + # Values < low should be -inf + below_low = low - 0.1 + assert_allclose(d.log_prob(below_low), -jnp.inf) + + # Values >= high should be -inf + at_high = high + assert_allclose(d.log_prob(at_high), -jnp.inf) + + +@pytest.mark.parametrize("value_shape", [(), (5,), (3, 4), (2, 3, 4)]) +def test_uniform_log_prob_value_broadcasting(value_shape): + """Test broadcasting when value has different shapes.""" + d = dist.Uniform(0.0, 1.0) + + if value_shape == (): + values = 0.5 + else: + values = jnp.linspace(-0.5, 1.5, np.prod(value_shape)).reshape(value_shape) + + log_probs = d.log_prob(values) + assert log_probs.shape == value_shape + + # Check that values inside support have finite log prob + inside_values = jnp.linspace(0.1, 0.9, np.prod(value_shape) if value_shape else 1) + if value_shape: + inside_values = inside_values.reshape(value_shape) + log_probs_inside = d.log_prob(inside_values) + assert jnp.all(jnp.isfinite(log_probs_inside)) + + # Check that values outside support have -inf + outside_values = jnp.linspace(-1.0, 2.0, np.prod(value_shape) if value_shape else 1) + if value_shape: + outside_values = outside_values.reshape(value_shape) + log_probs_outside = d.log_prob(outside_values) + # Values in [0, 1) should be finite, others should be -inf + mask_inside = (outside_values >= 0.0) & (outside_values < 1.0) + assert jnp.all(jnp.where(mask_inside, jnp.isfinite(log_probs_outside), True)) + assert jnp.all(jnp.where(~mask_inside, log_probs_outside == -jnp.inf, True)) + + @pytest.mark.parametrize( argnames="concentration1,concentration0,value", argvalues=[ From 6dff67169398c27f0517fc4e68951a5263b07684 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sun, 30 Nov 2025 22:52:31 +0100 Subject: [PATCH 07/14] remove bad noqa generating warnings --- test/contrib/test_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index f51d24221..d896ccda4 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -620,7 +620,7 @@ def __call__(self, x, state): return x, state # Eager initialization of the Net module outside the model - net_module, eager_state = eqx.nn.make_with_state(Net)(key=random.PRNGKey(0)) # noqa: E1111 + net_module, eager_state = eqx.nn.make_with_state(Net)(key=random.PRNGKey(0)) x = dist.Normal(0, 1).expand([4, 3]).to_event(2).sample(random.PRNGKey(0)) def model(): From 01af96b33f67b5e7bcbe1297069d85ad373aaa61 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sun, 30 Nov 2025 23:18:30 +0100 Subject: [PATCH 08/14] validate to false --- numpyro/distributions/censored.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/censored.py b/numpyro/distributions/censored.py index 5d94468ea..aa3fce60e 100644 --- a/numpyro/distributions/censored.py +++ b/numpyro/distributions/censored.py @@ -81,7 +81,7 @@ def __init__( base_dist: DistributionT, censored: ArrayLike = False, *, - validate_args: Optional[bool] = None, + validate_args: bool = False, ): # test if base_dist has an implemented cdf method if not hasattr(base_dist, "cdf"): @@ -197,7 +197,7 @@ def __init__( base_dist: DistributionT, censored: ArrayLike = False, *, - validate_args: Optional[bool] = None, + validate_args: bool = False, ): # test if base_dist has an implemented cdf method if not hasattr(base_dist, "cdf"): @@ -335,7 +335,7 @@ def __init__( left_censored: ArrayLike, right_censored: ArrayLike, *, - validate_args: Optional[bool] = None, + validate_args: bool = False, ): # Optionally test that cdf actually works (in validate_args mode) if validate_args: From 9ca88f1cb8c53e6cf67c466e61c1f6c586c43e64 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sun, 30 Nov 2025 23:30:23 +0100 Subject: [PATCH 09/14] fix support --- numpyro/distributions/continuous.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 254a8e60d..2e0d45176 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2706,8 +2706,11 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: + log_p = -jnp.log(self.high - self.low) + is_in_support = (value >= self.low) & (value < self.high) shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) - return -jnp.broadcast_to(jnp.log(self.high - self.low), shape) + log_p = jnp.broadcast_to(log_p, shape) + return jnp.where(is_in_support, log_p, -jnp.inf) def cdf(self, value: ArrayLike) -> ArrayLike: cdf = (value - self.low) / (self.high - self.low) From 2bc25c008f3c23013522dd2c940c7fef36fd27fc Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 1 Dec 2025 11:17:07 +0100 Subject: [PATCH 10/14] rm dupplicated code --- test/test_distributions.py | 103 ------------------------------------- 1 file changed, 103 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 198f2c1a0..c058fcda5 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4640,109 +4640,6 @@ def test_interval_censored_validate_sample( censored_dist.log_prob(value) # Should not raise -def test_uniform_log_prob_outside_support(): - d = dist.Uniform(0, 1) - assert_allclose(d.log_prob(-0.5), -jnp.inf) - assert_allclose(d.log_prob(1.5), -jnp.inf) - - -@pytest.mark.parametrize( - "low, high", [(0.0, 1.0), (-2.0, 3.0), (1.0, 5.0), (-5.0, -1.0)] -) -def test_uniform_log_prob_boundaries(low, high): - """Test that boundary values are handled correctly.""" - d = dist.Uniform(low, high) - expected_log_prob = -jnp.log(high - low) - - # Value at lower bound (included): should have finite log prob - assert_allclose(d.log_prob(low), expected_log_prob) - - # Value just above lower bound: should have finite log prob - assert_allclose(d.log_prob(low + 1e-10), expected_log_prob) - - # Value at upper bound (excluded): should be -inf - assert_allclose(d.log_prob(high), -jnp.inf) - - # Value just below upper bound: should have finite log prob - assert_allclose(d.log_prob(high - 1e-10), expected_log_prob) - - # Value inside support: should have finite log prob - mid = (low + high) / 2.0 - assert_allclose(d.log_prob(mid), expected_log_prob) - - # Value below lower bound: should be -inf - assert_allclose(d.log_prob(low - 1.0), -jnp.inf) - - # Value above upper bound: should be -inf - assert_allclose(d.log_prob(high + 1.0), -jnp.inf) - - -@pytest.mark.parametrize("batch_shape", [(), (3,), (2, 3), (4, 2, 3)]) -def test_uniform_log_prob_broadcasting(batch_shape): - """Test broadcasting with different batch shapes.""" - if batch_shape == (): - low = 0.0 - high = 1.0 - else: - low = jnp.linspace(0.0, 1.0, np.prod(batch_shape)).reshape(batch_shape) - high = jnp.linspace(1.0, 2.0, np.prod(batch_shape)).reshape(batch_shape) - - d = dist.Uniform(low, high) - - # Test with scalar value - value = 0.5 - log_probs = d.log_prob(value) - assert log_probs.shape == batch_shape - - # Test with batched value - if batch_shape: - value_batched = jnp.linspace(-0.5, 1.5, np.prod(batch_shape)).reshape( - batch_shape - ) - log_probs_batched = d.log_prob(value_batched) - assert log_probs_batched.shape == batch_shape - - # Check that values outside support return -inf - # Values < low should be -inf - below_low = low - 0.1 - assert_allclose(d.log_prob(below_low), -jnp.inf) - - # Values >= high should be -inf - at_high = high - assert_allclose(d.log_prob(at_high), -jnp.inf) - - -@pytest.mark.parametrize("value_shape", [(), (5,), (3, 4), (2, 3, 4)]) -def test_uniform_log_prob_value_broadcasting(value_shape): - """Test broadcasting when value has different shapes.""" - d = dist.Uniform(0.0, 1.0) - - if value_shape == (): - values = 0.5 - else: - values = jnp.linspace(-0.5, 1.5, np.prod(value_shape)).reshape(value_shape) - - log_probs = d.log_prob(values) - assert log_probs.shape == value_shape - - # Check that values inside support have finite log prob - inside_values = jnp.linspace(0.1, 0.9, np.prod(value_shape) if value_shape else 1) - if value_shape: - inside_values = inside_values.reshape(value_shape) - log_probs_inside = d.log_prob(inside_values) - assert jnp.all(jnp.isfinite(log_probs_inside)) - - # Check that values outside support have -inf - outside_values = jnp.linspace(-1.0, 2.0, np.prod(value_shape) if value_shape else 1) - if value_shape: - outside_values = outside_values.reshape(value_shape) - log_probs_outside = d.log_prob(outside_values) - # Values in [0, 1) should be finite, others should be -inf - mask_inside = (outside_values >= 0.0) & (outside_values < 1.0) - assert jnp.all(jnp.where(mask_inside, jnp.isfinite(log_probs_outside), True)) - assert jnp.all(jnp.where(~mask_inside, log_probs_outside == -jnp.inf, True)) - - @pytest.mark.parametrize( argnames="concentration1,concentration0,value", argvalues=[ From 139bc885e69d71378e09dec1a644e20ab2e73ae6 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 1 Dec 2025 14:10:58 +0100 Subject: [PATCH 11/14] feedback 1/n --- numpyro/distributions/distribution.py | 2 +- test/test_distributions.py | 6 +++--- test/test_pickle.py | 2 +- test/test_util.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 173dae971..b8403ec7e 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -56,7 +56,7 @@ _VALIDATION_ENABLED = True -def enable_validation(is_validate: bool = False) -> None: +def enable_validation(is_validate: bool = True) -> None: """ Enable or disable validation checks in NumPyro. Validation checks provide useful warnings and errors, e.g. NaN checks, validating distribution arguments and support values, etc. which is diff --git a/test/test_distributions.py b/test/test_distributions.py index c058fcda5..0505e9411 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3693,13 +3693,13 @@ def test_vmap_validate_args(): def test_explicit_validate_args(): # Check validation passes for valid parameters. - d = dist.Normal(0, 1) - d.validate_args(False) + d = dist.Normal(0, 1, validate_args=False) + d.validate_args() # Check validation fails for invalid parameters. d = dist.Normal(0, -1) with pytest.raises(ValueError, match="got invalid scale parameter"): - d.validate_args(False) + d.validate_args() # Check validation is skipped for strict=False and raises an error for strict=True. jitted = jax.jit( diff --git a/test/test_pickle.py b/test/test_pickle.py index 7ed338578..ef210defd 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -65,7 +65,7 @@ def bernoulli_model(): def logistic_regression(): - data = jnp.arange(10) + data = random.choice(random.PRNGKey(0), jnp.array([0, 1]), (10,)) x = numpyro.sample("x", dist.Normal(0, 1)) with numpyro.plate("N", 10, subsample_size=2): batch = numpyro.subsample(data, 0) diff --git a/test/test_util.py b/test/test_util.py index f351b45f3..62ead6395 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -112,7 +112,7 @@ def test_format_shapes(): def model_test(): mean = numpyro.param("mean", jnp.zeros(len(data))) - scale = numpyro.sample("scale", dist.Normal(0, 1).expand([3]).to_event(1)) + scale = numpyro.sample("scale", dist.LogNormal(0, 1).expand([3]).to_event(1)) scale = scale.sum() with numpyro.plate("data", len(data), subsample_size=10) as ind: batch = data[ind] From d587182ed2a95c0bb554a2ad0bb0d8b35fb78145 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 1 Dec 2025 14:20:14 +0100 Subject: [PATCH 12/14] undo uniform changes and admend test --- numpyro/distributions/continuous.py | 5 +- test/test_distributions.py | 106 +++------------------------- 2 files changed, 9 insertions(+), 102 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 2e0d45176..254a8e60d 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2706,11 +2706,8 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: - log_p = -jnp.log(self.high - self.low) - is_in_support = (value >= self.low) & (value < self.high) shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape) - log_p = jnp.broadcast_to(log_p, shape) - return jnp.where(is_in_support, log_p, -jnp.inf) + return -jnp.broadcast_to(jnp.log(self.high - self.low), shape) def cdf(self, value: ArrayLike) -> ArrayLike: cdf = (value - self.low) / (self.high - self.low) diff --git a/test/test_distributions.py b/test/test_distributions.py index 0505e9411..77c79155f 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4769,103 +4769,13 @@ def log_prob_fn(params): def test_uniform_log_prob_outside_support(): - d = dist.Uniform(0, 1) - assert_allclose(d.log_prob(-0.5), -jnp.inf) - assert_allclose(d.log_prob(1.5), -jnp.inf) - - -@pytest.mark.parametrize( - "low, high", [(0.0, 1.0), (-2.0, 3.0), (1.0, 5.0), (-5.0, -1.0)] -) -def test_uniform_log_prob_boundaries(low, high): - """Test that boundary values are handled correctly.""" - d = dist.Uniform(low, high) - expected_log_prob = -jnp.log(high - low) - - # Value at lower bound (included): should have finite log prob - assert_allclose(d.log_prob(low), expected_log_prob) - - # Value just above lower bound: should have finite log prob - assert_allclose(d.log_prob(low + 1e-10), expected_log_prob) + from numpyro.distributions.distribution import enable_validation - # Value at upper bound (excluded): should be -inf - assert_allclose(d.log_prob(high), -jnp.inf) + enable_validation() - # Value just below upper bound: should have finite log prob - assert_allclose(d.log_prob(high - 1e-10), expected_log_prob) - - # Value inside support: should have finite log prob - mid = (low + high) / 2.0 - assert_allclose(d.log_prob(mid), expected_log_prob) - - # Value below lower bound: should be -inf - assert_allclose(d.log_prob(low - 1.0), -jnp.inf) - - # Value above upper bound: should be -inf - assert_allclose(d.log_prob(high + 1.0), -jnp.inf) - - -@pytest.mark.parametrize("batch_shape", [(), (3,), (2, 3), (4, 2, 3)]) -def test_uniform_log_prob_broadcasting(batch_shape): - """Test broadcasting with different batch shapes.""" - if batch_shape == (): - low = 0.0 - high = 1.0 - else: - low = jnp.linspace(0.0, 1.0, np.prod(batch_shape)).reshape(batch_shape) - high = jnp.linspace(1.0, 2.0, np.prod(batch_shape)).reshape(batch_shape) - - d = dist.Uniform(low, high) - - # Test with scalar value - value = 0.5 - log_probs = d.log_prob(value) - assert log_probs.shape == batch_shape - - # Test with batched value - if batch_shape: - value_batched = jnp.linspace(-0.5, 1.5, np.prod(batch_shape)).reshape( - batch_shape - ) - log_probs_batched = d.log_prob(value_batched) - assert log_probs_batched.shape == batch_shape - - # Check that values outside support return -inf - # Values < low should be -inf - below_low = low - 0.1 - assert_allclose(d.log_prob(below_low), -jnp.inf) - - # Values >= high should be -inf - at_high = high - assert_allclose(d.log_prob(at_high), -jnp.inf) - - -@pytest.mark.parametrize("value_shape", [(), (5,), (3, 4), (2, 3, 4)]) -def test_uniform_log_prob_value_broadcasting(value_shape): - """Test broadcasting when value has different shapes.""" - d = dist.Uniform(0.0, 1.0) - - if value_shape == (): - values = 0.5 - else: - values = jnp.linspace(-0.5, 1.5, np.prod(value_shape)).reshape(value_shape) - - log_probs = d.log_prob(values) - assert log_probs.shape == value_shape - - # Check that values inside support have finite log prob - inside_values = jnp.linspace(0.1, 0.9, np.prod(value_shape) if value_shape else 1) - if value_shape: - inside_values = inside_values.reshape(value_shape) - log_probs_inside = d.log_prob(inside_values) - assert jnp.all(jnp.isfinite(log_probs_inside)) - - # Check that values outside support have -inf - outside_values = jnp.linspace(-1.0, 2.0, np.prod(value_shape) if value_shape else 1) - if value_shape: - outside_values = outside_values.reshape(value_shape) - log_probs_outside = d.log_prob(outside_values) - # Values in [0, 1) should be finite, others should be -inf - mask_inside = (outside_values >= 0.0) & (outside_values < 1.0) - assert jnp.all(jnp.where(mask_inside, jnp.isfinite(log_probs_outside), True)) - assert jnp.all(jnp.where(~mask_inside, log_probs_outside == -jnp.inf, True)) + d = dist.Uniform(0, 1) + with pytest.warns( + UserWarning, + match="Out-of-support values provided to log prob method. The value argument should be within the support.", + ): + d.log_prob(-0.5) From ce5d495d108bd9af1d8bde4f841691325f6da7e2 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 1 Dec 2025 15:00:06 +0100 Subject: [PATCH 13/14] try to fix validation tests --- test/test_distributions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 77c79155f..87cd2ace1 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3694,7 +3694,6 @@ def test_vmap_validate_args(): def test_explicit_validate_args(): # Check validation passes for valid parameters. d = dist.Normal(0, 1, validate_args=False) - d.validate_args() # Check validation fails for invalid parameters. d = dist.Normal(0, -1) From de1fdefcab9e1c053af9819431dac0d2dbf4693c Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 1 Dec 2025 15:35:33 +0100 Subject: [PATCH 14/14] fix 1/n XD --- test/test_distributions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 87cd2ace1..4c0daea3d 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3694,9 +3694,10 @@ def test_vmap_validate_args(): def test_explicit_validate_args(): # Check validation passes for valid parameters. d = dist.Normal(0, 1, validate_args=False) + d.validate_args() # Check validation fails for invalid parameters. - d = dist.Normal(0, -1) + d = dist.Normal(0, -1, validate_args=False) with pytest.raises(ValueError, match="got invalid scale parameter"): d.validate_args()