Skip to content

Commit e585708

Browse files
committed
fix density computation
1 parent 5a1a3fa commit e585708

File tree

3 files changed

+49
-36
lines changed

3 files changed

+49
-36
lines changed

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,12 @@ def _forward(
413413
raise ValueError("Stochastic methods are not supported for forward integration.")
414414

415415
if density:
416+
if integrate_kwargs["steps"] == "adaptive":
417+
logging.warning(
418+
"Using adaptive integration for density estimation can lead to "
419+
"problems with autodiff. Switching to 200 fixed steps instead."
420+
)
421+
integrate_kwargs["steps"] = 200
416422

417423
def deltas(time, xz):
418424
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
@@ -461,6 +467,12 @@ def _inverse(
461467
if density:
462468
if integrate_kwargs["method"] in STOCHASTIC_METHODS:
463469
raise ValueError("Stochastic methods are not supported for density computation.")
470+
if integrate_kwargs["steps"] == "adaptive":
471+
logging.warning(
472+
"Using adaptive integration for density estimation can lead to "
473+
"problems with autodiff. Switching to 200 fixed steps instead."
474+
)
475+
integrate_kwargs["steps"] = 200
464476

465477
def deltas(time, xz):
466478
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from collections.abc import Sequence
23

34
import keras
@@ -236,14 +237,21 @@ def f(x):
236237
def _forward(
237238
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
238239
) -> Tensor | tuple[Tensor, Tensor]:
240+
integrate_kwargs = self.integrate_kwargs | kwargs
239241
if density:
242+
if integrate_kwargs["steps"] == "adaptive":
243+
logging.warning(
244+
"Using adaptive integration for density estimation can lead to "
245+
"problems with autodiff. Switching to 200 fixed steps instead."
246+
)
247+
integrate_kwargs["steps"] = 200
240248

241249
def deltas(time, xz):
242250
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
243251
return {"xz": v, "trace": trace}
244252

245253
state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))}
246-
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))
254+
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs)
247255

248256
z = state["xz"]
249257
log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1)
@@ -254,7 +262,7 @@ def deltas(time, xz):
254262
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
255263

256264
state = {"xz": x}
257-
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))
265+
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs)
258266

259267
z = state["xz"]
260268

@@ -263,14 +271,21 @@ def deltas(time, xz):
263271
def _inverse(
264272
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
265273
) -> Tensor | tuple[Tensor, Tensor]:
274+
integrate_kwargs = self.integrate_kwargs | kwargs
266275
if density:
276+
if integrate_kwargs["steps"] == "adaptive":
277+
logging.warning(
278+
"Using adaptive integration for density estimation can lead to "
279+
"problems with autodiff. Switching to 200 fixed steps instead."
280+
)
281+
integrate_kwargs["steps"] = 200
267282

268283
def deltas(time, xz):
269284
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
270285
return {"xz": v, "trace": trace}
271286

272287
state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))}
273-
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))
288+
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs)
274289

275290
x = state["xz"]
276291
log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1)
@@ -281,7 +296,7 @@ def deltas(time, xz):
281296
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
282297

283298
state = {"xz": z}
284-
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))
299+
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs)
285300

286301
x = state["xz"]
287302

bayesflow/utils/integrate.py

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,6 @@
2222
STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"]
2323

2424

25-
def _check_all_nans(state: StateDict):
26-
all_nans_flags = []
27-
for v in state.values():
28-
all_nans_flags.append(keras.ops.all(keras.ops.isnan(v)))
29-
return keras.ops.all(keras.ops.stack(all_nans_flags))
30-
31-
3225
def euler_step(
3326
fn: Callable,
3427
state: StateDict,
@@ -243,22 +236,17 @@ def integrate_fixed(
243236
step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False)
244237
step_size = (stop_time - start_time) / steps
245238

246-
def cond(_loop_var, _loop_state, _loop_time):
247-
all_nans = _check_all_nans(_loop_state)
248-
end_now = keras.ops.less(_loop_var, steps)
249-
return keras.ops.logical_and(~all_nans, end_now)
250-
251-
def body(_loop_var, _loop_state, _loop_time):
252-
_loop_state, _loop_time, _, _ = step_fn(_loop_state, _loop_time, step_size)
253-
return _loop_var + 1, _loop_state, _loop_time
239+
def body(_loop_var, _loop_state):
240+
_state, _time = _loop_state
241+
_state, _time, _, _ = step_fn(_state, _time, step_size)
242+
return _state, _time
254243

255-
_, state, _ = keras.ops.while_loop(
256-
cond,
244+
state, _ = keras.ops.fori_loop(
245+
0,
246+
steps,
257247
body,
258-
[0, state, start_time],
248+
(state, start_time),
259249
)
260-
if _check_all_nans(state):
261-
raise RuntimeError("All values are NaNs in state during integration.")
262250
return state
263251

264252

@@ -283,25 +271,18 @@ def integrate_scheduled(
283271

284272
step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=False)
285273

286-
def cond(_loop_var, _loop_state):
287-
all_nans = _check_all_nans(_loop_state)
288-
end_now = keras.ops.less(_loop_var, len(steps) - 1)
289-
return keras.ops.logical_and(~all_nans, end_now)
290-
291274
def body(_loop_var, _loop_state):
292275
_time = steps[_loop_var]
293276
step_size = steps[_loop_var + 1] - steps[_loop_var]
294277
_loop_state, _, _, _ = step_fn(_loop_state, _time, step_size)
295-
return _loop_var + 1, _loop_state
278+
return _loop_state
296279

297-
_, state = keras.ops.while_loop(
298-
cond,
280+
state = keras.ops.fori_loop(
281+
0,
282+
keras.ops.shape(steps)[0] - 1,
299283
body,
300-
[0, state],
284+
state,
301285
)
302-
303-
if _check_all_nans(state):
304-
raise RuntimeError("All values are NaNs in state during integration.")
305286
return state
306287

307288

@@ -501,6 +482,11 @@ def integrate(
501482

502483

503484
############ SDE Solvers #############
485+
def _check_all_nans(state: StateDict):
486+
all_nans_flags = []
487+
for v in state.values():
488+
all_nans_flags.append(keras.ops.all(keras.ops.isnan(v)))
489+
return keras.ops.all(keras.ops.stack(all_nans_flags))
504490

505491

506492
def stochastic_adaptive_step_size_controller(

0 commit comments

Comments
 (0)