2222STOCHASTIC_METHODS = ["euler_maruyama" , "sea" , "shark" , "two_step_adaptive" , "langevin" ]
2323
2424
25+ def _check_all_nans (state : StateDict ):
26+ all_nans_flags = []
27+ for v in state .values ():
28+ all_nans_flags .append (keras .ops .all (keras .ops .isnan (v )))
29+ return keras .ops .all (keras .ops .stack (all_nans_flags ))
30+
31+
2532def euler_step (
2633 fn : Callable ,
2734 state : StateDict ,
@@ -218,7 +225,8 @@ def integrate_fixed(
218225 def body (_loop_var , _loop_state ):
219226 _state , _time = _loop_state
220227 _state , _time , _ , _ = step_fn (_state , _time , step_size )
221-
228+ if _check_all_nans (_state ):
229+ raise RuntimeError (f"All values are NaNs in state during integration at { _time } ." )
222230 return _state , _time
223231
224232 state , time = keras .ops .fori_loop (0 , steps , body , (state , time ))
@@ -251,6 +259,9 @@ def body(_loop_var, _loop_state):
251259 _time = steps [_loop_var ]
252260 step_size = steps [_loop_var + 1 ] - steps [_loop_var ]
253261 _loop_state , _ , _ , _ = step_fn (_loop_state , _time , step_size )
262+
263+ if _check_all_nans (_loop_state ):
264+ raise RuntimeError (f"All values are NaNs in state during integration at { _time } ." )
254265 return _loop_state
255266
256267 state = keras .ops .fori_loop (0 , len (steps ) - 1 , body , state )
@@ -296,10 +307,12 @@ def cond(_state, _time, _step_size, _step, _k1, _count_not_accepted):
296307 step_lt_min = keras .ops .less (_step , float (min_steps ))
297308 step_lt_max = keras .ops .less (_step , float (max_steps ))
298309
299- return keras .ops .logical_or (
300- step_lt_min ,
301- keras .ops .logical_and (keras .ops .all (time_remaining > 0 ), step_lt_max ),
310+ all_nans = _check_all_nans (_state )
311+
312+ end_now = keras .ops .logical_or (
313+ step_lt_min , keras .ops .logical_and (keras .ops .all (time_remaining > 0 ), step_lt_max )
302314 )
315+ return keras .ops .logical_and (~ all_nans , end_now )
303316
304317 def body (_state , _time , _step_size , _step , _k1 , _count_not_accepted ):
305318 # Time remaining from current point
@@ -347,6 +360,9 @@ def body(_state, _time, _step_size, _step, _k1, _count_not_accepted):
347360 [state , start_time , initial_step , step0 , k1_0 , count_not_accepted ],
348361 )
349362
363+ if _check_all_nans (state ):
364+ raise RuntimeError (f"All values are NaNs in state during integration at { time } ." )
365+
350366 # Final step to hit stop_time exactly
351367 time_diff = stop_time - time
352368 time_remaining = keras .ops .sign (stop_time - start_time ) * time_diff
@@ -974,6 +990,9 @@ def body_fixed(_i, _loop_state):
974990 step_size_factor = step_size_factor ,
975991 corrector_noise_history = corrector_noise_history ,
976992 )
993+ all_nans = _check_all_nans (new_state )
994+ if all_nans :
995+ raise RuntimeError (f"All values are NaNs in state during integration at { _current_time } ." )
977996 return new_state , new_time , initial_step
978997
979998 # Execute the fixed loop
@@ -1004,9 +1023,10 @@ def integrate_stochastic_adaptive(
10041023 initial_loop_state = (keras .ops .zeros ((), dtype = "int32" ), state , start_time , initial_step , 0 , state )
10051024
10061025 def cond (i , current_state , current_time , current_step , counter , last_state ):
1007- # time remaining after the next step
10081026 time_remaining = keras .ops .sign (stop_time - start_time ) * (stop_time - (current_time + current_step ))
1009- return keras .ops .logical_and (keras .ops .all (time_remaining > 0 ), keras .ops .less (i , max_steps ))
1027+ all_nans = _check_all_nans (current_state )
1028+ end_now = keras .ops .logical_and (keras .ops .all (time_remaining > 0 ), keras .ops .less (i , max_steps ))
1029+ return keras .ops .logical_and (~ all_nans , end_now )
10101030
10111031 def body_adaptive (_i , _current_state , _current_time , _current_step , _counter , _last_state ):
10121032 # Step Size Control
@@ -1048,9 +1068,36 @@ def body_adaptive(_i, _current_state, _current_time, _current_step, _counter, _l
10481068 return _i + 1 , new_state , new_time , new_step , _counter , _new_current_state
10491069
10501070 # Execute the adaptive loop
1051- _ , final_state , final_time , _ , final_counter , _ = keras .ops .while_loop (cond , body_adaptive , initial_loop_state )
1071+ _ , final_state , final_time , _ , final_counter , final_k1 = keras .ops .while_loop (
1072+ cond , body_adaptive , initial_loop_state
1073+ )
1074+
1075+ if _check_all_nans (final_state ):
1076+ raise RuntimeError (f"All values are NaNs in state during integration at { final_time } ." )
1077+
1078+ # Final step to hit stop_time exactly
1079+ time_diff = stop_time - final_time
1080+ time_remaining = keras .ops .sign (stop_time - start_time ) * time_diff
1081+ if keras .ops .all (time_remaining > 0 ):
1082+ noise_final = {k : z_history [k ][- 1 ] for k in final_state .keys ()}
1083+ noise_extra_final = None
1084+ if len (z_extra_history ) > 0 :
1085+ noise_extra_final = {k : z_extra_history [k ][- 1 ] for k in final_state .keys ()}
1086+
1087+ final_state , _ , _ = step_fn (
1088+ state = final_state ,
1089+ time = final_time ,
1090+ step_size = time_diff ,
1091+ last_state = final_k1 ,
1092+ min_step_size = min_step_size ,
1093+ max_step_size = time_remaining ,
1094+ noise = noise_final ,
1095+ noise_aux = noise_extra_final ,
1096+ use_adaptive_step_size = False ,
1097+ )
1098+ final_counter = final_counter + 1
10521099
1053- logging .debug (f"Finished integration after { final_counter } steps at { final_time } ." )
1100+ logging .debug (f"Finished integration after { final_counter } ." )
10541101 return final_state
10551102
10561103
@@ -1094,13 +1141,12 @@ def integrate_langevin(
10941141
10951142 def body (_i , loop_state ):
10961143 current_state , current_time = loop_state
1097- t = current_time
10981144
10991145 # score at current time
1100- score = score_fn (t , ** filter_kwargs (current_state , score_fn ))
1146+ score = score_fn (current_time , ** filter_kwargs (current_state , score_fn ))
11011147
11021148 # noise schedule
1103- log_snr_t = noise_schedule .get_log_snr (t = t , training = False )
1149+ log_snr_t = noise_schedule .get_log_snr (t = current_time , training = False )
11041150 _ , sigma_t = noise_schedule .get_alpha_sigma (log_snr_t = log_snr_t )
11051151
11061152 new_state : StateDict = {}
@@ -1125,6 +1171,8 @@ def body(_i, loop_state):
11251171 step_size_factor = step_size_factor ,
11261172 corrector_noise_history = corrector_noise_history ,
11271173 )
1174+ if _check_all_nans (new_state ):
1175+ raise RuntimeError (f"All values are NaNs in state during integration at { current_time } ." )
11281176
11291177 return new_state , new_time
11301178
0 commit comments