1919
2020
2121DETERMINISTIC_METHODS = ["euler" , "rk45" , "tsit5" ]
22- STOCHASTIC_METHODS = ["euler_maruyama" , "sea" , "shark" , "langevin " , "fast_adaptive " ]
22+ STOCHASTIC_METHODS = ["euler_maruyama" , "sea" , "shark" , "two_step_adaptive " , "langevin " ]
2323
2424
2525def euler_step (
@@ -509,7 +509,6 @@ def euler_maruyama_step(
509509 use_adaptive_step_size : bool = False ,
510510 min_step_size : float = - float ("inf" ),
511511 max_step_size : float = float ("inf" ),
512- adaptive_factor : float = 0.01 ,
513512 ** kwargs ,
514513) -> Union [Tuple [StateDict , ArrayLike , ArrayLike ], Tuple [StateDict , ArrayLike , ArrayLike , StateDict ]]:
515514 """
@@ -525,7 +524,6 @@ def euler_maruyama_step(
525524 use_adaptive_step_size: Whether to use adaptive step sizing.
526525 min_step_size: Minimum allowed step size.
527526 max_step_size: Maximum allowed step size.
528- adaptive_factor: Factor to compute adaptive step size (0 < adaptive_factor < 1).
529527
530528 Returns:
531529 new_state: Updated state after one Euler-Maruyama step.
@@ -541,7 +539,7 @@ def euler_maruyama_step(
541539 new_step_size = stochastic_adaptive_step_size_controller (
542540 state = state ,
543541 drift = drift ,
544- adaptive_factor = adaptive_factor ,
542+ adaptive_factor = max_step_size ,
545543 min_step_size = min_step_size ,
546544 max_step_size = max_step_size ,
547545 )
@@ -561,7 +559,7 @@ def euler_maruyama_step(
561559 return new_state , time + new_step_size , new_step_size
562560
563561
564- def fast_adaptive_step (
562+ def two_step_adaptive_step (
565563 drift_fn : Callable ,
566564 diffusion_fn : Callable ,
567565 state : StateDict ,
@@ -572,8 +570,8 @@ def fast_adaptive_step(
572570 use_adaptive_step_size : bool = True ,
573571 min_step_size : float = - float ("inf" ),
574572 max_step_size : float = float ("inf" ),
575- e_abs : float = 0.01 ,
576- e_rel : float = 0.01 ,
573+ e_rel : float = 0.1 ,
574+ e_abs : float = None ,
577575 r : float = 0.9 ,
578576 adapt_safety : float = 0.9 ,
579577 ** kwargs ,
@@ -608,8 +606,8 @@ def fast_adaptive_step(
608606 use_adaptive_step_size: Whether to adapt step size.
609607 min_step_size: Minimum allowed step size.
610608 max_step_size: Maximum allowed step size.
611- e_abs: Absolute error tolerance.
612609 e_rel: Relative error tolerance.
610+ e_abs: Absolute error tolerance. Default assumes standardized targets.
613611 r: Order of the method for step size adaptation.
614612 adapt_safety: Safety factor for step size adaptation.
615613 **kwargs: Additional arguments passed to drift_fn and diffusion_fn.
@@ -650,6 +648,8 @@ def fast_adaptive_step(
650648
651649 # Error estimation
652650 if use_adaptive_step_size :
651+ if e_abs is None :
652+ e_abs = 0.02576 # 1% of 99% CI of standardized unit variance
653653 # Check if we're at minimum step size - if so, force acceptance
654654 at_min_step = keras .ops .less_equal (step_size , min_step_size )
655655
@@ -709,13 +709,33 @@ def fast_adaptive_step(
709709 return state_heun , time_mid , step_size
710710
711711
712+ def compute_levy_area (
713+ state : StateDict , diffusion : StateDict , noise : StateDict , noise_aux : StateDict , step_size : ArrayLike
714+ ) -> StateDict :
715+ step_size_abs = keras .ops .abs (step_size )
716+ sqrt_step_size = keras .ops .sqrt (step_size_abs )
717+ inv_sqrt3 = keras .ops .cast (1.0 / np .sqrt (3.0 ), dtype = keras .ops .dtype (step_size_abs ))
718+
719+ # Build Lévy area H_k from w_k and Z_k
720+ H = {}
721+ for k in state .keys ():
722+ if k in diffusion :
723+ term1 = 0.5 * step_size_abs * noise [k ]
724+ term2 = 0.5 * step_size_abs * sqrt_step_size * inv_sqrt3 * noise_aux [k ]
725+ H [k ] = term1 + term2
726+ else :
727+ H [k ] = keras .ops .zeros_like (state [k ])
728+ return H
729+
730+
712731def sea_step (
713732 drift_fn : Callable ,
714733 diffusion_fn : Callable ,
715734 state : StateDict ,
716735 time : ArrayLike ,
717736 step_size : ArrayLike ,
718- noise : StateDict ,
737+ noise : StateDict , # standard normals
738+ noise_aux : StateDict , # standard normals
719739 ** kwargs ,
720740) -> Tuple [StateDict , ArrayLike , ArrayLike ]:
721741 """
@@ -725,7 +745,7 @@ def sea_step(
725745 which improves the local error and the global error constant for additive noise.
726746
727747 The scheme is
728- X_{n+1} = X_n + f(t_n, X_n + 0.5 * g(t_n) * ΔW_n) * h + g(t_n) * ΔW_n
748+ X_{n+1} = X_n + f(t_n, X_n + g(t_n) * (0.5 * ΔW_n + ΔH_n ) * h + g(t_n) * ΔW_n
729749
730750 [1] Foster et al., "High order splitting methods for SDEs satisfying a commutativity condition" (2023)
731751 Args:
@@ -735,20 +755,23 @@ def sea_step(
735755 time: Current time scalar tensor.
736756 step_size: Time increment dt.
737757 noise: Mapping of variable names to dW noise tensors.
758+ noise_aux: Mapping of variable names to auxiliary noise.
738759
739760 Returns:
740761 new_state: Updated state after one SEA step.
741762 new_time: time + dt.
742763 """
743- # Compute diffusion (assumed additive or weakly state dependent)
764+ # Compute diffusion
744765 diffusion = diffusion_fn (time , ** filter_kwargs (state , diffusion_fn ))
745766 sqrt_step_size = keras .ops .sqrt (keras .ops .abs (step_size ))
746767
747- # Build shifted state: X_shift = X + 0.5 * g * ΔW
768+ la = compute_levy_area (state = state , diffusion = diffusion , noise = noise , noise_aux = noise_aux , step_size = step_size )
769+
770+ # Build shifted state: X_shift = X + g * (0.5 * ΔW + ΔH)
748771 shifted_state = {}
749772 for key , x in state .items ():
750773 if key in diffusion :
751- shifted_state [key ] = x + 0.5 * diffusion [key ] * sqrt_step_size * noise [key ]
774+ shifted_state [key ] = x + diffusion [key ] * ( 0.5 * sqrt_step_size * noise [key ] + la [ key ])
752775 else :
753776 shifted_state [key ] = x
754777
@@ -810,33 +833,18 @@ def shark_step(
810833 """
811834 h = step_size
812835 t = time
813-
814- # Magnitude of the time step for stochastic scaling
815836 h_mag = keras .ops .abs (h )
816- # h_sign = keras.ops.sign(h)
817837 sqrt_h_mag = keras .ops .sqrt (h_mag )
818- inv_sqrt3 = keras .ops .cast (1.0 / np .sqrt (3.0 ), dtype = keras .ops .dtype (h_mag ))
819838
820- # g(y_k)
821- g0 = diffusion_fn (t , ** filter_kwargs (state , diffusion_fn ))
839+ diffusion = diffusion_fn (t , ** filter_kwargs (state , diffusion_fn ))
822840
823- # Build H_k from w_k and Z_k
824- H = {}
825- for k in state .keys ():
826- if k in g0 :
827- w_k = sqrt_h_mag * noise [k ]
828- z_k = noise_aux [k ] # standard normal
829- term1 = 0.5 * h_mag * w_k
830- term2 = 0.5 * h_mag * sqrt_h_mag * inv_sqrt3 * z_k
831- H [k ] = term1 + term2
832- else :
833- H [k ] = keras .ops .zeros_like (state [k ])
841+ la = compute_levy_area (state = state , diffusion = diffusion , noise = noise , noise_aux = noise_aux , step_size = step_size )
834842
835843 # === 1) shifted initial state ===
836844 y_tilde_k = {}
837845 for k in state .keys ():
838- if k in g0 :
839- y_tilde_k [k ] = state [k ] + g0 [k ] * H [k ]
846+ if k in diffusion :
847+ y_tilde_k [k ] = state [k ] + diffusion [k ] * la [k ]
840848 else :
841849 y_tilde_k [k ] = state [k ]
842850
@@ -866,12 +874,12 @@ def shark_step(
866874
867875 # stochastic parts
868876 sto1 = (
869- g_tilde_k [k ] * ((2.0 / 5.0 ) * sqrt_h_mag * noise [k ] + (6.0 / 5.0 ) * H [k ])
877+ g_tilde_k [k ] * ((2.0 / 5.0 ) * sqrt_h_mag * noise [k ] + (6.0 / 5.0 ) * la [k ])
870878 if k in g_tilde_k
871879 else keras .ops .zeros_like (det )
872880 )
873881 sto2 = (
874- g_tilde_mid [k ] * ((3.0 / 5.0 ) * sqrt_h_mag * noise [k ] - (6.0 / 5.0 ) * H [k ])
882+ g_tilde_mid [k ] * ((3.0 / 5.0 ) * sqrt_h_mag * noise [k ] - (6.0 / 5.0 ) * la [k ])
875883 if k in g_tilde_mid
876884 else keras .ops .zeros_like (det )
877885 )
@@ -1154,7 +1162,7 @@ def integrate_stochastic(
11541162 seed : keras .random .SeedGenerator ,
11551163 steps : int | Literal ["adaptive" ] = 100 ,
11561164 method : str = "euler_maruyama" ,
1157- min_steps : int = 20 ,
1165+ min_steps : int = 10 ,
11581166 max_steps : int = 10_000 ,
11591167 score_fn : Callable = None ,
11601168 corrector_steps : int = 0 ,
@@ -1229,8 +1237,8 @@ def integrate_stochastic(
12291237 step_fn_raw = shark_step
12301238 if is_adaptive :
12311239 raise ValueError ("SHARK SDE solver does not support adaptive steps." )
1232- case "fast_adaptive " :
1233- step_fn_raw = fast_adaptive_step
1240+ case "two_step_adaptive " :
1241+ step_fn_raw = two_step_adaptive_step
12341242 case "langevin" :
12351243 if is_adaptive :
12361244 raise ValueError ("Langevin sampling does not support adaptive steps." )
@@ -1269,7 +1277,7 @@ def integrate_stochastic(
12691277 for key , val in state .items ():
12701278 shape = keras .ops .shape (val )
12711279 z_history [key ] = keras .random .normal ((loop_steps , * shape ), dtype = keras .ops .dtype (val ), seed = seed )
1272- if method == " shark" :
1280+ if method in [ "sea" , " shark"] :
12731281 z_extra_history [key ] = keras .random .normal ((loop_steps , * shape ), dtype = keras .ops .dtype (val ), seed = seed )
12741282
12751283 if is_adaptive :
0 commit comments