Skip to content

Commit 2ddf784

Browse files
committed
test: check that normalizing flows are reproducible
1 parent f735356 commit 2ddf784

File tree

2 files changed

+64
-27
lines changed

2 files changed

+64
-27
lines changed

python/nutpie/normalizing_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1916,7 +1916,7 @@ def make_flow(
19161916
diag = jnp.sqrt(pos_std / grad_std)
19171917
mean = positions.mean(0) + gradients.mean(0) * diag * diag
19181918

1919-
key = jax.random.PRNGKey(seed % (2**63))
1919+
key = jax.random.key(seed % (2**63), impl="threefry2x32")
19201920

19211921
diag_param = Parameterize(
19221922
lambda x: x + jnp.sqrt(1 + x**2),

tests/test_pymc.py

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import numpy as np
99
import pymc as pm
1010
import pytest
11-
from scipy import stats
1211

1312
import nutpie
1413
import nutpie.compile_pymc
@@ -270,38 +269,76 @@ def test_pymc_var_names(backend, gradient_backend):
270269

271270
@pytest.mark.pymc
272271
@pytest.mark.flow
273-
@pytest.mark.parametrize("kind", ["masked", "subset"])
272+
@pytest.mark.parametrize("kind", ["masked"])
274273
def test_normalizing_flow(kind):
275-
with pm.Model() as model:
276-
pm.HalfNormal("x", shape=2)
274+
import jax
277275

278-
compiled = nutpie.compile_pymc_model(
279-
model, backend="jax", gradient_backend="jax"
280-
).with_transform_adapt(
281-
verbose=True,
282-
coupling_type=kind,
283-
num_layers=4,
284-
)
285-
trace = nutpie.sample(
286-
compiled,
287-
chains=1,
288-
transform_adapt=True,
289-
window_switch_freq=128,
290-
seed=1,
291-
draws=2000,
292-
)
293-
draws = trace.posterior.x.isel(x_dim_0=0, chain=0)
294-
kstest = stats.ks_1samp(draws, stats.halfnorm.cdf)
295-
assert kstest.pvalue > 0.01
276+
old_x64 = jax.config.update("jax_enable_x64", True)
277+
278+
try:
279+
with pm.Model() as model:
280+
pm.HalfNormal("x", shape=2)
296281

297-
draws = trace.posterior.x.isel(x_dim_0=1, chain=0)
298-
kstest = stats.ks_1samp(draws, stats.halfnorm.cdf)
299-
assert kstest.pvalue > 0.01
282+
compiled = nutpie.compile_pymc_model(
283+
model, backend="jax", gradient_backend="jax"
284+
).with_transform_adapt(
285+
verbose=True,
286+
coupling_type=kind,
287+
num_layers=2,
288+
)
289+
trace = nutpie.sample(
290+
compiled,
291+
chains=1,
292+
transform_adapt=True,
293+
window_switch_freq=128,
294+
seed=1,
295+
draws=2000,
296+
)
297+
298+
compiled = nutpie.compile_pymc_model(
299+
model, backend="jax", gradient_backend="jax"
300+
).with_transform_adapt(
301+
verbose=True,
302+
coupling_type=kind,
303+
num_layers=2,
304+
)
305+
trace2 = nutpie.sample(
306+
compiled,
307+
chains=1,
308+
transform_adapt=True,
309+
window_switch_freq=128,
310+
seed=1,
311+
draws=2000,
312+
)
313+
draws1 = trace.posterior.x
314+
draws2 = trace2.posterior.x
315+
316+
# Check that the two draws are the same
317+
np.testing.assert_allclose(draws1, draws2)
318+
319+
# Compare to precompute values to make sure it is reproducible
320+
# accross architectures
321+
expected = np.array(
322+
[
323+
[1.81033486, 1.18735544],
324+
[0.12551686, 0.04161655],
325+
[1.07813544, 0.12578679],
326+
[0.71503155, 0.37380833],
327+
[0.83237662, 0.67041153],
328+
]
329+
)
330+
print(expected)
331+
np.testing.assert_allclose(
332+
draws1.isel(chain=0, draw=slice(0, 5)).values, expected, atol=1e-5
333+
)
334+
finally:
335+
# Restore the original config
336+
jax.config.update("jax_enable_x64", old_x64)
300337

301338

302339
@pytest.mark.pymc
303340
@pytest.mark.flow
304-
@pytest.mark.parametrize("kind", ["masked", "subset"])
341+
@pytest.mark.parametrize("kind", ["masked"])
305342
def test_normalizing_flow_1d(kind):
306343
with pm.Model() as model:
307344
pm.HalfNormal("x")

0 commit comments

Comments
 (0)