55from keras import ops
66
77from bayesflow .types import Tensor
8- from bayesflow .utils import (
9- expand_right_as ,
10- integrate ,
11- integrate_stochastic ,
12- )
8+ from bayesflow .utils import expand_right_as , integrate , integrate_stochastic , STOCHASTIC_METHODS
139from bayesflow .utils .serialization import serializable
1410from .diffusion_model import DiffusionModel
1511from .schedules .noise_schedule import NoiseSchedule
@@ -318,7 +314,7 @@ def _inverse_compositional(
318314 z = z / ops .sqrt (ops .cast (scale_latent , dtype = ops .dtype (z )))
319315
320316 if density :
321- if integrate_kwargs ["method" ] == "euler_maruyama" :
317+ if integrate_kwargs ["method" ] in STOCHASTIC_METHODS :
322318 raise ValueError ("Stochastic methods are not supported for density computation." )
323319
324320 def deltas (time , xz ):
@@ -346,7 +342,7 @@ def deltas(time, xz):
346342
347343 state = {"xz" : z }
348344
349- if integrate_kwargs ["method" ] == "euler_maruyama" :
345+ if integrate_kwargs ["method" ] in STOCHASTIC_METHODS :
350346
351347 def deltas (time , xz ):
352348 return {
@@ -365,20 +361,19 @@ def diffusion(time, xz):
365361 return {"xz" : self .diffusion_term (xz , time = time , training = training )}
366362
367363 score_fn = None
368- if "corrector_steps" in integrate_kwargs :
369- if integrate_kwargs ["corrector_steps" ] > 0 :
370-
371- def score_fn (time , xz ):
372- return {
373- "xz" : self .compositional_score (
374- xz ,
375- time = time ,
376- conditions = conditions ,
377- compute_prior_score = compute_prior_score ,
378- mini_batch_size = mini_batch_size ,
379- training = training ,
380- )
381- }
364+ if "corrector_steps" in integrate_kwargs or integrate_kwargs .get ("method" ) == "langevin" :
365+
366+ def score_fn (time , xz ):
367+ return {
368+ "xz" : self .compositional_score (
369+ xz ,
370+ time = time ,
371+ conditions = conditions ,
372+ compute_prior_score = compute_prior_score ,
373+ mini_batch_size = mini_batch_size ,
374+ training = training ,
375+ )
376+ }
382377
383378 state = integrate_stochastic (
384379 drift_fn = deltas ,
@@ -390,7 +385,6 @@ def score_fn(time, xz):
390385 ** integrate_kwargs ,
391386 )
392387 else :
393- integrate_kwargs .pop ("corrector_steps" , None )
394388
395389 def deltas (time , xz ):
396390 return {
0 commit comments