@@ -270,7 +270,7 @@ def test_pymc_var_names(backend, gradient_backend):
270270
271271@pytest .mark .pymc
272272@pytest .mark .flow
273- @pytest .mark .parametrize ("kind" , ["masked" , "subset" ])
273+ @pytest .mark .parametrize ("kind" , ["masked" ])
274274def test_normalizing_flow (kind ):
275275 with pm .Model () as model :
276276 pm .HalfNormal ("x" , shape = 2 )
@@ -280,7 +280,7 @@ def test_normalizing_flow(kind):
280280 ).with_transform_adapt (
281281 verbose = True ,
282282 coupling_type = kind ,
283- num_layers = 4 ,
283+ num_layers = 2 ,
284284 )
285285 trace = nutpie .sample (
286286 compiled ,
@@ -290,13 +290,38 @@ def test_normalizing_flow(kind):
290290 seed = 1 ,
291291 draws = 2000 ,
292292 )
293- draws = trace .posterior .x .isel (x_dim_0 = 0 , chain = 0 )
294- kstest = stats .ks_1samp (draws , stats .halfnorm .cdf )
295- assert kstest .pvalue > 0.01
296293
297- draws = trace .posterior .x .isel (x_dim_0 = 1 , chain = 0 )
298- kstest = stats .ks_1samp (draws , stats .halfnorm .cdf )
299- assert kstest .pvalue > 0.01
294+ compiled = nutpie .compile_pymc_model (
295+ model , backend = "jax" , gradient_backend = "jax"
296+ ).with_transform_adapt (
297+ verbose = True ,
298+ coupling_type = kind ,
299+ num_layers = 2 ,
300+ )
301+ trace2 = nutpie .sample (
302+ compiled ,
303+ chains = 1 ,
304+ transform_adapt = True ,
305+ window_switch_freq = 128 ,
306+ seed = 1 ,
307+ draws = 2000 ,
308+ )
309+ draws1 = trace .posterior .x
310+ draws2 = trace2 .posterior .x
311+
312+ # Check that the two draws are the same
313+ assert np .allclose (draws1 , draws2 )
314+
315+ # Compare to precompute values to make sure it is reproducible
316+ # accross architectures
317+ expected = np .array ([
318+ [1.81033486 , 1.18735544 ],
319+ [0.12551686 , 0.04161655 ],
320+ [1.07813544 , 0.12578679 ],
321+ [0.71503155 , 0.37380833 ],
322+ [0.83237662 , 0.67041153 ]
323+ ])
324+ assert np .allclose (draws1 .isel (chain = 0 , draw = slice (0 , 5 )), expected , atol = 1e-5 )
300325
301326
302327@pytest .mark .pymc
0 commit comments