Skip to content

Commit a661204

Browse files
committed
Merge branch 'new_samplers' into compositional_sampling_diffusion
# Conflicts: # bayesflow/networks/diffusion_model/diffusion_model.py # bayesflow/utils/integrate.py
2 parents 67f1175 + f9823f8 commit a661204

File tree

14 files changed

+1616
-294
lines changed

14 files changed

+1616
-294
lines changed

bayesflow/diagnostics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
calibration_error,
88
calibration_log_gamma,
99
posterior_contraction,
10+
posterior_z_score,
1011
summary_space_comparison,
1112
)
1213

bayesflow/diagnostics/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
from .classifier_two_sample_test import classifier_two_sample_test
66
from .model_misspecification import bootstrap_comparison, summary_space_comparison
77
from .calibration_log_gamma import calibration_log_gamma, gamma_null_distribution, gamma_discrepancy
8+
from .posterior_z_score import posterior_z_score
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from collections.abc import Sequence, Mapping, Callable
2+
3+
import numpy as np
4+
5+
from ...utils.dict_utils import dicts_to_arrays, compute_test_quantities
6+
7+
8+
def posterior_z_score(
9+
estimates: Mapping[str, np.ndarray] | np.ndarray,
10+
targets: Mapping[str, np.ndarray] | np.ndarray,
11+
variable_keys: Sequence[str] = None,
12+
variable_names: Sequence[str] = None,
13+
test_quantities: dict[str, Callable] = None,
14+
aggregation: Callable | None = np.median,
15+
) -> dict[str, any]:
16+
"""
17+
Computes the posterior z-score from prior to posterior for the given samples according to [1]:
18+
19+
post_z_score = (posterior_mean - true_parameters) / posterior_std
20+
21+
The score is adequate if it centers around zero and spreads roughly
22+
in the interval [-3, 3]
23+
24+
[1] Schad, D. J., Betancourt, M., & Vasishth, S. (2021).
25+
Toward a principled Bayesian workflow in cognitive science.
26+
Psychological methods, 26(1), 103.
27+
28+
Paper also available at https://arxiv.org/abs/1904.12765
29+
30+
Parameters
31+
----------
32+
estimates : np.ndarray of shape (num_datasets, num_draws_post, num_variables)
33+
Posterior samples, comprising `num_draws_post` random draws from the posterior distribution
34+
for each data set from `num_datasets`.
35+
targets : np.ndarray of shape (num_datasets, num_variables)
36+
Prior samples, comprising `num_datasets` ground truths.
37+
variable_keys : Sequence[str], optional (default = None)
38+
Select keys from the dictionaries provided in estimates and targets.
39+
By default, select all keys.
40+
variable_names : Sequence[str], optional (default = None)
41+
Optional variable names to show in the output.
42+
test_quantities : dict or None, optional, default: None
43+
A dict that maps plot titles to functions that compute
44+
test quantities based on estimate/target draws.
45+
46+
The dict keys are automatically added to ``variable_keys``
47+
and ``variable_names``.
48+
Test quantity functions are expected to accept a dict of draws with
49+
shape ``(batch_size, ...)`` as the first (typically only)
50+
positional argument and return an NumPy array of shape
51+
``(batch_size,)``.
52+
The functions do not have to deal with an additional
53+
sample dimension, as appropriate reshaping is done internally.
54+
aggregation : callable or None, optional (default = np.median)
55+
Function to aggregate the PC across draws. Typically `np.mean` or `np.median`.
56+
If None is provided, the individual values are returned.
57+
58+
Returns
59+
-------
60+
result : dict
61+
Dictionary containing:
62+
63+
- "values" : float or np.ndarray
64+
The (optionally aggregated) posterior z-score per variable
65+
- "metric_name" : str
66+
The name of the metric ("Posterior z-score").
67+
- "variable_names" : str
68+
The (inferred) variable names.
69+
70+
Notes
71+
-----
72+
Posterior z-score quantifies how far the posterior mean lies from
73+
the true generating parameter, in standard-error units. Values near 0
74+
(in absolute terms) mean the posterior mean is close to the truth;
75+
large absolute values indicate substantial deviation.
76+
The sign shows the direction of the bias.
77+
78+
"""
79+
80+
# Optionally, compute and prepend test quantities from draws
81+
if test_quantities is not None:
82+
updated_data = compute_test_quantities(
83+
targets=targets,
84+
estimates=estimates,
85+
variable_keys=variable_keys,
86+
variable_names=variable_names,
87+
test_quantities=test_quantities,
88+
)
89+
variable_names = updated_data["variable_names"]
90+
variable_keys = updated_data["variable_keys"]
91+
estimates = updated_data["estimates"]
92+
targets = updated_data["targets"]
93+
94+
samples = dicts_to_arrays(
95+
estimates=estimates,
96+
targets=targets,
97+
variable_keys=variable_keys,
98+
variable_names=variable_names,
99+
)
100+
101+
post_vars = samples["estimates"].var(axis=1, ddof=1)
102+
post_means = samples["estimates"].mean(axis=1)
103+
post_stds = np.sqrt(post_vars)
104+
z_score = (post_means - samples["targets"]) / post_stds
105+
if aggregation is not None:
106+
z_score = aggregation(z_score, axis=0)
107+
variable_names = samples["estimates"].variable_names
108+
return {"values": z_score, "metric_name": "Posterior z-score", "variable_names": variable_names}

bayesflow/diagnostics/plots/plot_quantity.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,12 @@ def _prepare_values(
213213

214214
if estimates is not None:
215215
if is_values_callable:
216-
values = values(estimates=estimates, targets=targets, **filter_kwargs({"aggregation": None}, values))
216+
values = values(
217+
estimates=estimates,
218+
targets=targets,
219+
variable_keys=variable_keys,
220+
**filter_kwargs({"aggregation": None}, values),
221+
)
217222

218223
data = dicts_to_arrays(
219224
estimates=estimates,

bayesflow/experimental/stable_consistency_model/stable_consistency_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,9 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, **kwargs) -> Tensor:
222222
z : Tensor
223223
Samples from a standard normal distribution
224224
conditions : Tensor, optional, default: None
225-
Conditions for a approximate conditional distribution
225+
Conditions for an approximate conditional distribution
226226
**kwargs : dict, optional, default: {}
227-
Additional keyword arguments. Include `steps` (default: 30) to
227+
Additional keyword arguments. Include `steps` (default: 15) to
228228
adjust the number of sampling steps.
229229
230230
Returns

bayesflow/networks/diffusion_model/diffusion_model.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
integrate_stochastic,
1717
logging,
1818
tensor_utils,
19+
STOCHASTIC_METHODS,
1920
)
2021
from bayesflow.utils.serialization import serialize, deserialize, serializable
2122

@@ -39,13 +40,13 @@ class DiffusionModel(InferenceNetwork):
3940
"activation": "mish",
4041
"kernel_initializer": "he_normal",
4142
"residual": True,
42-
"dropout": 0.0,
43+
"dropout": 0.05,
4344
"spectral_normalization": False,
4445
}
4546

4647
INTEGRATE_DEFAULT_CONFIG = {
47-
"method": "rk45",
48-
"steps": 100,
48+
"method": "two_step_adaptive",
49+
"steps": "adaptive",
4950
}
5051

5152
def __init__(
@@ -402,14 +403,13 @@ def _forward(
402403
conditions: Tensor = None,
403404
density: bool = False,
404405
training: bool = False,
405-
compositional: bool = False,
406406
**kwargs,
407407
) -> Tensor | tuple[Tensor, Tensor]:
408408
integrate_kwargs = {"start_time": 0.0, "stop_time": 1.0}
409409
integrate_kwargs = integrate_kwargs | self.integrate_kwargs
410410
integrate_kwargs = integrate_kwargs | kwargs
411411

412-
if integrate_kwargs["method"] == "euler_maruyama":
412+
if integrate_kwargs["method"] in STOCHASTIC_METHODS:
413413
raise ValueError("Stochastic methods are not supported for forward integration.")
414414

415415
if density:
@@ -453,14 +453,13 @@ def _inverse(
453453
conditions: Tensor = None,
454454
density: bool = False,
455455
training: bool = False,
456-
compositional: bool = False,
457456
**kwargs,
458457
) -> Tensor | tuple[Tensor, Tensor]:
459458
integrate_kwargs = {"start_time": 1.0, "stop_time": 0.0}
460459
integrate_kwargs = integrate_kwargs | self.integrate_kwargs
461460
integrate_kwargs = integrate_kwargs | kwargs
462461
if density:
463-
if integrate_kwargs["method"] == "euler_maruyama":
462+
if integrate_kwargs["method"] in STOCHASTIC_METHODS:
464463
raise ValueError("Stochastic methods are not supported for density computation.")
465464

466465
def deltas(time, xz):
@@ -479,7 +478,7 @@ def deltas(time, xz):
479478
return x, log_density
480479

481480
state = {"xz": z}
482-
if integrate_kwargs["method"] == "euler_maruyama":
481+
if integrate_kwargs["method"] in STOCHASTIC_METHODS:
483482

484483
def deltas(time, xz):
485484
return {
@@ -490,18 +489,17 @@ def diffusion(time, xz):
490489
return {"xz": self.diffusion_term(xz, time=time, training=training)}
491490

492491
score_fn = None
493-
if "corrector_steps" in integrate_kwargs:
494-
if integrate_kwargs["corrector_steps"] > 0:
495-
496-
def score_fn(time, xz):
497-
return {
498-
"xz": self.score(
499-
xz,
500-
time=time,
501-
conditions=conditions,
502-
training=training,
503-
)
504-
}
492+
if "corrector_steps" in integrate_kwargs or integrate_kwargs.get("method") == "langevin":
493+
494+
def score_fn(time, xz):
495+
return {
496+
"xz": self.score(
497+
xz,
498+
time=time,
499+
conditions=conditions,
500+
training=training,
501+
)
502+
}
505503

506504
state = integrate_stochastic(
507505
drift_fn=deltas,

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ class FlowMatching(InferenceNetwork):
5353
}
5454

5555
INTEGRATE_DEFAULT_CONFIG = {
56-
"method": "rk45",
57-
"steps": 100,
56+
"method": "tsit5",
57+
"steps": "adaptive",
5858
}
5959

6060
def __init__(
@@ -236,14 +236,15 @@ def f(x):
236236
def _forward(
237237
self, x: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
238238
) -> Tensor | tuple[Tensor, Tensor]:
239+
integrate_kwargs = self.integrate_kwargs | kwargs
239240
if density:
240241

241242
def deltas(time, xz):
242243
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
243244
return {"xz": v, "trace": trace}
244245

245246
state = {"xz": x, "trace": keras.ops.zeros(keras.ops.shape(x)[:-1] + (1,), dtype=keras.ops.dtype(x))}
246-
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))
247+
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs)
247248

248249
z = state["xz"]
249250
log_density = self.base_distribution.log_prob(z) + keras.ops.squeeze(state["trace"], axis=-1)
@@ -254,7 +255,7 @@ def deltas(time, xz):
254255
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
255256

256257
state = {"xz": x}
257-
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **(self.integrate_kwargs | kwargs))
258+
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **integrate_kwargs)
258259

259260
z = state["xz"]
260261

@@ -263,14 +264,15 @@ def deltas(time, xz):
263264
def _inverse(
264265
self, z: Tensor, conditions: Tensor = None, density: bool = False, training: bool = False, **kwargs
265266
) -> Tensor | tuple[Tensor, Tensor]:
267+
integrate_kwargs = self.integrate_kwargs | kwargs
266268
if density:
267269

268270
def deltas(time, xz):
269271
v, trace = self._velocity_trace(xz, time=time, conditions=conditions, training=training)
270272
return {"xz": v, "trace": trace}
271273

272274
state = {"xz": z, "trace": keras.ops.zeros(keras.ops.shape(z)[:-1] + (1,), dtype=keras.ops.dtype(z))}
273-
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))
275+
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs)
274276

275277
x = state["xz"]
276278
log_density = self.base_distribution.log_prob(z) - keras.ops.squeeze(state["trace"], axis=-1)
@@ -281,7 +283,7 @@ def deltas(time, xz):
281283
return {"xz": self.velocity(xz, time=time, conditions=conditions, training=training)}
282284

283285
state = {"xz": z}
284-
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **(self.integrate_kwargs | kwargs))
286+
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **integrate_kwargs)
285287

286288
x = state["xz"]
287289

bayesflow/simulators/benchmark_simulators/lotka_volterra.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ def __init__(
1010
X0: int = 30,
1111
Y0: int = 1,
1212
T: int | None = 20,
13-
subsample: int = 10,
13+
subsample: int | str = "original",
1414
flatten: bool = True,
1515
obs_noise: float = 0.1,
16-
dt: float = None,
16+
dt: float = 0.1,
1717
rng: np.random.Generator = None,
1818
):
1919
"""Lotka Volterra simulated benchmark.
@@ -27,14 +27,17 @@ def __init__(
2727
Initial number of predator species.
2828
T: int, optional, default: 20
2929
The duration (time horizon) of the simulation.
30-
subsample: int or None, optional, default: 10
30+
subsample: int, str or None, optional, default: 'original'
3131
The number of evenly spaced time points to return.
3232
If None, no subsampling will be performed and all T timepoints will be returned.
33+
If 'original', the original benchmark task subsampling of 20 points is used.
3334
flatten: bool, optional, default: True
3435
A flag to indicate whether a 1D (`flatten=True`) or 2D (`flatten=False`)
3536
representation of the simulated data is returned.
3637
obs_noise: float, optional, default: 0.1
3738
The standard deviation of the log-normal likelihood.
39+
dt: float, optional, default: 0.1
40+
The time step size for the ODE solver.
3841
rng: np.random.Generator or None, optional, default: None
3942
An optional random number generator to use.
4043
"""
@@ -95,21 +98,23 @@ def observation_model(self, params: np.ndarray) -> np.ndarray:
9598
# Unpack parameter vector into scalars
9699
alpha, beta, gamma, delta = params
97100

98-
# Prepate time vector between 0 and T of length T
99-
t_vec = np.linspace(0, self.T, int(1 / self.dt))
101+
# Prepare time vector between 0 and T of length T
102+
t_vec = np.arange(0, self.T + self.dt, self.dt)
100103

101104
# Integrate using scipy and retain only infected (2-nd dimension)
102105
pp = odeint(self._deriv, x0, t_vec, args=(alpha, beta, gamma, delta))
103106

104107
# Subsample evenly the specified number of points, if specified
105-
if self.subsample is not None:
108+
if self.subsample == "original":
109+
pp = pp[::21]
110+
elif self.subsample is not None:
106111
pp = pp[:: (self.T // self.subsample)]
107112

108-
# Ensure minimum count is 0, which will later pass by log(0 + 1)
109-
pp[pp < 0] = 0.0
113+
# Ensure minimum count is 0
114+
pp = np.clip(pp, a_min=1e-10, a_max=10000.0)
110115

111116
# Add noise, decide whether to flatten and return
112-
x = self.rng.lognormal(np.log1p(pp), sigma=self.obs_noise)
117+
x = self.rng.lognormal(np.log(pp), sigma=self.obs_noise)
113118
if self.flatten:
114119
return x.flatten()
115120
return x

0 commit comments

Comments
 (0)