Skip to content

Commit 4c9d44b

Browse files
committed
fix jax all nans
1 parent ac07af2 commit 4c9d44b

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

bayesflow/utils/integrate.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from bayesflow.types import Tensor
1212
from bayesflow.utils import filter_kwargs
1313
from bayesflow.utils.logging import warning
14-
from keras import backend as K
1514

1615
from . import logging
1716

@@ -24,8 +23,6 @@
2423

2524

2625
def _check_all_nans(state: StateDict):
27-
if K.backend() == "jax":
28-
return False # JAX backend does not support checks of the state variables
2926
all_nans_flags = []
3027
for v in state.values():
3128
all_nans_flags.append(keras.ops.all(keras.ops.isnan(v)))
@@ -376,7 +373,7 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted):
376373

377374
# Step counter: increment only on accepted steps
378375
updated_step = _step + keras.ops.where(accepted, 1.0, 0.0)
379-
_count_not_accepted = _count_not_accepted + 1 if not accepted else _count_not_accepted
376+
_count_not_accepted = _count_not_accepted + keras.ops.where(accepted, 1.0, 0.0)
380377

381378
# For the next iteration, always use the new suggested step size
382379
return updated_state, updated_time, new_step_size, updated_step, updated_k1, _count_not_accepted

0 commit comments

Comments
 (0)