Skip to content

Commit 129792a

Browse files
committed
Fixed merge conflict
2 parents 1632a9e + effbb00 commit 129792a

34 files changed

+556
-118
lines changed

.github/workflows/style.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: Check Code Style
2+
3+
on:
4+
workflow_dispatch:
5+
pull_request:
6+
branches:
7+
- main
8+
- dev
9+
- update-workflows
10+
push:
11+
branches:
12+
- main
13+
- dev
14+
- update-workflows
15+
16+
jobs:
17+
check-code-style:
18+
name: Check Code Style
19+
20+
runs-on: ubuntu-latest
21+
22+
steps:
23+
- name: Checkout Repository
24+
uses: actions/checkout@v4
25+
26+
- name: Set up Python
27+
uses: actions/setup-python@v5
28+
with:
29+
python-version: "3.10"
30+
31+
- name: Install Ruff
32+
run: |
33+
pip install -U pip setuptools wheel
34+
pip install ruff
35+
36+
- name: Run Linter
37+
run: ruff check --config pyproject.toml --verbose
38+
39+
- name: Run Formatter
40+
run: ruff format --config pyproject.toml --check --verbose

.github/workflows/tests.yaml

Lines changed: 34 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
name: Tests
2+
name: Multi-Backend Tests
33

44
on:
55
workflow_dispatch:
@@ -12,105 +12,82 @@ on:
1212
- main
1313
- dev
1414

15+
defaults:
16+
run:
17+
shell: bash
1518

16-
jobs:
17-
lint:
18-
runs-on: ubuntu-latest
19-
defaults:
20-
run:
21-
shell: bash -el {0}
22-
23-
steps:
24-
- name: Checkout code
25-
uses: actions/checkout@v4
26-
27-
- name: Set up Conda
28-
uses: conda-incubator/setup-miniconda@v3
29-
with:
30-
environment-file: environment.yaml
31-
python-version: "3.10"
32-
33-
- name: Show Environment Info
34-
run: |
35-
conda info
36-
conda list
37-
conda config --show-sources
38-
conda config --show
39-
printenv | sort
40-
41-
- name: Run Linter
42-
run: |
43-
ruff check --config pyproject.toml --verbose
44-
45-
- name: Run Formatter
46-
run: |
47-
ruff format --check --config pyproject.toml --verbose
4819

20+
jobs:
4921
test:
50-
runs-on: ${{ matrix.os }}
22+
name: Run Multi-Backend Tests
23+
5124
strategy:
52-
fail-fast: false
5325
matrix:
5426
os: [ubuntu-latest, windows-latest]
5527
python-version: ["3.10", "3.11"]
5628
backend: ["jax", "tensorflow", "torch"]
57-
defaults:
58-
run:
59-
shell: bash -el {0}
29+
30+
runs-on: ${{ matrix.os }}
31+
6032
env:
6133
KERAS_BACKEND: ${{ matrix.backend }}
6234

6335
steps:
64-
- name: Checkout code
36+
- name: Checkout Repository
6537
uses: actions/checkout@v4
6638

67-
- name: Set up Conda
68-
uses: conda-incubator/setup-miniconda@v3
39+
- name: Set up Python
40+
uses: actions/setup-python@v5
6941
with:
70-
environment-file: environment.yaml
7142
python-version: ${{ matrix.python-version }}
7243

44+
- name: Install Dependencies
45+
run: |
46+
pip install -U pip setuptools wheel
47+
pip install .[test]
48+
7349
- name: Install JAX
7450
if: ${{ matrix.backend == 'jax' }}
7551
run: |
7652
pip install -U jax
77-
- name: Install NumPy
78-
if: ${{ matrix.backend == 'numpy' }}
79-
run: |
80-
conda install numpy
81-
- name: Install Tensorflow
53+
54+
- name: Install TensorFlow
8255
if: ${{ matrix.backend == 'tensorflow' }}
8356
run: |
8457
pip install -U tensorflow
58+
8559
- name: Install PyTorch
8660
if: ${{ matrix.backend == 'torch' }}
8761
run: |
88-
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
62+
pip install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
8963
9064
- name: Show Environment Info
9165
run: |
92-
conda info
93-
conda list
94-
conda config --show-sources
95-
conda config --show
66+
python --version
67+
pip --version
9668
printenv | sort
97-
conda env export
9869
pip list
9970
10071
- name: Run Tests
10172
run: |
102-
pytest -x
73+
pytest -x -m "not slow"
10374
104-
- name: Create Coverage Report
75+
- name: Run Slow Tests
76+
# run all slow tests only on manual trigger
77+
if: github.event_name == 'workflow_dispatch'
10578
run: |
106-
coverage xml
79+
pytest -m "slow"
10780
10881
- name: Upload test results to Codecov
10982
if: ${{ !cancelled() }}
110-
uses: codecov/test-results-action@v1
83+
uses: codecov/codecov-action@v4
11184
with:
11285
token: ${{ secrets.CODECOV_TOKEN }}
11386

87+
- name: Create Coverage Report
88+
run: |
89+
coverage xml
90+
11491
- name: Upload Coverage Reports to CodeCov
11592
uses: codecov/codecov-action@v4
11693
with:

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ For an in-depth exposition, check out our walkthrough notebooks below.
5454
1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb)
5555
2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb)
5656
3. [Two moons starter example](examples/Two_Moons_Starter.ipynb)
57-
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_point_estimation_and_expert_stats.ipynb)
57+
4. [Rapid iteration with point estimators](examples/Lotka_Volterra_Point_Estimation_and_Expert_Stats.ipynb)
5858
5. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
5959
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
6060
7. [Simple model comparison example](examples/One_Sample_TTest.ipynb)

bayesflow/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,12 @@ def setup():
4040
torch.autograd.set_grad_enabled(False)
4141

4242
logging.warning(
43+
"\n"
4344
"When using torch backend, we need to disable autograd by default to avoid excessive memory usage. Use\n"
45+
"\n"
4446
"with torch.enable_grad():\n"
47+
" ...\n"
48+
"\n"
4549
"in contexts where you need gradients (e.g. custom training loops)."
4650
)
4751

bayesflow/adapters/adapter.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,24 @@ def rename(self, from_key: str, to_key: str):
556556
self.transforms.append(Rename(from_key, to_key))
557557
return self
558558

559+
def scale(self, keys: str | Sequence[str], by: float | np.ndarray):
560+
from .transforms import Scale
561+
562+
if isinstance(keys, str):
563+
keys = [keys]
564+
565+
self.transforms.append(MapTransform({key: Scale(scale=by) for key in keys}))
566+
return self
567+
568+
def shift(self, keys: str | Sequence[str], by: float | np.ndarray):
569+
from .transforms import Shift
570+
571+
if isinstance(keys, str):
572+
keys = [keys]
573+
574+
self.transforms.append(MapTransform({key: Shift(shift=by) for key in keys}))
575+
return self
576+
559577
def sqrt(self, keys: str | Sequence[str]):
560578
"""Append an :py:class:`~transforms.Sqrt` transform to the adapter.
561579

bayesflow/adapters/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from .numpy_transform import NumpyTransform
1515
from .one_hot import OneHot
1616
from .rename import Rename
17+
from .scale import Scale
18+
from .shift import Shift
1719
from .sqrt import Sqrt
1820
from .standardize import Standardize
1921
from .to_array import ToArray
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from keras.saving import (
2+
deserialize_keras_object as deserialize,
3+
register_keras_serializable as serializable,
4+
serialize_keras_object as serialize,
5+
)
6+
import numpy as np
7+
8+
from .elementwise_transform import ElementwiseTransform
9+
10+
11+
@serializable(package="bayesflow.adapters")
12+
class Scale(ElementwiseTransform):
13+
def __init__(self, scale: np.typing.ArrayLike):
14+
self.scale = np.array(scale)
15+
16+
@classmethod
17+
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
18+
return cls(scale=deserialize(config["scale"]))
19+
20+
def get_config(self) -> dict:
21+
return {"scale": serialize(self.scale)}
22+
23+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
24+
return data * self.scale
25+
26+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
27+
return data / self.scale
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from keras.saving import (
2+
deserialize_keras_object as deserialize,
3+
register_keras_serializable as serializable,
4+
serialize_keras_object as serialize,
5+
)
6+
import numpy as np
7+
8+
from .elementwise_transform import ElementwiseTransform
9+
10+
11+
@serializable(package="bayesflow.adapters")
12+
class Shift(ElementwiseTransform):
13+
def __init__(self, shift: np.typing.ArrayLike):
14+
self.shift = np.array(shift)
15+
16+
@classmethod
17+
def from_config(cls, config: dict, custom_objects=None) -> "ElementwiseTransform":
18+
return cls(shift=deserialize(config["shift"]))
19+
20+
def get_config(self) -> dict:
21+
return {"shift": serialize(self.shift)}
22+
23+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
24+
return data + self.shift
25+
26+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
27+
return data - self.shift

bayesflow/diagnostics/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .posterior_contraction import posterior_contraction
33
from .root_mean_squared_error import root_mean_squared_error
44
from .expected_calibration_error import expected_calibration_error
5+
from .classifier_two_sample_test import classifier_two_sample_test

0 commit comments

Comments
 (0)