@@ -284,12 +284,7 @@ def integrate_adaptive(
284284
285285 tolerance = keras .ops .convert_to_tensor (kwargs .get ("tolerance" , 1e-6 ), dtype = "float32" )
286286 step_fn = partial (step_fn , fn , ** kwargs , use_adaptive_step_size = True )
287-
288- # Initial (conservative) step size guess
289- total_time = stop_time - start_time
290- step_size0 = keras .ops .convert_to_tensor (total_time / max_steps , dtype = "float32" )
291-
292- # Track step count as scalar tensor
287+ initial_step = (stop_time - start_time ) / float (min_steps )
293288 step0 = keras .ops .convert_to_tensor (0.0 , dtype = "float32" )
294289 count_not_accepted = 0
295290
@@ -308,18 +303,10 @@ def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted):
308303
309304 def body (_state , _time , _step_size , _step , _k1 , _count_not_accepted ):
310305 # Time remaining from current point
311- time_remaining = stop_time - _time
312-
313- # Per-step min/max step sizes (like original code)
306+ time_remaining = keras .ops .abs (stop_time - _time )
314307 min_step_size = time_remaining / (max_steps - _step )
315308 max_step_size = time_remaining / keras .ops .maximum (min_steps - _step , 1.0 )
316-
317- # Ensure ordering: min_step_size <= max_step_size
318- lower = keras .ops .minimum (min_step_size , max_step_size )
319- upper = keras .ops .maximum (min_step_size , max_step_size )
320- min_step_size = lower
321- max_step_size = upper
322- h = keras .ops .clip (_step_size , min_step_size , max_step_size )
309+ h = keras .ops .sign (_step_size ) * keras .ops .clip (keras .ops .abs (_step_size ), min_step_size , max_step_size )
323310
324311 # Take one trial step
325312 new_state , new_time , new_k1 , err = step_fn (
@@ -330,7 +317,9 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted):
330317 )
331318
332319 new_step_size = h * keras .ops .clip (0.9 * (tolerance / (err + 1e-12 )) ** 0.2 , 0.2 , 5.0 )
333- new_step_size = keras .ops .clip (new_step_size , min_step_size , max_step_size )
320+ new_step_size = keras .ops .sign (new_step_size ) * keras .ops .clip (
321+ keras .ops .abs (new_step_size ), min_step_size , max_step_size
322+ )
334323
335324 # Error control: reject if err > tolerance
336325 too_big = keras .ops .greater (err , tolerance )
@@ -355,7 +344,7 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted):
355344 state , time , step_size , step , k1 , count_not_accepted = keras .ops .while_loop (
356345 cond ,
357346 body ,
358- [state , start_time , step_size0 , step0 , k1_0 , count_not_accepted ],
347+ [state , start_time , initial_step , step0 , k1_0 , count_not_accepted ],
359348 )
360349
361350 # Final step to hit stop_time exactly
0 commit comments