Skip to content

Commit 8c9e6b7

Browse files
committed
Add basic workflow tests and reduce coupling depth for tests
1 parent 129792a commit 8c9e6b7

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

tests/test_networks/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ def flow_matching_subnet(subnet):
3838
def coupling_flow():
3939
from bayesflow.networks import CouplingFlow
4040

41-
return CouplingFlow()
41+
return CouplingFlow(depth=2)
4242

4343

4444
@pytest.fixture()
4545
def coupling_flow_subnet(subnet):
4646
from bayesflow.networks import CouplingFlow
4747

48-
return CouplingFlow(subnet=subnet)
48+
return CouplingFlow(depth=2, subnet=subnet)
4949

5050

5151
@pytest.fixture()

tests/test_workflows/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
4+
@pytest.fixture()
5+
def inference_network():
6+
from bayesflow.networks import CouplingFlow
7+
8+
return CouplingFlow(depth=2)
9+
10+
11+
@pytest.fixture()
12+
def summary_network():
13+
from bayesflow.networks import TimeSeriesTransformer
14+
15+
return TimeSeriesTransformer(embed_dims=(8, 8), mlp_widths=(32, 32), mlp_depths=(1, 1))
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import bayesflow as bf
2+
3+
4+
def test_classifier_two_sample_test(inference_network, summary_network):
5+
workflow = bf.BasicWorkflow(
6+
inference_network=inference_network,
7+
summary_network=summary_network,
8+
inference_variables=["parameters"],
9+
summary_variables=["observables"],
10+
simulator=bf.simulators.SIR(),
11+
)
12+
13+
history = workflow.fit_online(epochs=2, batch_size=32, num_batches_per_epoch=2)
14+
plots = workflow.plot_default_diagnostics(test_data=50, num_samples=50)
15+
metrics = workflow.compute_default_diagnostics(test_data=50, num_samples=50, variable_names=["p1", "p2"])
16+
17+
assert "loss" in list(history.history.keys())
18+
assert len(history.history["loss"]) == 2
19+
assert list(plots.keys()) == ["losses", "recovery", "calibration_ecdf", "z_score_contraction"]
20+
assert list(metrics.columns) == ["p1", "p2"]
21+
assert metrics.values.shape == (3, 2)

0 commit comments

Comments
 (0)