Skip to content

Commit 1fe2c60

Browse files
committed
improved defaults
1 parent dd021bb commit 1fe2c60

File tree

2 files changed

+53
-44
lines changed

2 files changed

+53
-44
lines changed

bayesflow/utils/integrate.py

Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
DETERMINISTIC_METHODS = ["euler", "rk45", "tsit5"]
22-
STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "langevin", "fast_adaptive"]
22+
STOCHASTIC_METHODS = ["euler_maruyama", "sea", "shark", "two_step_adaptive", "langevin"]
2323

2424

2525
def euler_step(
@@ -509,7 +509,6 @@ def euler_maruyama_step(
509509
use_adaptive_step_size: bool = False,
510510
min_step_size: float = -float("inf"),
511511
max_step_size: float = float("inf"),
512-
adaptive_factor: float = 0.01,
513512
**kwargs,
514513
) -> Union[Tuple[StateDict, ArrayLike, ArrayLike], Tuple[StateDict, ArrayLike, ArrayLike, StateDict]]:
515514
"""
@@ -525,7 +524,6 @@ def euler_maruyama_step(
525524
use_adaptive_step_size: Whether to use adaptive step sizing.
526525
min_step_size: Minimum allowed step size.
527526
max_step_size: Maximum allowed step size.
528-
adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1).
529527
530528
Returns:
531529
new_state: Updated state after one Euler-Maruyama step.
@@ -541,7 +539,7 @@ def euler_maruyama_step(
541539
new_step_size = stochastic_adaptive_step_size_controller(
542540
state=state,
543541
drift=drift,
544-
adaptive_factor=adaptive_factor,
542+
adaptive_factor=max_step_size,
545543
min_step_size=min_step_size,
546544
max_step_size=max_step_size,
547545
)
@@ -561,7 +559,7 @@ def euler_maruyama_step(
561559
return new_state, time + new_step_size, new_step_size
562560

563561

564-
def fast_adaptive_step(
562+
def two_step_adaptive_step(
565563
drift_fn: Callable,
566564
diffusion_fn: Callable,
567565
state: StateDict,
@@ -572,8 +570,8 @@ def fast_adaptive_step(
572570
use_adaptive_step_size: bool = True,
573571
min_step_size: float = -float("inf"),
574572
max_step_size: float = float("inf"),
575-
e_abs: float = 0.01,
576-
e_rel: float = 0.01,
573+
e_rel: float = 0.1,
574+
e_abs: float = None,
577575
r: float = 0.9,
578576
adapt_safety: float = 0.9,
579577
**kwargs,
@@ -608,8 +606,8 @@ def fast_adaptive_step(
608606
use_adaptive_step_size: Whether to adapt step size.
609607
min_step_size: Minimum allowed step size.
610608
max_step_size: Maximum allowed step size.
611-
e_abs: Absolute error tolerance.
612609
e_rel: Relative error tolerance.
610+
e_abs: Absolute error tolerance. Default assumes standardized targets.
613611
r: Order of the method for step size adaptation.
614612
adapt_safety: Safety factor for step size adaptation.
615613
**kwargs: Additional arguments passed to drift_fn and diffusion_fn.
@@ -650,6 +648,8 @@ def fast_adaptive_step(
650648

651649
# Error estimation
652650
if use_adaptive_step_size:
651+
if e_abs is None:
652+
e_abs = 0.02576 # 1% of 99% CI of standardized unit variance
653653
# Check if we're at minimum step size - if so, force acceptance
654654
at_min_step = keras.ops.less_equal(step_size, min_step_size)
655655

@@ -709,13 +709,33 @@ def fast_adaptive_step(
709709
return state_heun, time_mid, step_size
710710

711711

712+
def compute_levy_area(
713+
state: StateDict, diffusion: StateDict, noise: StateDict, noise_aux: StateDict, step_size: ArrayLike
714+
) -> StateDict:
715+
step_size_abs = keras.ops.abs(step_size)
716+
sqrt_step_size = keras.ops.sqrt(step_size_abs)
717+
inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(step_size_abs))
718+
719+
# Build Lévy area H_k from w_k and Z_k
720+
H = {}
721+
for k in state.keys():
722+
if k in diffusion:
723+
term1 = 0.5 * step_size_abs * noise[k]
724+
term2 = 0.5 * step_size_abs * sqrt_step_size * inv_sqrt3 * noise_aux[k]
725+
H[k] = term1 + term2
726+
else:
727+
H[k] = keras.ops.zeros_like(state[k])
728+
return H
729+
730+
712731
def sea_step(
713732
drift_fn: Callable,
714733
diffusion_fn: Callable,
715734
state: StateDict,
716735
time: ArrayLike,
717736
step_size: ArrayLike,
718-
noise: StateDict,
737+
noise: StateDict, # standard normals
738+
noise_aux: StateDict, # standard normals
719739
**kwargs,
720740
) -> Tuple[StateDict, ArrayLike, ArrayLike]:
721741
"""
@@ -725,7 +745,7 @@ def sea_step(
725745
which improves the local error and the global error constant for additive noise.
726746
727747
The scheme is
728-
X_{n+1} = X_n + f(t_n, X_n + 0.5 * g(t_n) * ΔW_n) * h + g(t_n) * ΔW_n
748+
X_{n+1} = X_n + f(t_n, X_n + g(t_n) * (0.5 * ΔW_n + ΔH_n) * h + g(t_n) * ΔW_n
729749
730750
[1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023)
731751
Args:
@@ -735,20 +755,23 @@ def sea_step(
735755
time: Current time scalar tensor.
736756
step_size: Time increment dt.
737757
noise: Mapping of variable names to dW noise tensors.
758+
noise_aux: Mapping of variable names to auxiliary noise.
738759
739760
Returns:
740761
new_state: Updated state after one SEA step.
741762
new_time: time + dt.
742763
"""
743-
# Compute diffusion (assumed additive or weakly state dependent)
764+
# Compute diffusion
744765
diffusion = diffusion_fn(time, **filter_kwargs(state, diffusion_fn))
745766
sqrt_step_size = keras.ops.sqrt(keras.ops.abs(step_size))
746767

747-
# Build shifted state: X_shift = X + 0.5 * g * ΔW
768+
la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size)
769+
770+
# Build shifted state: X_shift = X + g * (0.5 * ΔW + ΔH)
748771
shifted_state = {}
749772
for key, x in state.items():
750773
if key in diffusion:
751-
shifted_state[key] = x + 0.5 * diffusion[key] * sqrt_step_size * noise[key]
774+
shifted_state[key] = x + diffusion[key] * (0.5 * sqrt_step_size * noise[key] + la[key])
752775
else:
753776
shifted_state[key] = x
754777

@@ -810,33 +833,18 @@ def shark_step(
810833
"""
811834
h = step_size
812835
t = time
813-
814-
# Magnitude of the time step for stochastic scaling
815836
h_mag = keras.ops.abs(h)
816-
# h_sign = keras.ops.sign(h)
817837
sqrt_h_mag = keras.ops.sqrt(h_mag)
818-
inv_sqrt3 = keras.ops.cast(1.0 / np.sqrt(3.0), dtype=keras.ops.dtype(h_mag))
819838

820-
# g(y_k)
821-
g0 = diffusion_fn(t, **filter_kwargs(state, diffusion_fn))
839+
diffusion = diffusion_fn(t, **filter_kwargs(state, diffusion_fn))
822840

823-
# Build H_k from w_k and Z_k
824-
H = {}
825-
for k in state.keys():
826-
if k in g0:
827-
w_k = sqrt_h_mag * noise[k]
828-
z_k = noise_aux[k] # standard normal
829-
term1 = 0.5 * h_mag * w_k
830-
term2 = 0.5 * h_mag * sqrt_h_mag * inv_sqrt3 * z_k
831-
H[k] = term1 + term2
832-
else:
833-
H[k] = keras.ops.zeros_like(state[k])
841+
la = compute_levy_area(state=state, diffusion=diffusion, noise=noise, noise_aux=noise_aux, step_size=step_size)
834842

835843
# === 1) shifted initial state ===
836844
y_tilde_k = {}
837845
for k in state.keys():
838-
if k in g0:
839-
y_tilde_k[k] = state[k] + g0[k] * H[k]
846+
if k in diffusion:
847+
y_tilde_k[k] = state[k] + diffusion[k] * la[k]
840848
else:
841849
y_tilde_k[k] = state[k]
842850

@@ -866,12 +874,12 @@ def shark_step(
866874

867875
# stochastic parts
868876
sto1 = (
869-
g_tilde_k[k] * ((2.0 / 5.0) * sqrt_h_mag * noise[k] + (6.0 / 5.0) * H[k])
877+
g_tilde_k[k] * ((2.0 / 5.0) * sqrt_h_mag * noise[k] + (6.0 / 5.0) * la[k])
870878
if k in g_tilde_k
871879
else keras.ops.zeros_like(det)
872880
)
873881
sto2 = (
874-
g_tilde_mid[k] * ((3.0 / 5.0) * sqrt_h_mag * noise[k] - (6.0 / 5.0) * H[k])
882+
g_tilde_mid[k] * ((3.0 / 5.0) * sqrt_h_mag * noise[k] - (6.0 / 5.0) * la[k])
875883
if k in g_tilde_mid
876884
else keras.ops.zeros_like(det)
877885
)
@@ -1154,7 +1162,7 @@ def integrate_stochastic(
11541162
seed: keras.random.SeedGenerator,
11551163
steps: int | Literal["adaptive"] = 100,
11561164
method: str = "euler_maruyama",
1157-
min_steps: int = 20,
1165+
min_steps: int = 10,
11581166
max_steps: int = 10_000,
11591167
score_fn: Callable = None,
11601168
corrector_steps: int = 0,
@@ -1229,8 +1237,8 @@ def integrate_stochastic(
12291237
step_fn_raw = shark_step
12301238
if is_adaptive:
12311239
raise ValueError("SHARK SDE solver does not support adaptive steps.")
1232-
case "fast_adaptive":
1233-
step_fn_raw = fast_adaptive_step
1240+
case "two_step_adaptive":
1241+
step_fn_raw = two_step_adaptive_step
12341242
case "langevin":
12351243
if is_adaptive:
12361244
raise ValueError("Langevin sampling does not support adaptive steps.")
@@ -1269,7 +1277,7 @@ def integrate_stochastic(
12691277
for key, val in state.items():
12701278
shape = keras.ops.shape(val)
12711279
z_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed)
1272-
if method == "shark":
1280+
if method in ["sea", "shark"]:
12731281
z_extra_history[key] = keras.random.normal((loop_steps, *shape), dtype=keras.ops.dtype(val), seed=seed)
12741282

12751283
if is_adaptive:

tests/test_utils/test_integrate.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def fn(t, x):
101101
("euler_maruyama", True),
102102
("sea", False),
103103
("shark", False),
104-
("fast_adaptive", False),
105-
("fast_adaptive", True),
104+
("two_step_adaptive", False),
105+
("two_step_adaptive", True),
106106
],
107107
)
108108
def test_forward_additive_ou_weak_means_and_vars(method, use_adapt):
@@ -167,8 +167,8 @@ def diffusion_fn(t, x):
167167
("euler_maruyama", True),
168168
("sea", False),
169169
("shark", False),
170-
("fast_adaptive", False),
171-
("fast_adaptive", True),
170+
("two_step_adaptive", False),
171+
("two_step_adaptive", True),
172172
],
173173
)
174174
def test_backward_additive_ou_weak_means_and_vars(method, use_adapt):
@@ -218,6 +218,7 @@ def diffusion_fn(t, x):
218218
seed=seed,
219219
method=method,
220220
max_steps=1_000,
221+
min_steps=100,
221222
)
222223

223224
x_0 = np.array(out["x"])
@@ -235,8 +236,8 @@ def diffusion_fn(t, x):
235236
("euler_maruyama", True),
236237
("sea", False),
237238
("shark", False),
238-
("fast_adaptive", False),
239-
("fast_adaptive", True),
239+
("two_step_adaptive", False),
240+
("two_step_adaptive", True),
240241
],
241242
)
242243
def test_zero_noise_reduces_to_deterministic(method, use_adapt):

0 commit comments

Comments
 (0)