Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@

All notable changes to this project will be documented in this file.

## [0.15.1] - 2025-05-28

### Bug Fixes

- Incorrect results with non-contiguous shared variable (Adrian Seyboldt)

- Allow upper case backend string (Adrian Seyboldt)

- Allow data named x with unfrozen model (Adrian Seyboldt)


### Styling

- Fix small typing issue (Adrian Seyboldt)


## [0.15.0] - 2025-05-27

### Bug Fixes
Expand Down Expand Up @@ -39,7 +55,9 @@ All notable changes to this project will be documented in this file.

- Add entries to gitignore (Adrian Seyboldt)

- Bump dependencies (Adrian Seyboldt)
- Update dependencies (Adrian Seyboldt)

- Update changelog (Adrian Seyboldt)


### Styling
Expand Down
10 changes: 5 additions & 5 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "nutpie"
version = "0.15.0"
version = "0.15.1"
authors = [
"Adrian Seyboldt <adrian.seyboldt@gmail.com>",
"PyMC Developers <pymc.devs@gmail.com>",
Expand Down
17 changes: 11 additions & 6 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def intrinsic(f):


def _rv_dict_to_flat_array_wrapper(
fn: Callable[[SeedType], dict[str, np.ndarray]],
fn: Callable[[SeedType | None], dict[str, np.ndarray]],
names: list[str],
shapes: list[tuple[int]],
) -> Callable[[SeedType], np.ndarray]:
Expand Down Expand Up @@ -61,7 +61,7 @@ def _rv_dict_to_flat_array_wrapper(
"""

@wraps(fn)
def seeded_array_fn(seed: SeedType = None):
def seeded_array_fn(seed: SeedType | None = None):
initial_value_dict = fn(seed)
total_size = sum(np.prod(shape).astype(int) for shape in shapes)
flat_array = np.empty(total_size, dtype="float64", order="C")
Expand Down Expand Up @@ -406,8 +406,8 @@ def logp_fn_jax_grad(x, *shared):
seen.add(val)

def make_logp_func():
def logp(x, **shared):
logp, grad = logp_fn(x, *[shared[name] for name in logp_shared_names])
def logp(_x, **shared):
logp, grad = logp_fn(_x, *[shared[name] for name in logp_shared_names])
return float(logp), np.asarray(grad, dtype="float64", order="C")

return logp
Expand All @@ -418,8 +418,8 @@ def logp(x, **shared):

def make_expand_func(seed1, seed2, chain):
# TODO handle seeds
def expand(x, **shared):
values = expand_fn(x, *[shared[name] for name in expand_shared_names])
def expand(_x, **shared):
values = expand_fn(_x, *[shared[name] for name in expand_shared_names])
return {
name: np.asarray(val, order="C", dtype=dtype).ravel()
for name, val, dtype in zip(names, values, dtypes, strict=True)
Expand Down Expand Up @@ -499,6 +499,11 @@ def compile_pymc_model(
"and restart your kernel in case you are in an interactive session."
)

if gradient_backend is not None:
gradient_backend = gradient_backend.lower() # type: ignore[assignment]
if backend is not None:
backend = backend.lower() # type: ignore[assignment]

from pymc.model.transform.optimization import freeze_dims_and_data
from pymc.initial_point import make_initial_point_fn

Expand Down
15 changes: 15 additions & 0 deletions tests/test_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ def test_pymc_model(backend, gradient_backend):
trace.posterior.a # noqa: B018


@pytest.mark.pymc
@parameterize_backends
def test_name_x(backend, gradient_backend):
with pm.Model() as model:
x = pm.Data("x", 1.0)
a = pm.Normal("a", mu=x)
pm.Deterministic("z", x * a)

compiled = nutpie.compile_pymc_model(
model, backend=backend, gradient_backend=gradient_backend, freeze_model=False
)
trace = nutpie.sample(compiled, chains=1)
trace.posterior.a # noqa: B018


@pytest.mark.pymc
def test_order_shared():
a_val = np.array([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])
Expand Down