Skip to content

Commit 49dfa2a

Browse files
committed
fix: fix normalizing flows for 1d posteriors
1 parent 32985e8 commit 49dfa2a

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

python/nutpie/transform_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ def update(self, seed, positions, gradients, logps):
666666

667667
if self._debug_save_bijection:
668668
_BIJECTION_TRACE.append(
669-
(self.index, fit, (positions, gradients, logps))
669+
(self.index, base, (positions, gradients, logps))
670670
)
671671
return
672672

tests/test_pymc.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,33 @@ def test_normalizing_flow(kind):
299299
assert kstest.pvalue > 0.01
300300

301301

302+
@pytest.mark.pymc
303+
@pytest.mark.flow
304+
@pytest.mark.parametrize("kind", ["masked", "subset"])
305+
def test_normalizing_flow_1d(kind):
306+
with pm.Model() as model:
307+
pm.HalfNormal("x")
308+
309+
compiled = nutpie.compile_pymc_model(
310+
model, backend="jax", gradient_backend="jax"
311+
).with_transform_adapt(
312+
num_diag_windows=6,
313+
verbose=True,
314+
coupling_type=kind,
315+
)
316+
trace = nutpie.sample(
317+
compiled,
318+
chains=1,
319+
transform_adapt=True,
320+
window_switch_freq=150,
321+
tune=600,
322+
seed=1,
323+
)
324+
draws = trace.posterior.x.isel(chain=0)
325+
kstest = stats.ks_1samp(draws, stats.halfnorm.cdf)
326+
assert kstest.pvalue > 0.01
327+
328+
302329
@pytest.mark.pymc
303330
@pytest.mark.parametrize(
304331
("backend", "gradient_backend"),

0 commit comments

Comments
 (0)