Skip to content

Commit 0ae81f7

Browse files
committed
test: check that normalizing flows are reproducible
1 parent fc199cb commit 0ae81f7

File tree

1 file changed

+33
-8
lines changed

1 file changed

+33
-8
lines changed

tests/test_pymc.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def test_pymc_var_names(backend, gradient_backend):
270270

271271
@pytest.mark.pymc
272272
@pytest.mark.flow
273-
@pytest.mark.parametrize("kind", ["masked", "subset"])
273+
@pytest.mark.parametrize("kind", ["masked"])
274274
def test_normalizing_flow(kind):
275275
with pm.Model() as model:
276276
pm.HalfNormal("x", shape=2)
@@ -280,7 +280,7 @@ def test_normalizing_flow(kind):
280280
).with_transform_adapt(
281281
verbose=True,
282282
coupling_type=kind,
283-
num_layers=4,
283+
num_layers=2,
284284
)
285285
trace = nutpie.sample(
286286
compiled,
@@ -290,13 +290,38 @@ def test_normalizing_flow(kind):
290290
seed=1,
291291
draws=2000,
292292
)
293-
draws = trace.posterior.x.isel(x_dim_0=0, chain=0)
294-
kstest = stats.ks_1samp(draws, stats.halfnorm.cdf)
295-
assert kstest.pvalue > 0.01
296293

297-
draws = trace.posterior.x.isel(x_dim_0=1, chain=0)
298-
kstest = stats.ks_1samp(draws, stats.halfnorm.cdf)
299-
assert kstest.pvalue > 0.01
294+
compiled = nutpie.compile_pymc_model(
295+
model, backend="jax", gradient_backend="jax"
296+
).with_transform_adapt(
297+
verbose=True,
298+
coupling_type=kind,
299+
num_layers=2,
300+
)
301+
trace2 = nutpie.sample(
302+
compiled,
303+
chains=1,
304+
transform_adapt=True,
305+
window_switch_freq=128,
306+
seed=1,
307+
draws=2000,
308+
)
309+
draws1 = trace.posterior.x
310+
draws2 = trace2.posterior.x
311+
312+
# Check that the two draws are the same
313+
assert np.allclose(draws1, draws2)
314+
315+
# Compare to precompute values to make sure it is reproducible
316+
# accross architectures
317+
expected = np.array([
318+
[1.81033486, 1.18735544],
319+
[0.12551686, 0.04161655],
320+
[1.07813544, 0.12578679],
321+
[0.71503155, 0.37380833],
322+
[0.83237662, 0.67041153]
323+
])
324+
assert np.allclose(draws1.isel(chain=0, draw=slice(0, 5)), expected, atol=1e-5)
300325

301326

302327
@pytest.mark.pymc

0 commit comments

Comments
 (0)