@@ -224,17 +224,22 @@ def integrate_fixed(
224224 step_fn = partial (step_fn , fn , ** kwargs , use_adaptive_step_size = False )
225225 step_size = (stop_time - start_time ) / steps
226226
227- time = start_time
228-
229- def body (_loop_var , _loop_state ):
230- _state , _time = _loop_state
231- _state , _time , _ , _ = step_fn (_state , _time , step_size )
232- if _check_all_nans (_state ):
233- raise RuntimeError (f"All values are NaNs in state during integration at { _time } ." )
234- return _state , _time
227+ def cond (_loop_var , _loop_state , _loop_time ):
228+ all_nans = _check_all_nans (_loop_state )
229+ end_now = keras .ops .less (_loop_var , steps )
230+ return keras .ops .logical_and (~ all_nans , end_now )
235231
236- state , time = keras .ops .fori_loop (0 , steps , body , (state , time ))
232+ def body (_loop_var , _loop_state , _loop_time ):
233+ _loop_state , _loop_time , _ , _ = step_fn (_loop_state , _loop_time , step_size )
234+ return _loop_var + 1 , _loop_state , _loop_time
237235
236+ _ , state , _ = keras .ops .while_loop (
237+ cond ,
238+ body ,
239+ [0 , state , start_time ],
240+ )
241+ if _check_all_nans (state ):
242+ raise RuntimeError ("All values are NaNs in state during integration." )
238243 return state
239244
240245
@@ -259,16 +264,25 @@ def integrate_scheduled(
259264
260265 step_fn = partial (step_fn , fn , ** kwargs , use_adaptive_step_size = False )
261266
267+ def cond (_loop_var , _loop_state ):
268+ all_nans = _check_all_nans (_loop_state )
269+ end_now = keras .ops .less (_loop_var , len (steps ) - 1 )
270+ return keras .ops .logical_and (~ all_nans , end_now )
271+
262272 def body (_loop_var , _loop_state ):
263273 _time = steps [_loop_var ]
264274 step_size = steps [_loop_var + 1 ] - steps [_loop_var ]
265275 _loop_state , _ , _ , _ = step_fn (_loop_state , _time , step_size )
276+ return _loop_var + 1 , _loop_state
266277
267- if _check_all_nans (_loop_state ):
268- raise RuntimeError (f"All values are NaNs in state during integration at { _time } ." )
269- return _loop_state
278+ _ , state = keras .ops .while_loop (
279+ cond ,
280+ body ,
281+ [0 , state ],
282+ )
270283
271- state = keras .ops .fori_loop (0 , len (steps ) - 1 , body , state )
284+ if _check_all_nans (state ):
285+ raise RuntimeError ("All values are NaNs in state during integration." )
272286 return state
273287
274288
@@ -635,7 +649,7 @@ def two_step_adaptive_step(
635649 min_step_size = min_step_size ,
636650 max_step_size = max_step_size ,
637651 noise = noise ,
638- use_adaptive_step_size = True ,
652+ use_adaptive_step_size = False ,
639653 )
640654
641655 # Compute drift and diffusion at new state, but update from old state
@@ -957,9 +971,12 @@ def integrate_stochastic_fixed(
957971 """
958972 initial_step = (stop_time - start_time ) / float (steps )
959973
960- def body_fixed (_i , _loop_state ):
961- _current_state , _current_time , _current_step = _loop_state
974+ def cond (_loop_var , _loop_state , _loop_time , _loop_step ):
975+ all_nans = _check_all_nans (_loop_state )
976+ end_now = keras .ops .less (_loop_var , steps )
977+ return keras .ops .logical_and (~ all_nans , end_now )
962978
979+ def body (_i , _current_state , _current_time , _current_step ):
963980 # Determine step size: either the constant size or the remainder to reach stop_time
964981 remaining = keras .ops .abs (stop_time - _current_time )
965982 sign = keras .ops .sign (_current_step )
@@ -994,13 +1011,16 @@ def body_fixed(_i, _loop_state):
9941011 step_size_factor = step_size_factor ,
9951012 corrector_noise_history = corrector_noise_history ,
9961013 )
997- all_nans = _check_all_nans (new_state )
998- if all_nans :
999- raise RuntimeError (f"All values are NaNs in state during integration at { _current_time } ." )
1000- return new_state , new_time , initial_step
1014+ return _i + 1 , new_state , new_time , initial_step
1015+
1016+ _ , final_state , final_time , _ = keras .ops .while_loop (
1017+ cond ,
1018+ body ,
1019+ [0 , state , start_time , initial_step ],
1020+ )
1021+ if _check_all_nans (final_state ):
1022+ raise RuntimeError (f"All values are NaNs in state during integration at { final_time } ." )
10011023
1002- # Execute the fixed loop
1003- final_state , final_time , _ = keras .ops .fori_loop (0 , steps , body_fixed , (state , start_time , initial_step ))
10041024 return final_state
10051025
10061026
@@ -1024,22 +1044,21 @@ def integrate_stochastic_adaptive(
10241044 """
10251045 Performs adaptive-step SDE integration.
10261046 """
1027- initial_loop_state = (keras .ops .zeros ((), dtype = "int32" ), state , start_time , initial_step , 0 , state )
1047+ initial_loop_state = (keras .ops .zeros ((), dtype = "int32" ), state , start_time , initial_step , state )
10281048
1029- def cond (i , current_state , current_time , current_step , counter , last_state ):
1049+ def cond (i , current_state , current_time , current_step , last_state ):
10301050 time_remaining = keras .ops .sign (stop_time - start_time ) * (stop_time - (current_time + current_step ))
10311051 all_nans = _check_all_nans (current_state )
10321052 end_now = keras .ops .logical_and (keras .ops .all (time_remaining > 0 ), keras .ops .less (i , max_steps ))
10331053 return keras .ops .logical_and (~ all_nans , end_now )
10341054
1035- def body_adaptive (_i , _current_state , _current_time , _current_step , _counter , _last_state ):
1055+ def body_adaptive (_i , _current_state , _current_time , _current_step , _last_state ):
10361056 # Step Size Control
10371057 remaining = keras .ops .abs (stop_time - _current_time )
10381058 sign = keras .ops .sign (_current_step )
10391059 # Ensure the next step does not overshoot the stop_time
10401060 dt_mag = keras .ops .minimum (keras .ops .abs (_current_step ), remaining )
10411061 dt = sign * dt_mag
1042- _counter += 1
10431062
10441063 _noise_i = {k : z_history [k ][_i ] for k in _current_state .keys ()}
10451064 _noise_extra_i = None
@@ -1069,12 +1088,10 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _l
10691088 corrector_noise_history = corrector_noise_history ,
10701089 )
10711090
1072- return _i + 1 , new_state , new_time , new_step , _counter , _new_current_state
1091+ return _i + 1 , new_state , new_time , new_step , _new_current_state
10731092
10741093 # Execute the adaptive loop
1075- _ , final_state , final_time , _ , final_counter , final_k1 = keras .ops .while_loop (
1076- cond , body_adaptive , initial_loop_state
1077- )
1094+ final_counter , final_state , final_time , _ , final_k1 = keras .ops .while_loop (cond , body_adaptive , initial_loop_state )
10781095
10791096 if _check_all_nans (final_state ):
10801097 raise RuntimeError (f"All values are NaNs in state during integration at { final_time } ." )
@@ -1143,27 +1160,30 @@ def integrate_langevin(
11431160 dt = (stop_time - start_time ) / float (steps )
11441161 effective_factor = step_size_factor * 100 / np .sqrt (steps )
11451162
1146- def body (_i , loop_state ):
1147- current_state , current_time = loop_state
1163+ def cond (_loop_var , _loop_state , _loop_time ):
1164+ all_nans = _check_all_nans (_loop_state )
1165+ end_now = keras .ops .less (_loop_var , steps )
1166+ return keras .ops .logical_and (~ all_nans , end_now )
11481167
1168+ def body (_i , _loop_state , _loop_time ):
11491169 # score at current time
1150- score = score_fn (current_time , ** filter_kwargs (current_state , score_fn ))
1170+ score = score_fn (_loop_time , ** filter_kwargs (_loop_state , score_fn ))
11511171
11521172 # noise schedule
1153- log_snr_t = noise_schedule .get_log_snr (t = current_time , training = False )
1173+ log_snr_t = noise_schedule .get_log_snr (t = _loop_time , training = False )
11541174 _ , sigma_t = noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
11551175
11561176 new_state : StateDict = {}
1157- for k in current_state .keys ():
1177+ for k in _loop_state .keys ():
11581178 s_k = score .get (k , None )
11591179 if s_k is None :
1160- new_state [k ] = current_state [k ]
1180+ new_state [k ] = _loop_state [k ]
11611181 continue
11621182
11631183 e = effective_factor * sigma_t ** 2
1164- new_state [k ] = current_state [k ] + e * s_k + keras .ops .sqrt (2.0 * e ) * z_history [k ][_i ]
1184+ new_state [k ] = _loop_state [k ] + e * s_k + keras .ops .sqrt (2.0 * e ) * z_history [k ][_i ]
11651185
1166- new_time = current_time + dt
1186+ new_time = _loop_time + dt
11671187
11681188 new_state = _apply_corrector (
11691189 new_state = new_state ,
@@ -1175,17 +1195,16 @@ def body(_i, loop_state):
11751195 step_size_factor = step_size_factor ,
11761196 corrector_noise_history = corrector_noise_history ,
11771197 )
1178- if _check_all_nans (new_state ):
1179- raise RuntimeError (f"All values are NaNs in state during integration at { current_time } ." )
11801198
1181- return new_state , new_time
1199+ return _i + 1 , new_state , new_time
11821200
1183- final_state , _ = keras .ops .fori_loop (
1184- 0 ,
1185- steps ,
1201+ _ , final_state , final_time = keras .ops .while_loop (
1202+ cond ,
11861203 body ,
1187- (state , start_time ),
1204+ (0 , state , start_time ),
11881205 )
1206+ if _check_all_nans (final_state ):
1207+ raise RuntimeError (f"All values are NaNs in state during integration at { final_time } ." )
11891208 return final_state
11901209
11911210
0 commit comments