@@ -640,7 +640,7 @@ def two_step_adaptive_step(
640640 if e_abs is None :
641641 e_abs = 0.02576 # 1% of 99% CI of standardized unit variance
642642 # Check if we're at minimum step size - if so, force acceptance
643- at_min_step = keras .ops .less_equal (step_size , min_step_size )
643+ at_min_step = keras .ops .less_equal (keras . ops . abs ( step_size ) , min_step_size )
644644
645645 # Compute error tolerance for each component
646646 e_abs_tensor = keras .ops .cast (e_abs , dtype = keras .ops .dtype (list (state .values ())[0 ]))
@@ -681,9 +681,8 @@ def two_step_adaptive_step(
681681 new_step_candidate = step_size * adapt_factor
682682
683683 # Clamp to valid range
684- sign_step = keras .ops .sign (step_size )
685- new_step_size = keras .ops .minimum (keras .ops .maximum (new_step_candidate , min_step_size ), max_step_size )
686- new_step_size = sign_step * keras .ops .abs (new_step_size )
684+ new_step_size = keras .ops .clip (keras .ops .abs (new_step_candidate ), min_step_size , max_step_size )
685+ new_step_size = keras .ops .sign (step_size ) * new_step_size
687686
688687 # Return appropriate state based on acceptance
689688 new_state = keras .ops .cond (accepted , lambda : state_heun , lambda : state )
@@ -1147,7 +1146,7 @@ def integrate_stochastic(
11471146 seed : keras .random .SeedGenerator ,
11481147 steps : int | Literal ["adaptive" ] = 100 ,
11491148 method : str = "euler_maruyama" ,
1150- min_steps : int = 10 ,
1149+ min_steps : int = 50 ,
11511150 max_steps : int = 10_000 ,
11521151 score_fn : Callable = None ,
11531152 corrector_steps : int = 0 ,
0 commit comments