Skip to content

Commit 77e950c

Browse files
committed
merge new samplers
1 parent 1129b1c commit 77e950c

File tree

2 files changed

+17
-23
lines changed

2 files changed

+17
-23
lines changed

bayesflow/networks/diffusion_model/compositional_diffusion_model.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55
from keras import ops
66

77
from bayesflow.types import Tensor
8-
from bayesflow.utils import (
9-
expand_right_as,
10-
integrate,
11-
integrate_stochastic,
12-
)
8+
from bayesflow.utils import expand_right_as, integrate, integrate_stochastic, STOCHASTIC_METHODS
139
from bayesflow.utils.serialization import serializable
1410
from .diffusion_model import DiffusionModel
1511
from .schedules.noise_schedule import NoiseSchedule
@@ -318,7 +314,7 @@ def _inverse_compositional(
318314
z = z / ops.sqrt(ops.cast(scale_latent, dtype=ops.dtype(z)))
319315

320316
if density:
321-
if integrate_kwargs["method"] == "euler_maruyama":
317+
if integrate_kwargs["method"] in STOCHASTIC_METHODS:
322318
raise ValueError("Stochastic methods are not supported for density computation.")
323319

324320
def deltas(time, xz):
@@ -346,7 +342,7 @@ def deltas(time, xz):
346342

347343
state = {"xz": z}
348344

349-
if integrate_kwargs["method"] == "euler_maruyama":
345+
if integrate_kwargs["method"] in STOCHASTIC_METHODS:
350346

351347
def deltas(time, xz):
352348
return {
@@ -365,20 +361,19 @@ def diffusion(time, xz):
365361
return {"xz": self.diffusion_term(xz, time=time, training=training)}
366362

367363
score_fn = None
368-
if "corrector_steps" in integrate_kwargs:
369-
if integrate_kwargs["corrector_steps"] > 0:
370-
371-
def score_fn(time, xz):
372-
return {
373-
"xz": self.compositional_score(
374-
xz,
375-
time=time,
376-
conditions=conditions,
377-
compute_prior_score=compute_prior_score,
378-
mini_batch_size=mini_batch_size,
379-
training=training,
380-
)
381-
}
364+
if "corrector_steps" in integrate_kwargs or integrate_kwargs.get("method") == "langevin":
365+
366+
def score_fn(time, xz):
367+
return {
368+
"xz": self.compositional_score(
369+
xz,
370+
time=time,
371+
conditions=conditions,
372+
compute_prior_score=compute_prior_score,
373+
mini_batch_size=mini_batch_size,
374+
training=training,
375+
)
376+
}
382377

383378
state = integrate_stochastic(
384379
drift_fn=deltas,
@@ -390,7 +385,6 @@ def score_fn(time, xz):
390385
**integrate_kwargs,
391386
)
392387
else:
393-
integrate_kwargs.pop("corrector_steps", None)
394388

395389
def deltas(time, xz):
396390
return {

bayesflow/workflows/basic_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def compositional_sample(
291291
*,
292292
num_samples: int,
293293
conditions: Mapping[str, np.ndarray],
294-
compute_prior_score: Callable[[Mapping[str, np.ndarray]], np.ndarray],
294+
compute_prior_score: Callable[[Mapping[str, np.ndarray]], Mapping[str, np.ndarray]],
295295
**kwargs,
296296
) -> dict[str, np.ndarray]:
297297
"""

0 commit comments

Comments
 (0)