Skip to content

Commit 1632a9e

Browse files
committed
Remove the need to pass inference variables to workflow
1 parent 22c75d1 commit 1632a9e

25 files changed

+241
-173
lines changed

bayesflow/adapters/adapter.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from collections.abc import MutableSequence, Sequence
1+
from collections.abc import MutableSequence, Sequence, Mapping
22

33
import numpy as np
4+
45
from keras.saving import (
56
deserialize_keras_object as deserialize,
67
register_keras_serializable as serializable,
@@ -121,16 +122,16 @@ def inverse(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, any]:
121122

122123
return data
123124

124-
def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) -> dict[str, np.ndarray]:
125+
def __call__(self, data: Mapping[str, any], *, inverse: bool = False, **kwargs) -> dict[str, np.ndarray]:
125126
"""Apply the transforms in the given direction.
126127
127128
Parameters
128129
----------
129-
data : dict
130+
data : Mapping[str, any]
130131
The data to be transformed.
131132
inverse : bool, optional
132133
If False, apply the forward transform, else apply the inverse transform (default False).
133-
**kwargs : dict
134+
**kwargs
134135
Additional keyword arguments passed to each transform.
135136
136137
Returns

bayesflow/approximators/approximator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
import keras
1+
from collections.abc import Mapping
2+
23
import multiprocessing as mp
34

5+
import keras
6+
47
from bayesflow.adapters import Adapter
58
from bayesflow.datasets import OnlineDataset
69
from bayesflow.simulators import Simulator
@@ -19,7 +22,7 @@ def build_adapter(cls, **kwargs) -> Adapter:
1922
# implemented by each respective architecture
2023
raise NotImplementedError
2124

22-
def build_from_data(self, data: dict[str, any]) -> None:
25+
def build_from_data(self, data: Mapping[str, any]) -> None:
2326
self.compute_metrics(**data, stage="training")
2427
self.built = True
2528

@@ -72,7 +75,7 @@ def fit(self, *, dataset: keras.utils.PyDataset = None, simulator: Simulator = N
7275
A dataset containing simulations for training. If provided, `simulator` must be None.
7376
simulator : Simulator, optional
7477
A simulator used to generate a dataset. If provided, `dataset` must be None.
75-
**kwargs : dict
78+
**kwargs
7679
Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`):
7780
7881
batch_size : int or None, default='auto'

bayesflow/approximators/continuous_approximator.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from collections.abc import Sequence
1+
from collections.abc import Mapping, Sequence, Callable
22

3-
import keras
43
import numpy as np
4+
5+
import keras
56
from keras.saving import (
67
deserialize_keras_object as deserialize,
78
register_keras_serializable as serializable,
@@ -23,7 +24,7 @@ class ContinuousApproximator(Approximator):
2324
2425
Parameters
2526
----------
26-
adapter : Adapter
27+
adapter : bayesflow.adapters.Adapter
2728
Adapter for data processing. You can use :py:meth:`build_adapter`
2829
to create it.
2930
inference_network : InferenceNetwork
@@ -53,7 +54,7 @@ def build_adapter(
5354
inference_variables: Sequence[str],
5455
inference_conditions: Sequence[str] = None,
5556
summary_variables: Sequence[str] = None,
56-
sample_weight: Sequence[str] = None,
57+
sample_weight: str = None,
5758
) -> Adapter:
5859
"""Create an :py:class:`~bayesflow.adapters.Adapter` suited for the approximator.
5960
@@ -159,7 +160,7 @@ def fit(self, *args, **kwargs):
159160
A dataset containing simulations for training. If provided, `simulator` must be None.
160161
simulator : Simulator, optional
161162
A simulator used to generate a dataset. If provided, `dataset` must be None.
162-
**kwargs : dict
163+
**kwargs
163164
Additional keyword arguments passed to `keras.Model.fit()`, including (see also `build_dataset`):
164165
165166
batch_size : int or None, default='auto'
@@ -221,12 +222,50 @@ def get_config(self):
221222

222223
def estimate(
223224
self,
224-
conditions: dict[str, np.ndarray],
225+
conditions: Mapping[str, np.ndarray],
225226
split: bool = False,
226-
estimators: dict[str, callable] = None,
227+
estimators: Mapping[str, Callable] = None,
227228
num_samples: int = 1000,
228229
**kwargs,
229230
) -> dict[str, dict[str, np.ndarray]]:
231+
"""
232+
Estimate summary statistics for variables based on given conditions.
233+
234+
This function samples data using the object's ``sample`` method according to the provided
235+
conditions and then computes summary statistics for each variable using a set of estimator
236+
functions. By default, it calculates the mean, median, and selected quantiles (10th, 50th,
237+
and 90th percentiles). Users can also supply custom estimators that override or extend the
238+
default ones.
239+
240+
Parameters
241+
----------
242+
conditions : Mapping[str, np.ndarray]
243+
A mapping from variable names to numpy arrays representing the conditions under which
244+
samples should be generated.
245+
split : bool, optional
246+
If True, indicates that the data sampling process should split the samples based on an
247+
internal logic. The default is False.
248+
estimators : Mapping[str, Callable], optional
249+
A dictionary where keys are estimator names and values are callables. Each callable must
250+
accept an array and an axis parameter, and return a dictionary with the computed statistic.
251+
If not provided, a default set of estimators is used:
252+
- 'mean': Computes the mean along the specified axis.
253+
- 'median': Computes the median along the specified axis.
254+
- 'quantiles': Computes the 10th, 50th, and 90th percentiles along the specified axis,
255+
then rearranges the axes for convenience.
256+
num_samples : int, optional
257+
The number of samples to generate for each variable. The default is 1000.
258+
**kwargs
259+
Additional keyword arguments passed to the ``sample`` method.
260+
261+
Returns
262+
-------
263+
dict[str, dict[str, np.ndarray]]
264+
A nested dictionary where the outer keys correspond to variable names and the inner keys
265+
correspond to estimator names. Each inner dictionary contains the computed statistic(s) for
266+
the variable, potentially with reduced nesting via ``squeeze_inner_estimates_dict``.
267+
"""
268+
230269
estimators = estimators or {}
231270
estimators = (
232271
dict(
@@ -261,7 +300,7 @@ def sample(
261300
self,
262301
*,
263302
num_samples: int,
264-
conditions: dict[str, np.ndarray],
303+
conditions: Mapping[str, np.ndarray],
265304
split: bool = False,
266305
**kwargs,
267306
) -> dict[str, np.ndarray]:
@@ -338,14 +377,14 @@ def _sample(
338377
**filter_kwargs(kwargs, self.inference_network.sample),
339378
)
340379

341-
def log_prob(self, data: dict[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]:
380+
def log_prob(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray | dict[str, np.ndarray]:
342381
"""
343382
Computes the log-probability of given data under the model. The `data` dictionary is preprocessed using the
344383
`adapter`. Log-probabilities are returned as NumPy arrays.
345384
346385
Parameters
347386
----------
348-
data : dict[str, np.ndarray]
387+
data : Mapping[str, np.ndarray]
349388
Dictionary of observed data as NumPy arrays.
350389
**kwargs : dict
351390
Additional keyword arguments for the adapter and log-probability computation.

bayesflow/approximators/point_approximator.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
import keras
1+
from collections.abc import Mapping
2+
23
import numpy as np
4+
5+
import keras
36
from keras.saving import (
47
register_keras_serializable as serializable,
58
)
@@ -21,7 +24,7 @@ class PointApproximator(ContinuousApproximator):
2124

2225
def estimate(
2326
self,
24-
conditions: dict[str, np.ndarray],
27+
conditions: Mapping[str, np.ndarray],
2528
split: bool = False,
2629
**kwargs,
2730
) -> dict[str, dict[str, np.ndarray | dict[str, np.ndarray]]]:
@@ -33,7 +36,7 @@ def estimate(
3336
3437
Parameters
3538
----------
36-
conditions : dict[str, np.ndarray]
39+
conditions : Mapping[str, np.ndarray]
3740
A dictionary mapping variable names to arrays representing the conditions
3841
for the estimation process.
3942
split : bool, optional
@@ -71,7 +74,7 @@ def sample(
7174
self,
7275
*,
7376
num_samples: int,
74-
conditions: dict[str, np.ndarray],
77+
conditions: Mapping[str, np.ndarray],
7578
split: bool = False,
7679
**kwargs,
7780
) -> dict[str, dict[str, np.ndarray]]:
@@ -111,7 +114,7 @@ def sample(
111114
# Optionally split the arrays along the last axis.
112115
if split:
113116
raise NotImplementedError("split=True is currently not supported for `PointApproximator`.")
114-
samples = split_arrays(samples, axis=-1)
117+
115118
# Squeeze sample dictionary if there's only one key-value pair.
116119
samples = self._squeeze_parametric_score_major_dict(samples)
117120

@@ -120,7 +123,7 @@ def sample(
120123
def log_prob(
121124
self,
122125
*,
123-
data: dict[str, np.ndarray],
126+
data: Mapping[str, np.ndarray],
124127
**kwargs,
125128
) -> np.ndarray | dict[str, np.ndarray]:
126129
"""
@@ -152,14 +155,14 @@ def log_prob(
152155

153156
return log_prob
154157

155-
def _prepare_conditions(self, conditions: dict[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
158+
def _prepare_conditions(self, conditions: Mapping[str, np.ndarray], **kwargs) -> dict[str, Tensor]:
156159
"""Adapts and converts the conditions to tensors."""
157160
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
158161
conditions.pop("inference_variables", None)
159162
return keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
160163

161164
def _apply_inverse_adapter_to_estimates(
162-
self, estimates: dict[str, dict[str, Tensor]], **kwargs
165+
self, estimates: Mapping[str, Mapping[str, Tensor]], **kwargs
163166
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
164167
"""Applies the inverse adapter on each inner element of the _estimate output dictionary."""
165168
estimates = keras.tree.map_structure(keras.ops.convert_to_numpy, estimates)
@@ -183,7 +186,7 @@ def _apply_inverse_adapter_to_estimates(
183186
return processed
184187

185188
def _apply_inverse_adapter_to_samples(
186-
self, samples: dict[str, Tensor], **kwargs
189+
self, samples: Mapping[str, Tensor], **kwargs
187190
) -> dict[str, dict[str, np.ndarray]]:
188191
"""Applies the inverse adapter to a dictionary of samples."""
189192
samples = keras.tree.map_structure(keras.ops.convert_to_numpy, samples)
@@ -198,7 +201,7 @@ def _apply_inverse_adapter_to_samples(
198201
return processed
199202

200203
def _reorder_estimates(
201-
self, estimates: dict[str, dict[str, dict[str, np.ndarray]]]
204+
self, estimates: Mapping[str, Mapping[str, Mapping[str, np.ndarray]]]
202205
) -> dict[str, dict[str, dict[str, np.ndarray]]]:
203206
"""Reorders the nested dictionary so that the inference variable names become the top-level keys."""
204207
# Grab the variable names from one sample inner dictionary.
@@ -212,7 +215,7 @@ def _reorder_estimates(
212215
return reordered
213216

214217
def _squeeze_estimates(
215-
self, estimates: dict[str, dict[str, dict[str, np.ndarray]]]
218+
self, estimates: Mapping[str, Mapping[str, Mapping[str, np.ndarray]]]
216219
) -> dict[str, dict[str, np.ndarray]]:
217220
"""Squeezes each inner estimate dictionary to remove unnecessary nesting."""
218221
squeezed = {}
@@ -224,7 +227,7 @@ def _squeeze_estimates(
224227
return squeezed
225228

226229
def _squeeze_parametric_score_major_dict(
227-
self, samples: dict[str, np.ndarray]
230+
self, samples: Mapping[str, np.ndarray]
228231
) -> np.ndarray or dict[str, np.ndarray]:
229232
"""Squeezes the dictionary to just the value if there is only one key-value pair."""
230233
if len(samples) == 1:

bayesflow/datasets/offline_dataset.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
import keras
1+
from collections.abc import Mapping
2+
23
import numpy as np
34

5+
import keras
6+
47
from bayesflow.adapters import Adapter
58
from bayesflow.utils import logging
69

@@ -13,7 +16,12 @@ class OfflineDataset(keras.utils.PyDataset):
1316
"""
1417

1518
def __init__(
16-
self, data: dict[str, np.ndarray], batch_size: int, adapter: Adapter | None, num_samples: int = None, **kwargs
19+
self,
20+
data: Mapping[str, np.ndarray],
21+
batch_size: int,
22+
adapter: Adapter | None,
23+
num_samples: int = None,
24+
**kwargs,
1725
):
1826
super().__init__(**kwargs)
1927
self.batch_size = batch_size
@@ -60,7 +68,7 @@ def shuffle(self) -> None:
6068
np.random.shuffle(self.indices)
6169

6270
@staticmethod
63-
def _get_num_samples_from_data(data: dict) -> int:
71+
def _get_num_samples_from_data(data: Mapping) -> int:
6472
for key, value in data.items():
6573
if hasattr(value, "shape"):
6674
ndim = len(value.shape)

bayesflow/diagnostics/metrics/calibration_error.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Sequence, Any, Mapping, Callable
1+
from collections.abc import Sequence, Mapping, Callable
22

33
import numpy as np
44

@@ -14,7 +14,7 @@ def calibration_error(
1414
aggregation: Callable = np.median,
1515
min_quantile: float = 0.005,
1616
max_quantile: float = 0.995,
17-
) -> Mapping[str, Any]:
17+
) -> dict[str, any]:
1818
"""
1919
Computes an aggregate score for the marginal calibration error over an ensemble of approximate
2020
posteriors. The calibration error is given as the aggregate (e.g., median) of the absolute deviation

bayesflow/diagnostics/metrics/expected_calibration_error.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from collections.abc import Sequence
2+
13
import numpy as np
24
from keras import ops
3-
from typing import Sequence, Any, Mapping
45

56
from ...utils.exceptions import ShapeError
67
from sklearn.calibration import calibration_curve
@@ -12,7 +13,7 @@ def expected_calibration_error(
1213
model_names: Sequence[str] = None,
1314
n_bins: int = 10,
1415
return_probs: bool = False,
15-
) -> Mapping[str, Any]:
16+
) -> dict[str, any]:
1617
"""
1718
Estimates the expected calibration error (ECE) of a model comparison network according to [1].
1819

bayesflow/diagnostics/metrics/posterior_contraction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Sequence, Any, Mapping, Callable
1+
from collections.abc import Sequence, Mapping, Callable
22

33
import numpy as np
44

@@ -11,7 +11,7 @@ def posterior_contraction(
1111
variable_keys: Sequence[str] = None,
1212
variable_names: Sequence[str] = None,
1313
aggregation: Callable = np.median,
14-
) -> Mapping[str, Any]:
14+
) -> dict[str, any]:
1515
"""
1616
Computes the posterior contraction (PC) from prior to posterior for the given samples.
1717

bayesflow/diagnostics/metrics/root_mean_squared_error.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Sequence, Any, Mapping, Callable
1+
from collections.abc import Sequence, Mapping, Callable
22

33
import numpy as np
44

@@ -12,7 +12,7 @@ def root_mean_squared_error(
1212
variable_names: Sequence[str] = None,
1313
normalize: bool = True,
1414
aggregation: Callable = np.median,
15-
) -> Mapping[str, Any]:
15+
) -> dict[str, any]:
1616
"""
1717
Computes the (Normalized) Root Mean Squared Error (RMSE/NRMSE) for the given posterior and prior samples.
1818

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
from collections.abc import Mapping, Sequence
2+
13
import numpy as np
24
import matplotlib.pyplot as plt
35

4-
from typing import Sequence
56
from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
67
from ...utils.ecdf import simultaneous_ecdf_bands
78
from ...utils.ecdf.ranks import fractional_ranks, distance_ranks
89

910

1011
def calibration_ecdf(
11-
estimates: dict[str, np.ndarray] | np.ndarray,
12-
targets: dict[str, np.ndarray] | np.ndarray,
12+
estimates: Mapping[str, np.ndarray] | np.ndarray,
13+
targets: Mapping[str, np.ndarray] | np.ndarray,
1314
variable_keys: Sequence[str] = None,
1415
variable_names: Sequence[str] = None,
1516
difference: bool = False,

0 commit comments

Comments
 (0)