Skip to content

Commit 08853fb

Browse files
committed
improved initial step size
1 parent a7adea2 commit 08853fb

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

bayesflow/utils/integrate.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def two_step_adaptive_step(
640640
if e_abs is None:
641641
e_abs = 0.02576 # 1% of 99% CI of standardized unit variance
642642
# Check if we're at minimum step size - if so, force acceptance
643-
at_min_step = keras.ops.less_equal(step_size, min_step_size)
643+
at_min_step = keras.ops.less_equal(keras.ops.abs(step_size), min_step_size)
644644

645645
# Compute error tolerance for each component
646646
e_abs_tensor = keras.ops.cast(e_abs, dtype=keras.ops.dtype(list(state.values())[0]))
@@ -681,9 +681,8 @@ def two_step_adaptive_step(
681681
new_step_candidate = step_size * adapt_factor
682682

683683
# Clamp to valid range
684-
sign_step = keras.ops.sign(step_size)
685-
new_step_size = keras.ops.minimum(keras.ops.maximum(new_step_candidate, min_step_size), max_step_size)
686-
new_step_size = sign_step * keras.ops.abs(new_step_size)
684+
new_step_size = keras.ops.clip(keras.ops.abs(new_step_candidate), min_step_size, max_step_size)
685+
new_step_size = keras.ops.sign(step_size) * new_step_size
687686

688687
# Return appropriate state based on acceptance
689688
new_state = keras.ops.cond(accepted, lambda: state_heun, lambda: state)
@@ -1147,7 +1146,7 @@ def integrate_stochastic(
11471146
seed: keras.random.SeedGenerator,
11481147
steps: int | Literal["adaptive"] = 100,
11491148
method: str = "euler_maruyama",
1150-
min_steps: int = 10,
1149+
min_steps: int = 50,
11511150
max_steps: int = 10_000,
11521151
score_fn: Callable = None,
11531152
corrector_steps: int = 0,

tests/test_utils/test_integrate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,6 @@ def diffusion_fn(t, x):
218218
seed=seed,
219219
method=method,
220220
max_steps=1_000,
221-
min_steps=100,
222221
)
223222

224223
x_0 = np.array(out["x"])

0 commit comments

Comments
 (0)