1+ import logging
12from collections .abc import Sequence
23
34import keras
@@ -236,14 +237,21 @@ def f(x):
236237 def _forward (
237238 self , x : Tensor , conditions : Tensor = None , density : bool = False , training : bool = False , ** kwargs
238239 ) -> Tensor | tuple [Tensor , Tensor ]:
240+ integrate_kwargs = self .integrate_kwargs | kwargs
239241 if density :
242+ if integrate_kwargs ["steps" ] == "adaptive" :
243+ logging .warning (
244+ "Using adaptive integration for density estimation can lead to "
245+ "problems with autodiff. Switching to 200 fixed steps instead."
246+ )
247+ integrate_kwargs ["steps" ] = 200
240248
241249 def deltas (time , xz ):
242250 v , trace = self ._velocity_trace (xz , time = time , conditions = conditions , training = training )
243251 return {"xz" : v , "trace" : trace }
244252
245253 state = {"xz" : x , "trace" : keras .ops .zeros (keras .ops .shape (x )[:- 1 ] + (1 ,), dtype = keras .ops .dtype (x ))}
246- state = integrate (deltas , state , start_time = 1.0 , stop_time = 0.0 , ** ( self . integrate_kwargs | kwargs ) )
254+ state = integrate (deltas , state , start_time = 1.0 , stop_time = 0.0 , ** integrate_kwargs )
247255
248256 z = state ["xz" ]
249257 log_density = self .base_distribution .log_prob (z ) + keras .ops .squeeze (state ["trace" ], axis = - 1 )
@@ -254,7 +262,7 @@ def deltas(time, xz):
254262 return {"xz" : self .velocity (xz , time = time , conditions = conditions , training = training )}
255263
256264 state = {"xz" : x }
257- state = integrate (deltas , state , start_time = 1.0 , stop_time = 0.0 , ** ( self . integrate_kwargs | kwargs ) )
265+ state = integrate (deltas , state , start_time = 1.0 , stop_time = 0.0 , ** integrate_kwargs )
258266
259267 z = state ["xz" ]
260268
@@ -263,14 +271,21 @@ def deltas(time, xz):
263271 def _inverse (
264272 self , z : Tensor , conditions : Tensor = None , density : bool = False , training : bool = False , ** kwargs
265273 ) -> Tensor | tuple [Tensor , Tensor ]:
274+ integrate_kwargs = self .integrate_kwargs | kwargs
266275 if density :
276+ if integrate_kwargs ["steps" ] == "adaptive" :
277+ logging .warning (
278+ "Using adaptive integration for density estimation can lead to "
279+ "problems with autodiff. Switching to 200 fixed steps instead."
280+ )
281+ integrate_kwargs ["steps" ] = 200
267282
268283 def deltas (time , xz ):
269284 v , trace = self ._velocity_trace (xz , time = time , conditions = conditions , training = training )
270285 return {"xz" : v , "trace" : trace }
271286
272287 state = {"xz" : z , "trace" : keras .ops .zeros (keras .ops .shape (z )[:- 1 ] + (1 ,), dtype = keras .ops .dtype (z ))}
273- state = integrate (deltas , state , start_time = 0.0 , stop_time = 1.0 , ** ( self . integrate_kwargs | kwargs ) )
288+ state = integrate (deltas , state , start_time = 0.0 , stop_time = 1.0 , ** integrate_kwargs )
274289
275290 x = state ["xz" ]
276291 log_density = self .base_distribution .log_prob (z ) - keras .ops .squeeze (state ["trace" ], axis = - 1 )
@@ -281,7 +296,7 @@ def deltas(time, xz):
281296 return {"xz" : self .velocity (xz , time = time , conditions = conditions , training = training )}
282297
283298 state = {"xz" : z }
284- state = integrate (deltas , state , start_time = 0.0 , stop_time = 1.0 , ** ( self . integrate_kwargs | kwargs ) )
299+ state = integrate (deltas , state , start_time = 0.0 , stop_time = 1.0 , ** integrate_kwargs )
285300
286301 x = state ["xz" ]
287302
0 commit comments