diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 686c6ebde..803df79c0 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -216,7 +216,30 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: - return self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1)) + # Use double-where trick to avoid NaN gradients at boundary conditions + # Reference: https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf + is_boundary = (value == 0.0) | (value == 1.0) + + # Mask boundary values (0 or 1) to safe value (0.5) for gradient computation + safe_value = jnp.where(is_boundary, 0.5, value) + safe_complement = jnp.where(is_boundary, 0.5, 1.0 - value) + + # Compute log_prob with safe values (gradients flow through this path) + safe_dirichlet_value = jnp.stack([safe_value, safe_complement], axis=-1) + safe_log_prob = self._dirichlet.log_prob(safe_dirichlet_value) + + # At boundaries, compute correct forward value using xlogy (handles 0*log(0)=0) + # Use stop_gradient so gradients come only from safe_log_prob + correct_value = ( + xlogy(self.concentration1 - 1.0, value) + + xlogy(self.concentration0 - 1.0, 1.0 - value) + - betaln(self.concentration1, self.concentration0) + ) + + # Apply correction at boundaries, return safe value elsewhere + return jnp.where( + is_boundary, jax.lax.stop_gradient(correct_value), safe_log_prob + ) @property def mean(self) -> ArrayLike: diff --git a/test/test_distributions.py b/test/test_distributions.py index 7edc19226..ace1839b0 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4486,3 +4486,131 @@ def test_interval_censored_validate_sample( censored_dist.log_prob(value) else: censored_dist.log_prob(value) # Should not raise + + +@pytest.mark.parametrize( + argnames="concentration1,concentration0,value", + argvalues=[ + (1.0, 8.0, 0.0), + (8.0, 1.0, 1.0), + ], + ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"], +) +def test_beta_logprob_edge_cases(concentration1, concentration0, value): + """Test Beta distribution with concentration=1 gives finite log probability at boundary.""" + beta_dist = dist.Beta(concentration1, concentration0) + log_prob = beta_dist.log_prob(value) + + assert not jnp.isnan(log_prob), ( + f"Beta({concentration1},{concentration0}).log_prob({value}) should not be NaN" + ) + assert jnp.isfinite(log_prob), ( + f"Beta({concentration1},{concentration0}).log_prob({value}) should be finite" + ) + + +def test_beta_logprob_edge_case_consistency_small_values(): + """Test that edge case values are consistent with small deviation values.""" + beta_dist = dist.Beta(1.0, 8.0) + beta_dist2 = dist.Beta(8.0, 1.0) + + # At boundary + log_prob_at_zero = beta_dist.log_prob(0.0) + log_prob_at_one = beta_dist2.log_prob(1.0) + + # Very close to boundary + small_value = 1e-10 + log_prob_small = beta_dist.log_prob(small_value) + log_prob_close_to_one = beta_dist2.log_prob(1.0 - small_value) + + # Edge case values should be close to small deviation values + assert jnp.abs(log_prob_at_zero - log_prob_small) < 1e-5 + assert jnp.abs(log_prob_at_one - log_prob_close_to_one) < 1e-5 + + +def test_beta_logprob_edge_case_non_boundary_values(): + """Test that Beta with concentration=1 still works for non-boundary values.""" + beta_dist = dist.Beta(1.0, 8.0) + beta_dist2 = dist.Beta(8.0, 1.0) + + assert jnp.isfinite(beta_dist.log_prob(0.5)) + assert jnp.isfinite(beta_dist2.log_prob(0.5)) + + +def test_beta_logprob_boundary_non_edge_cases(): + """Test that non-edge cases (concentration > 1) still give -inf at boundaries.""" + beta_dist3 = dist.Beta(2.0, 8.0) + beta_dist4 = dist.Beta(8.0, 2.0) + + assert jnp.isneginf(beta_dist3.log_prob(0.0)) + assert jnp.isneginf(beta_dist4.log_prob(1.0)) + + +@pytest.mark.parametrize( + argnames="concentration1,concentration0,value,grad_param,grad_value", + argvalues=[ + (1.0, 8.0, 0.0, "value", 0.0), + (8.0, 1.0, 1.0, "value", 1.0), + (1.0, 8.0, 0.0, "concentration1", 1.0), + (1.0, 8.0, 0.0, "concentration0", 8.0), + (8.0, 1.0, 1.0, "concentration1", 8.0), + (8.0, 1.0, 1.0, "concentration0", 1.0), + ], + ids=[ + "Beta(1,8) at x=0", + "Beta(8,1) at x=1", + "Beta(1,8) at concentration1=1", + "Beta(1,8) at concentration0=8", + "Beta(8,1) at concentration1=8", + "Beta(8,1) at concentration0=1", + ], +) +def test_beta_gradient_edge_cases_single_param( + concentration1, concentration0, value, grad_param, grad_value +): + """Test that gradients w.r.t. individual parameters are finite at edge cases.""" + if grad_param == "value": + + def log_prob_fn(x): + return dist.Beta(concentration1, concentration0).log_prob(x) + + grad = jax.grad(log_prob_fn)(value) + elif grad_param == "concentration1": + + def log_prob_fn(c1): + return dist.Beta(c1, concentration0).log_prob(value) + + grad = jax.grad(log_prob_fn)(grad_value) + else: # concentration0 + + def log_prob_fn(c0): + return dist.Beta(concentration1, c0).log_prob(value) + + grad = jax.grad(log_prob_fn)(grad_value) + + assert jnp.isfinite(grad), ( + f"Gradient w.r.t. {grad_param} for Beta({concentration1},{concentration0}) " + f"at x={value} should be finite" + ) + + +@pytest.mark.parametrize( + argnames="concentration1,concentration0,value", + argvalues=[ + (1.0, 8.0, 0.0), + (8.0, 1.0, 1.0), + ], + ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"], +) +def test_beta_gradient_edge_cases_all_params(concentration1, concentration0, value): + """Test that all gradients are finite when computed simultaneously at edge cases.""" + + def log_prob_fn(params): + c1, c0, v = params + return dist.Beta(c1, c0).log_prob(v) + + grads = jax.grad(log_prob_fn)(jnp.array([concentration1, concentration0, value])) + assert jnp.all(jnp.isfinite(grads)), ( + f"All gradients for Beta({concentration1},{concentration0}) at x={value} " + f"should be finite" + )