Skip to content

Commit 23a69ea

Browse files
committed
check nan in integrate
1 parent 08853fb commit 23a69ea

File tree

2 files changed

+59
-12
lines changed

2 files changed

+59
-12
lines changed

bayesflow/utils/integrate.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222
STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"]
2323

2424

25+
def _check_all_nans(state: StateDict):
26+
all_nans_flags = []
27+
for v in state.values():
28+
all_nans_flags.append(keras.ops.all(keras.ops.isnan(v)))
29+
return keras.ops.all(keras.ops.stack(all_nans_flags))
30+
31+
2532
def euler_step(
2633
fn: Callable,
2734
state: StateDict,
@@ -218,7 +225,8 @@ def integrate_fixed(
218225
def body(_loop_var, _loop_state):
219226
_state, _time = _loop_state
220227
_state, _time, _, _ = step_fn(_state, _time, step_size)
221-
228+
if _check_all_nans(_state):
229+
raise RuntimeError(f"All values are NaNs in state during integration at {_time}.")
222230
return _state, _time
223231

224232
state, time = keras.ops.fori_loop(0, steps, body, (state, time))
@@ -251,6 +259,9 @@ def body(_loop_var, _loop_state):
251259
_time = steps[_loop_var]
252260
step_size = steps[_loop_var + 1] - steps[_loop_var]
253261
_loop_state, _, _, _ = step_fn(_loop_state, _time, step_size)
262+
263+
if _check_all_nans(_loop_state):
264+
raise RuntimeError(f"All values are NaNs in state during integration at {_time}.")
254265
return _loop_state
255266

256267
state = keras.ops.fori_loop(0, len(steps) - 1, body, state)
@@ -296,10 +307,12 @@ def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted):
296307
step_lt_min = keras.ops.less(_step, float(min_steps))
297308
step_lt_max = keras.ops.less(_step, float(max_steps))
298309

299-
return keras.ops.logical_or(
300-
step_lt_min,
301-
keras.ops.logical_and(keras.ops.all(time_remaining > 0), step_lt_max),
310+
all_nans = _check_all_nans(_state)
311+
312+
end_now = keras.ops.logical_or(
313+
step_lt_min, keras.ops.logical_and(keras.ops.all(time_remaining > 0), step_lt_max)
302314
)
315+
return keras.ops.logical_and(~all_nans, end_now)
303316

304317
def body(_state, _time, _step_size, _step, _k1, _count_not_accepted):
305318
# Time remaining from current point
@@ -347,6 +360,9 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted):
347360
[state, start_time, initial_step, step0, k1_0, count_not_accepted],
348361
)
349362

363+
if _check_all_nans(state):
364+
raise RuntimeError(f"All values are NaNs in state during integration at {time}.")
365+
350366
# Final step to hit stop_time exactly
351367
time_diff = stop_time - time
352368
time_remaining = keras.ops.sign(stop_time - start_time) * time_diff
@@ -974,6 +990,9 @@ def body_fixed(_i, _loop_state):
974990
step_size_factor=step_size_factor,
975991
corrector_noise_history=corrector_noise_history,
976992
)
993+
all_nans = _check_all_nans(new_state)
994+
if all_nans:
995+
raise RuntimeError(f"All values are NaNs in state during integration at {_current_time}.")
977996
return new_state, new_time, initial_step
978997

979998
# Execute the fixed loop
@@ -1004,9 +1023,10 @@ def integrate_stochastic_adaptive(
10041023
initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0, state)
10051024

10061025
def cond(i, current_state, current_time, current_step, counter, last_state):
1007-
# time remaining after the next step
10081026
time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (current_time + current_step))
1009-
return keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps))
1027+
all_nans = _check_all_nans(current_state)
1028+
end_now = keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps))
1029+
return keras.ops.logical_and(~all_nans, end_now)
10101030

10111031
def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _last_state):
10121032
# Step Size Control
@@ -1048,9 +1068,36 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _l
10481068
return _i + 1, new_state, new_time, new_step, _counter, _new_current_state
10491069

10501070
# Execute the adaptive loop
1051-
_, final_state, final_time, _, final_counter, _ = keras.ops.while_loop(cond, body_adaptive, initial_loop_state)
1071+
_, final_state, final_time, _, final_counter, final_k1 = keras.ops.while_loop(
1072+
cond, body_adaptive, initial_loop_state
1073+
)
1074+
1075+
if _check_all_nans(final_state):
1076+
raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.")
1077+
1078+
# Final step to hit stop_time exactly
1079+
time_diff = stop_time - final_time
1080+
time_remaining = keras.ops.sign(stop_time - start_time) * time_diff
1081+
if keras.ops.all(time_remaining > 0):
1082+
noise_final = {k: z_history[k][-1] for k in final_state.keys()}
1083+
noise_extra_final = None
1084+
if len(z_extra_history) > 0:
1085+
noise_extra_final = {k: z_extra_history[k][-1] for k in final_state.keys()}
1086+
1087+
final_state, _, _ = step_fn(
1088+
state=final_state,
1089+
time=final_time,
1090+
step_size=time_diff,
1091+
last_state=final_k1,
1092+
min_step_size=min_step_size,
1093+
max_step_size=time_remaining,
1094+
noise=noise_final,
1095+
noise_aux=noise_extra_final,
1096+
use_adaptive_step_size=False,
1097+
)
1098+
final_counter = final_counter + 1
10521099

1053-
logging.debug(f"Finished integration after {final_counter} steps at {final_time}.")
1100+
logging.debug(f"Finished integration after {final_counter}.")
10541101
return final_state
10551102

10561103

@@ -1094,13 +1141,12 @@ def integrate_langevin(
10941141

10951142
def body(_i, loop_state):
10961143
current_state, current_time = loop_state
1097-
t = current_time
10981144

10991145
# score at current time
1100-
score = score_fn(t, **filter_kwargs(current_state, score_fn))
1146+
score = score_fn(current_time, **filter_kwargs(current_state, score_fn))
11011147

11021148
# noise schedule
1103-
log_snr_t = noise_schedule.get_log_snr(t=t, training=False)
1149+
log_snr_t = noise_schedule.get_log_snr(t=current_time, training=False)
11041150
_, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
11051151

11061152
new_state: StateDict = {}
@@ -1125,6 +1171,8 @@ def body(_i, loop_state):
11251171
step_size_factor=step_size_factor,
11261172
corrector_noise_history=corrector_noise_history,
11271173
)
1174+
if _check_all_nans(new_state):
1175+
raise RuntimeError(f"All values are NaNs in state during integration at {current_time}.")
11281176

11291177
return new_state, new_time
11301178

tests/test_utils/test_integrate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def diffusion_fn(t, x):
202202
# Start at time T with value x_T
203203
initial_state = {"x": keras.ops.ones((N,)) * x_T}
204204
steps = 200 if not use_adapt else "adaptive"
205-
206205
# Expected mean and variance at t=0 after integrating backward from t=T
207206
# For backward integration, the effective drift coefficient changes sign
208207
exp_mean = x_T * np.exp(-a * T)

0 commit comments

Comments
 (0)