|
8 | 8 | import numpy as np |
9 | 9 | import pymc as pm |
10 | 10 | import pytest |
11 | | -from scipy import stats |
12 | 11 |
|
13 | 12 | import nutpie |
14 | 13 | import nutpie.compile_pymc |
@@ -270,38 +269,76 @@ def test_pymc_var_names(backend, gradient_backend): |
270 | 269 |
|
271 | 270 | @pytest.mark.pymc |
272 | 271 | @pytest.mark.flow |
273 | | -@pytest.mark.parametrize("kind", ["masked", "subset"]) |
| 272 | +@pytest.mark.parametrize("kind", ["masked"]) |
274 | 273 | def test_normalizing_flow(kind): |
275 | | - with pm.Model() as model: |
276 | | - pm.HalfNormal("x", shape=2) |
| 274 | + import jax |
277 | 275 |
|
278 | | - compiled = nutpie.compile_pymc_model( |
279 | | - model, backend="jax", gradient_backend="jax" |
280 | | - ).with_transform_adapt( |
281 | | - verbose=True, |
282 | | - coupling_type=kind, |
283 | | - num_layers=4, |
284 | | - ) |
285 | | - trace = nutpie.sample( |
286 | | - compiled, |
287 | | - chains=1, |
288 | | - transform_adapt=True, |
289 | | - window_switch_freq=128, |
290 | | - seed=1, |
291 | | - draws=2000, |
292 | | - ) |
293 | | - draws = trace.posterior.x.isel(x_dim_0=0, chain=0) |
294 | | - kstest = stats.ks_1samp(draws, stats.halfnorm.cdf) |
295 | | - assert kstest.pvalue > 0.01 |
| 276 | + old_x64 = jax.config.update("jax_enable_x64", True) |
| 277 | + |
| 278 | + try: |
| 279 | + with pm.Model() as model: |
| 280 | + pm.HalfNormal("x", shape=2) |
296 | 281 |
|
297 | | - draws = trace.posterior.x.isel(x_dim_0=1, chain=0) |
298 | | - kstest = stats.ks_1samp(draws, stats.halfnorm.cdf) |
299 | | - assert kstest.pvalue > 0.01 |
| 282 | + compiled = nutpie.compile_pymc_model( |
| 283 | + model, backend="jax", gradient_backend="jax" |
| 284 | + ).with_transform_adapt( |
| 285 | + verbose=True, |
| 286 | + coupling_type=kind, |
| 287 | + num_layers=2, |
| 288 | + ) |
| 289 | + trace = nutpie.sample( |
| 290 | + compiled, |
| 291 | + chains=1, |
| 292 | + transform_adapt=True, |
| 293 | + window_switch_freq=128, |
| 294 | + seed=1, |
| 295 | + draws=2000, |
| 296 | + ) |
| 297 | + |
| 298 | + compiled = nutpie.compile_pymc_model( |
| 299 | + model, backend="jax", gradient_backend="jax" |
| 300 | + ).with_transform_adapt( |
| 301 | + verbose=True, |
| 302 | + coupling_type=kind, |
| 303 | + num_layers=2, |
| 304 | + ) |
| 305 | + trace2 = nutpie.sample( |
| 306 | + compiled, |
| 307 | + chains=1, |
| 308 | + transform_adapt=True, |
| 309 | + window_switch_freq=128, |
| 310 | + seed=1, |
| 311 | + draws=2000, |
| 312 | + ) |
| 313 | + draws1 = trace.posterior.x |
| 314 | + draws2 = trace2.posterior.x |
| 315 | + |
| 316 | + # Check that the two draws are the same |
| 317 | + np.testing.assert_allclose(draws1, draws2) |
| 318 | + |
| 319 | + # Compare to precompute values to make sure it is reproducible |
| 320 | + # accross architectures |
| 321 | + expected = np.array( |
| 322 | + [ |
| 323 | + [1.81033486, 1.18735544], |
| 324 | + [0.12551686, 0.04161655], |
| 325 | + [1.07813544, 0.12578679], |
| 326 | + [0.71503155, 0.37380833], |
| 327 | + [0.83237662, 0.67041153], |
| 328 | + ] |
| 329 | + ) |
| 330 | + print(expected) |
| 331 | + np.testing.assert_allclose( |
| 332 | + draws1.isel(chain=0, draw=slice(0, 5)).values, expected, atol=1e-5 |
| 333 | + ) |
| 334 | + finally: |
| 335 | + # Restore the original config |
| 336 | + jax.config.update("jax_enable_x64", old_x64) |
300 | 337 |
|
301 | 338 |
|
302 | 339 | @pytest.mark.pymc |
303 | 340 | @pytest.mark.flow |
304 | | -@pytest.mark.parametrize("kind", ["masked", "subset"]) |
| 341 | +@pytest.mark.parametrize("kind", ["masked"]) |
305 | 342 | def test_normalizing_flow_1d(kind): |
306 | 343 | with pm.Model() as model: |
307 | 344 | pm.HalfNormal("x") |
|
0 commit comments