From 18e5a07583b38c3659d8dc580ec58a6b9356c564 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 28 May 2025 14:33:55 +0200 Subject: [PATCH] fix(numba): non-contiguous shared variable Shared variables (ie pm.Data) that were non-contiguous could lead to incorrect results in the pymc numba backend. We now ensure that they are always c-contiguous by copying if they are not. --- python/nutpie/compile_pymc.py | 4 ++-- tests/test_pymc.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 7cfe121..c714146 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -131,7 +131,7 @@ def with_data(self, **updates): if name not in shared_data: raise KeyError(f"Unknown shared variable: {name}") old_val = shared_data[name] - new_val = np.asarray(new_val, dtype=old_val.dtype).copy() + new_val = np.array(new_val, dtype=old_val.dtype, order="C", copy=True) new_val.flags.writeable = False if old_val.ndim != new_val.ndim: raise ValueError( @@ -256,7 +256,7 @@ def _compile_pymc_model_numba( for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]: if val.name in shared_data and val not in seen: raise ValueError(f"Shared variables must have unique names: {val.name}") - shared_data[val.name] = val.get_value() + shared_data[val.name] = np.array(val.get_value(), order="C", copy=True) shared_vars[val.name] = val seen.add(val) diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 51e2159..938f36c 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -31,6 +31,38 @@ def test_pymc_model(backend, gradient_backend): trace.posterior.a # noqa: B018 +@pytest.mark.pymc +def test_order_shared(): + a_val = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]) + with pm.Model() as model: + a = pm.Data("a", np.copy(a_val, order="C")) + b = pm.Normal("b", shape=(2, 5)) + pm.Deterministic("c", (a[:, None, :] * b[:, :, None]).sum(-1)) + + compiled = nutpie.compile_pymc_model(model, backend="numba") + trace = nutpie.sample(compiled) + np.testing.assert_allclose( + ( + trace.posterior.b.values[:, :, :, :, None] * a_val[None, None, :, None, :] + ).sum(-1), + trace.posterior.c.values, + ) + + with pm.Model() as model: + a = pm.Data("a", np.copy(a_val, order="F")) + b = pm.Normal("b", shape=(2, 5)) + pm.Deterministic("c", (a[:, None, :] * b[:, :, None]).sum(-1)) + + compiled = nutpie.compile_pymc_model(model, backend="numba") + trace = nutpie.sample(compiled) + np.testing.assert_allclose( + ( + trace.posterior.b.values[:, :, :, :, None] * a_val[None, None, :, None, :] + ).sum(-1), + trace.posterior.c.values, + ) + + @pytest.mark.pymc @parameterize_backends def test_low_rank(backend, gradient_backend):