Skip to content

Commit a7adea2

Browse files
committed
improved initial step size
1 parent a771e32 commit a7adea2

File tree

1 file changed

+7
-18
lines changed

1 file changed

+7
-18
lines changed

bayesflow/utils/integrate.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,7 @@ def integrate_adaptive(
284284

285285
tolerance = keras.ops.convert_to_tensor(kwargs.get("tolerance", 1e-6), dtype="float32")
286286
step_fn = partial(step_fn, fn, **kwargs, use_adaptive_step_size=True)
287-
288-
# Initial (conservative) step size guess
289-
total_time = stop_time - start_time
290-
step_size0 = keras.ops.convert_to_tensor(total_time / max_steps, dtype="float32")
291-
292-
# Track step count as scalar tensor
287+
initial_step = (stop_time - start_time) / float(min_steps)
293288
step0 = keras.ops.convert_to_tensor(0.0, dtype="float32")
294289
count_not_accepted = 0
295290

@@ -308,18 +303,10 @@ def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted):
308303

309304
def body(_state, _time, _step_size, _step, _k1, _count_not_accepted):
310305
# Time remaining from current point
311-
time_remaining = stop_time - _time
312-
313-
# Per-step min/max step sizes (like original code)
306+
time_remaining = keras.ops.abs(stop_time - _time)
314307
min_step_size = time_remaining / (max_steps - _step)
315308
max_step_size = time_remaining / keras.ops.maximum(min_steps - _step, 1.0)
316-
317-
# Ensure ordering: min_step_size <= max_step_size
318-
lower = keras.ops.minimum(min_step_size, max_step_size)
319-
upper = keras.ops.maximum(min_step_size, max_step_size)
320-
min_step_size = lower
321-
max_step_size = upper
322-
h = keras.ops.clip(_step_size, min_step_size, max_step_size)
309+
h = keras.ops.sign(_step_size) * keras.ops.clip(keras.ops.abs(_step_size), min_step_size, max_step_size)
323310

324311
# Take one trial step
325312
new_state, new_time, new_k1, err = step_fn(
@@ -330,7 +317,9 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted):
330317
)
331318

332319
new_step_size = h * keras.ops.clip(0.9 * (tolerance / (err + 1e-12)) ** 0.2, 0.2, 5.0)
333-
new_step_size = keras.ops.clip(new_step_size, min_step_size, max_step_size)
320+
new_step_size = keras.ops.sign(new_step_size) * keras.ops.clip(
321+
keras.ops.abs(new_step_size), min_step_size, max_step_size
322+
)
334323

335324
# Error control: reject if err > tolerance
336325
too_big = keras.ops.greater(err, tolerance)
@@ -355,7 +344,7 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted):
355344
state, time, step_size, step, k1, count_not_accepted = keras.ops.while_loop(
356345
cond,
357346
body,
358-
[state, start_time, step_size0, step0, k1_0, count_not_accepted],
347+
[state, start_time, initial_step, step0, k1_0, count_not_accepted],
359348
)
360349

361350
# Final step to hit stop_time exactly

0 commit comments

Comments
 (0)