Skip to content

Commit 82ad1da

Browse files
author
jax authors
committed
Merge pull request #21193 from sh0416:main
PiperOrigin-RevId: 637938498
2 parents 83871d3 + 818e7d9 commit 82ad1da

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

jax/_src/scipy/special.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def rel_entr(
672672
safe_q = jnp.where(both_gt_zero_mask, q, 1)
673673
log_val = lax.sub(_xlogx(safe_p), xlogy(safe_p, safe_q))
674674
result = jnp.where(
675-
both_gt_zero_mask, log_val, jnp.where(one_zero_mask, q, jnp.inf)
675+
both_gt_zero_mask, log_val, jnp.where(one_zero_mask, zero, jnp.inf)
676676
)
677677
return result
678678

tests/lax_scipy_special_functions_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,15 @@ def testNdtriExtremeValues(self):
228228
self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol)
229229
self._CompileAndCheck(lsp_special.ndtri, args_maker, rtol=rtol)
230230

231+
def testRelEntrExtremeValues(self):
232+
# Testing at the extreme values (bounds (0. and 1.) and outside the bounds).
233+
dtype = jax.numpy.zeros(0).dtype # default float dtype.
234+
args_maker = lambda: [np.array([-2, -2, -2, -1, -1, -1, 0, 0, 0]).astype(dtype),
235+
np.array([-1, 0, 1, -1, 0, 1, -1, 0, 1]).astype(dtype)]
236+
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
237+
self._CheckAgainstNumpy(osp_special.rel_entr, lsp_special.rel_entr, args_maker, rtol=rtol)
238+
self._CompileAndCheck(lsp_special.rel_entr, args_maker, rtol=rtol)
239+
231240

232241
if __name__ == "__main__":
233242
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)