Skip to content

Commit ad27606

Browse files
committed
make loop jax compatible
1 parent be78470 commit ad27606

File tree

1 file changed

+64
-45
lines changed

1 file changed

+64
-45
lines changed

bayesflow/utils/integrate.py

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -224,17 +224,22 @@ def integrate_fixed(
224224
step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False)
225225
step_size = (stop_time - start_time) / steps
226226

227-
time = start_time
228-
229-
def body(_loop_var, _loop_state):
230-
_state, _time = _loop_state
231-
_state, _time, _, _ = step_fn(_state, _time, step_size)
232-
if _check_all_nans(_state):
233-
raise RuntimeError(f"All values are NaNs in state during integration at {_time}.")
234-
return _state, _time
227+
def cond(_loop_var, _loop_state, _loop_time):
228+
all_nans = _check_all_nans(_loop_state)
229+
end_now = keras.ops.less(_loop_var, steps)
230+
return keras.ops.logical_and(~all_nans, end_now)
235231

236-
state, time = keras.ops.fori_loop(0, steps, body, (state, time))
232+
def body(_loop_var, _loop_state, _loop_time):
233+
_loop_state, _loop_time, _, _ = step_fn(_loop_state, _loop_time, step_size)
234+
return _loop_var + 1, _loop_state, _loop_time
237235

236+
_, state, _ = keras.ops.while_loop(
237+
cond,
238+
body,
239+
[0, state, start_time],
240+
)
241+
if _check_all_nans(state):
242+
raise RuntimeError("All values are NaNs in state during integration.")
238243
return state
239244

240245

@@ -259,16 +264,25 @@ def integrate_scheduled(
259264

260265
step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False)
261266

267+
def cond(_loop_var, _loop_state):
268+
all_nans = _check_all_nans(_loop_state)
269+
end_now = keras.ops.less(_loop_var, len(steps) - 1)
270+
return keras.ops.logical_and(~all_nans, end_now)
271+
262272
def body(_loop_var, _loop_state):
263273
_time = steps[_loop_var]
264274
step_size = steps[_loop_var + 1] - steps[_loop_var]
265275
_loop_state, _, _, _ = step_fn(_loop_state, _time, step_size)
276+
return _loop_var + 1, _loop_state
266277

267-
if _check_all_nans(_loop_state):
268-
raise RuntimeError(f"All values are NaNs in state during integration at {_time}.")
269-
return _loop_state
278+
_, state = keras.ops.while_loop(
279+
cond,
280+
body,
281+
[0, state],
282+
)
270283

271-
state = keras.ops.fori_loop(0, len(steps) - 1, body, state)
284+
if _check_all_nans(state):
285+
raise RuntimeError("All values are NaNs in state during integration.")
272286
return state
273287

274288

@@ -635,7 +649,7 @@ def two_step_adaptive_step(
635649
min_step_size=min_step_size,
636650
max_step_size=max_step_size,
637651
noise=noise,
638-
use_adaptive_step_size=True,
652+
use_adaptive_step_size=False,
639653
)
640654

641655
# Compute drift and diffusion at new state, but update from old state
@@ -957,9 +971,12 @@ def integrate_stochastic_fixed(
957971
"""
958972
initial_step = (stop_time - start_time) / float(steps)
959973

960-
def body_fixed(_i, _loop_state):
961-
_current_state, _current_time, _current_step = _loop_state
974+
def cond(_loop_var, _loop_state, _loop_time, _loop_step):
975+
all_nans = _check_all_nans(_loop_state)
976+
end_now = keras.ops.less(_loop_var, steps)
977+
return keras.ops.logical_and(~all_nans, end_now)
962978

979+
def body(_i, _current_state, _current_time, _current_step):
963980
# Determine step size: either the constant size or the remainder to reach stop_time
964981
remaining = keras.ops.abs(stop_time - _current_time)
965982
sign = keras.ops.sign(_current_step)
@@ -994,13 +1011,16 @@ def body_fixed(_i, _loop_state):
9941011
step_size_factor=step_size_factor,
9951012
corrector_noise_history=corrector_noise_history,
9961013
)
997-
all_nans = _check_all_nans(new_state)
998-
if all_nans:
999-
raise RuntimeError(f"All values are NaNs in state during integration at {_current_time}.")
1000-
return new_state, new_time, initial_step
1014+
return _i + 1, new_state, new_time, initial_step
1015+
1016+
_, final_state, final_time, _ = keras.ops.while_loop(
1017+
cond,
1018+
body,
1019+
[0, state, start_time, initial_step],
1020+
)
1021+
if _check_all_nans(final_state):
1022+
raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.")
10011023

1002-
# Execute the fixed loop
1003-
final_state, final_time, _ = keras.ops.fori_loop(0, steps, body_fixed, (state, start_time, initial_step))
10041024
return final_state
10051025

10061026

@@ -1024,22 +1044,21 @@ def integrate_stochastic_adaptive(
10241044
"""
10251045
Performs adaptive-step SDE integration.
10261046
"""
1027-
initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, 0, state)
1047+
initial_loop_state = (keras.ops.zeros((), dtype="int32"), state, start_time, initial_step, state)
10281048

1029-
def cond(i, current_state, current_time, current_step, counter, last_state):
1049+
def cond(i, current_state, current_time, current_step, last_state):
10301050
time_remaining = keras.ops.sign(stop_time - start_time) * (stop_time - (current_time + current_step))
10311051
all_nans = _check_all_nans(current_state)
10321052
end_now = keras.ops.logical_and(keras.ops.all(time_remaining > 0), keras.ops.less(i, max_steps))
10331053
return keras.ops.logical_and(~all_nans, end_now)
10341054

1035-
def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _last_state):
1055+
def body_adaptive(_i, _current_state, _current_time, _current_step, _last_state):
10361056
# Step Size Control
10371057
remaining = keras.ops.abs(stop_time - _current_time)
10381058
sign = keras.ops.sign(_current_step)
10391059
# Ensure the next step does not overshoot the stop_time
10401060
dt_mag = keras.ops.minimum(keras.ops.abs(_current_step), remaining)
10411061
dt = sign * dt_mag
1042-
_counter += 1
10431062

10441063
_noise_i = {k: z_history[k][_i] for k in _current_state.keys()}
10451064
_noise_extra_i = None
@@ -1069,12 +1088,10 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _l
10691088
corrector_noise_history=corrector_noise_history,
10701089
)
10711090

1072-
return _i + 1, new_state, new_time, new_step, _counter, _new_current_state
1091+
return _i + 1, new_state, new_time, new_step, _new_current_state
10731092

10741093
# Execute the adaptive loop
1075-
_, final_state, final_time, _, final_counter, final_k1 = keras.ops.while_loop(
1076-
cond, body_adaptive, initial_loop_state
1077-
)
1094+
final_counter, final_state, final_time, _, final_k1 = keras.ops.while_loop(cond, body_adaptive, initial_loop_state)
10781095

10791096
if _check_all_nans(final_state):
10801097
raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.")
@@ -1143,27 +1160,30 @@ def integrate_langevin(
11431160
dt = (stop_time - start_time) / float(steps)
11441161
effective_factor = step_size_factor * 100 / np.sqrt(steps)
11451162

1146-
def body(_i, loop_state):
1147-
current_state, current_time = loop_state
1163+
def cond(_loop_var, _loop_state, _loop_time):
1164+
all_nans = _check_all_nans(_loop_state)
1165+
end_now = keras.ops.less(_loop_var, steps)
1166+
return keras.ops.logical_and(~all_nans, end_now)
11481167

1168+
def body(_i, _loop_state, _loop_time):
11491169
# score at current time
1150-
score = score_fn(current_time, **filter_kwargs(current_state, score_fn))
1170+
score = score_fn(_loop_time, **filter_kwargs(_loop_state, score_fn))
11511171

11521172
# noise schedule
1153-
log_snr_t = noise_schedule.get_log_snr(t=current_time, training=False)
1173+
log_snr_t = noise_schedule.get_log_snr(t=_loop_time, training=False)
11541174
_, sigma_t = noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)
11551175

11561176
new_state: StateDict = {}
1157-
for k in current_state.keys():
1177+
for k in _loop_state.keys():
11581178
s_k = score.get(k, None)
11591179
if s_k is None:
1160-
new_state[k] = current_state[k]
1180+
new_state[k] = _loop_state[k]
11611181
continue
11621182

11631183
e = effective_factor * sigma_t**2
1164-
new_state[k] = current_state[k] + e * s_k + keras.ops.sqrt(2.0 * e) * z_history[k][_i]
1184+
new_state[k] = _loop_state[k] + e * s_k + keras.ops.sqrt(2.0 * e) * z_history[k][_i]
11651185

1166-
new_time = current_time + dt
1186+
new_time = _loop_time + dt
11671187

11681188
new_state = _apply_corrector(
11691189
new_state=new_state,
@@ -1175,17 +1195,16 @@ def body(_i, loop_state):
11751195
step_size_factor=step_size_factor,
11761196
corrector_noise_history=corrector_noise_history,
11771197
)
1178-
if _check_all_nans(new_state):
1179-
raise RuntimeError(f"All values are NaNs in state during integration at {current_time}.")
11801198

1181-
return new_state, new_time
1199+
return _i + 1, new_state, new_time
11821200

1183-
final_state, _ = keras.ops.fori_loop(
1184-
0,
1185-
steps,
1201+
_, final_state, final_time = keras.ops.while_loop(
1202+
cond,
11861203
body,
1187-
(state, start_time),
1204+
(0, state, start_time),
11881205
)
1206+
if _check_all_nans(final_state):
1207+
raise RuntimeError(f"All values are NaNs in state during integration at {final_time}.")
11891208
return final_state
11901209

11911210

0 commit comments

Comments
 (0)