Skip to content

Commit ac07af2

Browse files
committed
fix jax all nans
1 parent e585708 commit ac07af2

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

bayesflow/utils/integrate.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from bayesflow.types import Tensor
1212
from bayesflow.utils import filter_kwargs
1313
from bayesflow.utils.logging import warning
14+
from keras import backend as K
1415

1516
from . import logging
1617

@@ -22,6 +23,15 @@
2223
STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"]
2324

2425

26+
def _check_all_nans(state: StateDict):
27+
if K.backend() == "jax":
28+
return False # JAX backend does not support checks of the state variables
29+
all_nans_flags = []
30+
for v in state.values():
31+
all_nans_flags.append(keras.ops.all(keras.ops.isnan(v)))
32+
return keras.ops.all(keras.ops.stack(all_nans_flags))
33+
34+
2535
def euler_step(
2636
fn: Callable,
2737
state: StateDict,
@@ -482,11 +492,6 @@ def integrate(
482492

483493

484494
############ SDE Solvers #############
485-
def _check_all_nans(state: StateDict):
486-
all_nans_flags = []
487-
for v in state.values():
488-
all_nans_flags.append(keras.ops.all(keras.ops.isnan(v)))
489-
return keras.ops.all(keras.ops.stack(all_nans_flags))
490495

491496

492497
def stochastic_adaptive_step_size_controller(

0 commit comments

Comments
 (0)