From 3b887dc8ff8e3a7c8970a8e7a5445db73bc49aa3 Mon Sep 17 00:00:00 2001 From: arrjon Date: Mon, 27 Oct 2025 11:20:26 +0100 Subject: [PATCH 1/2] fix scm --- .../stable_consistency_model/stable_consistency_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py index 6ce27fdf4..dc092ab4e 100644 --- a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py +++ b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py @@ -307,7 +307,7 @@ def compute_metrics( r = 1.0 # TODO: if consistency distillation training (not supported yet) is unstable, add schedule here def f_teacher(x, t): - o = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") + o = self._apply_subnet(x, self.time_emb(t), conditions, training=stage == "training") return self.subnet_projector(o) primals = (xt / self.sigma, t) @@ -321,7 +321,7 @@ def f_teacher(x, t): cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt) # calculate output of the network - subnet_out = self._apply_subnet(x / self.sigma, self.time_emb(t), conditions, training=stage == "training") + subnet_out = self._apply_subnet(xt / self.sigma, self.time_emb(t), conditions, training=stage == "training") student_out = self.subnet_projector(subnet_out) # calculate the tangent From 64516a464ba828a6dd45db7f03f5e4cf4c7cb79a Mon Sep 17 00:00:00 2001 From: arrjon Date: Wed, 29 Oct 2025 17:31:44 +0100 Subject: [PATCH 2/2] fix saving --- .../stable_consistency_model/stable_consistency_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py index dc092ab4e..0f787d44f 100644 --- a/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py +++ b/bayesflow/experimental/stable_consistency_model/stable_consistency_model.py @@ -105,7 +105,6 @@ def __init__( ) embedding_kwargs = embedding_kwargs or {} - self.embedding_kwargs = embedding_kwargs self.time_emb = FourierEmbedding(**embedding_kwargs) self.time_emb_dim = self.time_emb.embed_dim @@ -123,13 +122,14 @@ def get_config(self): config = { "subnet": self.subnet, "sigma": self.sigma, - "embedding_kwargs": self.embedding_kwargs, + "time_emb": self.time_emb, "concatenate_subnet_input": self._concatenate_subnet_input, } return base_config | serialize(config) - def _discretize_time(self, num_steps: int, rho: float = 3.5, **kwargs): + @staticmethod + def _discretize_time(num_steps: int, rho: float = 3.5): t = keras.ops.linspace(0.0, pi / 2, num_steps) times = keras.ops.exp((t - pi / 2) * rho) * pi / 2 times = keras.ops.concatenate([keras.ops.zeros((1,)), times[1:]], axis=0)