From 76439f33e7b7f11d616efb7e9b675c29b15dbcdb Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 12 Mar 2025 19:05:40 +0100 Subject: [PATCH 01/10] feat: Add masked coupling flow --- python/nutpie/normalizing_flow.py | 549 ++++++++++++++++++++++++++--- python/nutpie/transform_adapter.py | 12 +- tests/test_pymc.py | 4 +- 3 files changed, 513 insertions(+), 52 deletions(-) diff --git a/python/nutpie/normalizing_flow.py b/python/nutpie/normalizing_flow.py index e92f7b0..9d043a9 100644 --- a/python/nutpie/normalizing_flow.py +++ b/python/nutpie/normalizing_flow.py @@ -1,7 +1,8 @@ -from typing import ClassVar, Union, Literal, Callable +from typing import Any, ClassVar, Union, Literal, Callable import math import itertools +from flowjax.bijections.bijection import AbstractBijection from flowjax.bijections.coupling import get_ravelled_pytree_constructor from flowjax.utils import arraylike_to_array import jax @@ -10,9 +11,9 @@ from flowjax import bijections import flowjax.distributions import flowjax.flows -from jaxtyping import Array, ArrayLike +from jaxtyping import Array, ArrayLike, PyTree import numpy as np -from paramax import NonTrainable, Parameterize +from paramax import NonTrainable, Parameterize, unwrap from equinox.nn import Linear from paramax.wrappers import AbstractUnwrappable @@ -446,6 +447,158 @@ def inverse_and_log_det(self, y: Array, condition: Array | None = None): return x, -jnp.log(scale) +class MaskedVmap(AbstractBijection): + bijection: AbstractBijection + in_axes: tuple + axis_size: int + cond_shape: tuple[int, ...] | None + mask: Array + + def __init__( + self, + bijection: AbstractBijection, + mask: Array, + *, + in_axes: PyTree | None | int | Callable = None, + axis_size: int | None = None, + in_axes_condition: int | None = None, + ): + if in_axes is not None and axis_size is not None: + raise ValueError("Cannot specify both in_axes and axis_size.") + + if axis_size is None: + if in_axes is None: + raise ValueError("Either axis_size or in_axes must be provided.") + # _check_no_unwrappables(in_axes) + from flowjax.bijections.jax_transforms import _infer_axis_size_from_params + + axis_size = _infer_axis_size_from_params(unwrap(bijection), in_axes) + + self.in_axes = (0, in_axes, 0, in_axes_condition) + self.bijection = bijection + self.axis_size = axis_size + self.cond_shape = self.get_cond_shape(in_axes_condition) + self.mask = mask + + def vmap(self, f: Callable): + return eqx.filter_vmap(f, in_axes=self.in_axes, axis_size=self.axis_size) + + def transform_and_log_det(self, x, condition=None): + def _transform_and_log_det(mask, bijection, x, condition): + y, det = bijection.transform_and_log_det(x, condition) + return jnp.where(mask, y, x), jnp.where(mask, det, jnp.zeros(())) + + y, log_det = self.vmap(_transform_and_log_det)( + self.mask, self.bijection, x, condition + ) + return y, jnp.sum(log_det) + + def inverse_and_log_det(self, y, condition=None): + def _inverse_and_log_det(mask, bijection, y, condition): + x, det = bijection.inverse_and_log_det(y, condition) + return jnp.where(mask, x, y), jnp.where(mask, det, jnp.zeros(())) + + x, log_det = self.vmap(_inverse_and_log_det)( + self.mask, self.bijection, y, condition + ) + return x, jnp.sum(log_det) + + @property + def shape(self): + return (self.axis_size, *self.bijection.shape) + + def get_cond_shape(self, cond_ax): + if self.bijection.cond_shape is None or cond_ax is None: + return self.bijection.cond_shape + return ( + *self.bijection.cond_shape[:cond_ax], + self.axis_size, + *self.bijection.cond_shape[cond_ax:], + ) + + +class Mask(eqx.Module): + mask: Array + + def __init__(self, mask: Array): + assert mask.dtype == jnp.bool_ + self.mask = mask + + def __call__(self, x: Array, *, key=None) -> Array: + return x * self.mask + + +class Scan(AbstractBijection): + """Repeatedly apply the same bijection with different parameter values. + + Internally, uses `jax.lax.scan` to reduce compilation time. Often it is convenient + to construct these using ``equinox.filter_vmap``. + + Args: + bijection: A bijection, in which the arrays leaves have an additional leading + axis to scan over. It is often can convenient to create compatible + bijections with ``equinox.filter_vmap``. + + Example: + Below is equivilent to ``Chain([Affine(p) for p in params])``. + + .. doctest:: + + >>> from flowjax.bijections import Scan, Affine + >>> import jax.numpy as jnp + >>> import equinox as eqx + >>> params = jnp.ones((3, 2)) + >>> affine = eqx.filter_vmap(Affine)(params) + >>> affine = Scan(affine) + """ + + bijection: AbstractBijection + filter_spec: Any = None + + def transform_and_log_det(self, x, condition=None): + def step(carry, bijection): + x, log_det = carry + y, log_det_i = bijection.transform_and_log_det(x, condition) + return ((y, log_det + log_det_i.sum()), None) + + (y, log_det), _ = _filter_scan( + step, (x, 0), self.bijection, filter_spec=self.filter_spec + ) + return y, log_det + + def inverse_and_log_det(self, y, condition=None): + def step(carry, bijection): + y, log_det = carry + x, log_det_i = bijection.inverse_and_log_det(y, condition) + return ((x, log_det + log_det_i.sum()), None) + + (y, log_det), _ = _filter_scan( + step, (y, 0), self.bijection, reverse=True, filter_spec=self.filter_spec + ) + return y, log_det + + @property + def shape(self): + return self.bijection.shape + + @property + def cond_shape(self): + return self.bijection.cond_shape + + +def _filter_scan(f, init, xs, *, reverse=False, filter_spec=None): + if filter_spec is None: + filter_spec = eqx.is_array + params, static = eqx.partition(xs, filter_spec=filter_spec) + + def _scan_fn(carry, x): + module = eqx.combine(x, static) + carry, y = f(carry, module) + return carry, y + + return jax.lax.scan(_scan_fn, init, params, reverse=reverse) + + class Coupling(bijections.AbstractBijection): """Coupling layer implementation (https://arxiv.org/abs/1605.08803). @@ -565,6 +718,115 @@ def _flat_params_to_transformer(self, params: Array): return transformer +class MaskedCoupling(bijections.AbstractBijection): + """Coupling layer implementation (https://arxiv.org/abs/1605.08803). + + Args: + key: Jax key + transformer: Unconditional bijection with shape () to be parameterised by the + conditioner neural netork. Parameters wrapped with ``NonTrainable`` + are excluded from being parameterized. + untransformed_dim: Number of untransformed conditioning variables (e.g. dim//2). + dim: Total dimension. + cond_dim: Dimension of additional conditioning variables. Defaults to None. + nn_width: Neural network hidden layer width. + nn_depth: Neural network hidden layer size. + nn_activation: Neural network activation function. Defaults to jnn.relu. + """ + + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None + untransformed_mask: Array + dim: int + transformer_constructor: Callable + requires_vmap: bool + conditioner: eqx.nn.MLP | eqx.Module + + @classmethod + def conditioner_output_size(cls, dim, transformer): + constructor, num_params = get_ravelled_pytree_constructor( + transformer, + filter_spec=eqx.is_inexact_array, + is_leaf=lambda leaf: isinstance(leaf, NonTrainable), + ) + return num_params * dim + + def __init__( + self, + key, + *, + transformer: bijections.AbstractBijection, + untransformed_mask: Array, + dim: int, + nn_width: int, + nn_depth: int, + nn_activation: Callable = jax.nn.relu, + conditioner: eqx.Module | None = None, + ): + if transformer.cond_shape is not None: + raise ValueError( + "Only unconditional transformers are supported.", + ) + + constructor, num_params = get_ravelled_pytree_constructor( + transformer, + filter_spec=eqx.is_inexact_array, + is_leaf=lambda leaf: isinstance(leaf, NonTrainable), + ) + + assert transformer.shape == () + self.requires_vmap = True + conditioner_output_size = num_params * dim + + self.transformer_constructor = constructor + self.dim = dim + self.shape = (dim,) + self.cond_shape = None + self.untransformed_mask = untransformed_mask + + if conditioner is None: + self.conditioner = eqx.nn.Sequential( + [ + Mask(untransformed_mask), + eqx.nn.MLP( + in_size=dim, + out_size=conditioner_output_size, + width_size=nn_width, + depth=nn_depth, + activation=nn_activation, + key=key, + ), + ] + ) + else: + self.conditioner = eqx.nn.Sequential( + [ + Mask(untransformed_mask), + conditioner, + ] + ) + + def transform_and_log_det(self, x, condition=None): + transformer_params = self.conditioner(x) + transformer = self._flat_params_to_transformer(transformer_params) + return transformer.transform_and_log_det(x) + + def inverse_and_log_det(self, y, condition=None): + transformer_params = self.conditioner(y) + transformer = self._flat_params_to_transformer(transformer_params) + return transformer.inverse_and_log_det(y) + + def _flat_params_to_transformer(self, params: Array): + """Reshape to dim X params_per_dim, then vmap.""" + assert self.requires_vmap + + transformer_params = jnp.reshape(params, (self.dim, -1)) + transformer = eqx.filter_vmap(self.transformer_constructor)(transformer_params) + return MaskedVmap( + transformer, ~self.untransformed_mask, in_axes=eqx.if_array(0) + ) + + def make_mvscale(key, n_dim, size, randomize_base=False): def make_single_hh(key, idx): key1, key2 = jax.random.split(key) @@ -605,7 +867,7 @@ def make_single_hh(key, idx): ) -def make_elemwise_trafo(key, n_dim, *, count=1): +def make_elemwise_trafo(key, n_dim, *, count=1, vmap=True): def make_elemwise(key, loc): key1, key2 = jax.random.split(key) scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) @@ -635,11 +897,16 @@ def make(key): key, keys = keys[0], keys[1:] loc = jax.random.normal(key=key, shape=(count,)) * 2 loc = loc - loc.mean() + if count == 1: + return make_elemwise(key, loc[0]) return bijections.Chain([make_elemwise(key, mu) for key, mu in zip(keys, loc)]) - keys = jax.random.split(key, n_dim) - make_affine = eqx.filter_vmap(make, axis_size=n_dim)(keys) - return bijections.Vmap(make_affine, in_axes=eqx.if_array(0)) + if vmap: + keys = jax.random.split(key, n_dim) + make_affine = eqx.filter_vmap(make, axis_size=n_dim)(keys) + return bijections.Vmap(make_affine, in_axes=eqx.if_array(0)) + else: + return make(key) def make_coupling(key, dim, n_untransformed, *, inner_mvscale=False, **kwargs): @@ -692,67 +959,159 @@ def make_mlp(out_size): ) -def make_flow( - seed, - positions, - gradients, +def make_flow_scan( + key, + n_dim, *, zero_init=False, - householder_layer=False, - dct_layer=False, - untransformed_dim: int | list[int | None] | None = None, n_layers, nn_width=None, nn_depth=None, + n_embed=None, + n_deembed=None, ): - from flowjax import bijections + dim = n_dim - positions = np.array(positions) - gradients = np.array(gradients) + if nn_width is None: + nn_width = 32 + if n_embed is None: + n_embed = 2 * nn_width + if n_deembed is None: + n_deembed = 2 * nn_width + if nn_depth is None: + nn_depth = 1 - if len(positions) == 0: - return + # Just to get at the size + transformer = AsymmetricAffine() + size = MaskedCoupling.conditioner_output_size(dim, transformer) - n_draws, n_dim = positions.shape + key, key1 = jax.random.split(key) + embed = eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32) + key, key1 = jax.random.split(key) + embed_back = eqx.nn.Linear(n_deembed, size, key=key1, dtype=jnp.float32) - if n_dim < 2: - n_layers = 0 + rng = np.random.default_rng(42) # TODO + order, counts = _generate_permutations(rng, dim, n_layers) + mask = order == 0 + mask[...] = False + for i in range(len(mask)): + mask[i, order[i, : counts[i]]] = True - assert positions.shape == gradients.shape + def make_transformer(): + scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) - if n_draws == 0: - raise ValueError("No draws") - elif n_draws == 1: - assert np.all(gradients != 0) - diag = np.clip(1 / jnp.sqrt(jnp.abs(gradients[0])), 1e-5, 1e5) - assert np.isfinite(diag).all() - mean = jnp.zeros_like(diag) - else: - pos_std = np.clip(positions.std(0), 1e-8, 1e8) - grad_std = np.clip(gradients.std(0), 1e-8, 1e8) - diag = jnp.sqrt(pos_std / grad_std) - mean = positions.mean(0) + gradients.mean(0) * diag * diag + affine = AsymmetricAffine( + jnp.zeros(()), + jnp.ones(()), + jnp.ones(()), + ) - key = jax.random.PRNGKey(seed % (2**63)) + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=affine, + replace=scale, + ) + affine = eqx.tree_at( + where=lambda aff: aff.theta, + pytree=affine, + replace=theta, + ) - diag_param = Parameterize( - lambda x: x + jnp.sqrt(1 + x**2), - (diag**2 - 1) / (2 * diag), + return bijections.Invert(affine) + + def make_mvscale(key, n_dim): + params = jax.random.normal(key, (n_dim,)) + params = params / jnp.linalg.norm(params) + return MvScale(params) + + def make_layer(key, mask, embed, embed_back): + key1, key2, key3, key4 = jax.random.split(key, 4) + transformer = make_transformer() + + conditioner = eqx.nn.Sequential( + [ + embed, + eqx.nn.MLP( + n_embed, + n_deembed, + width_size=nn_width, + depth=nn_depth, + key=key2, + dtype=jnp.float32, + activation=_NN_ACTIVATION, + ), + embed_back, + ] + ) + + coupling = MaskedCoupling( + key=key3, + transformer=transformer, + untransformed_mask=mask, + dim=dim, + conditioner=conditioner, + nn_width=nn_width, + nn_depth=nn_depth, + ) + + coupling = jax.tree_util.tree_map( + lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, + coupling, + ) + + mvscale = make_mvscale(key4, dim) + return bijections.Chain([coupling, mvscale]) + + keys = jax.random.split(key, n_layers) + + base = make_layer(key, mask[0], embed, embed_back) + out_axes = eqx.tree_at( + lambda tree: tree.bijections[0].conditioner.layers[1].layers[0], + pytree=base, + replace=None, ) - diag_affine = bijections.Affine(mean, diag) - diag_affine = eqx.tree_at( - where=lambda aff: aff.scale, - pytree=diag_affine, - replace=diag_param, + out_axes = eqx.tree_at( + lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1], + pytree=out_axes, + replace=None, ) + out_axes = jax.tree.map(lambda leaf: eqx.if_array(0)(leaf), out_axes) - flows = [ - diag_affine, - ] + vectorized = eqx.filter_vmap( + make_layer, in_axes=(0, 0, None, None), out_axes=out_axes + ) - if n_layers == 0: - return bijections.Chain(flows) + vectorize = jax.tree.map(lambda leaf: eqx.is_array(leaf), base) + vectorize = eqx.tree_at( + lambda tree: tree.bijections[0].conditioner.layers[1].layers[0], + pytree=vectorize, + replace=False, + ) + vectorize = eqx.tree_at( + lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1], + pytree=vectorize, + replace=False, + ) + + return Scan( + vectorized(keys, mask, embed, embed_back), + filter_spec=vectorize, + ) + +def make_flow_loop( + key, + n_dim, + *, + zero_init=False, + householder_layer=False, + dct_layer=False, + untransformed_dim: int | list[int | None] | None = None, + n_layers, + nn_width=None, + nn_depth=None, +): def make_layer(key, untransformed_dim: int | None, permutation=None): key, key_couple, key_permute, key_hh = jax.random.split(key, 4) @@ -837,7 +1196,97 @@ def add_default_permute(bijection, dim, key): bijection = bijections.Chain(layers) - return bijections.Chain([bijection, *flows]) + return bijection + + +def make_flow( + seed, + positions, + gradients, + *, + zero_init=False, + householder_layer=False, + dct_layer=False, + untransformed_dim: int | list[int | None] | None = None, + n_layers, + nn_width=None, + nn_depth=None, + n_embed=None, + n_deembed=None, + kind="subset", +): + positions = np.array(positions) + gradients = np.array(gradients) + + if len(positions) == 0: + return + + n_draws, n_dim = positions.shape + + if n_dim < 2: + n_layers = 0 + + assert positions.shape == gradients.shape + + if n_draws == 0: + raise ValueError("No draws") + elif n_draws == 1: + assert np.all(gradients != 0) + diag = np.clip(1 / jnp.sqrt(jnp.abs(gradients[0])), 1e-5, 1e5) + assert np.isfinite(diag).all() + mean = jnp.zeros_like(diag) + else: + pos_std = np.clip(positions.std(0), 1e-8, 1e8) + grad_std = np.clip(gradients.std(0), 1e-8, 1e8) + diag = jnp.sqrt(pos_std / grad_std) + mean = positions.mean(0) + gradients.mean(0) * diag * diag + + key = jax.random.PRNGKey(seed % (2**63)) + + diag_param = Parameterize( + lambda x: x + jnp.sqrt(1 + x**2), + (diag**2 - 1) / (2 * diag), + ) + diag_affine = bijections.Affine(mean, diag) + diag_affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=diag_affine, + replace=diag_param, + ) + + flows = [ + diag_affine, + ] + + if n_layers == 0: + return bijections.Chain(flows) + + if kind == "subset": + inner = make_flow_loop( + key, + n_dim, + zero_init=zero_init, + householder_layer=householder_layer, + dct_layer=dct_layer, + untransformed_dim=untransformed_dim, + n_layers=n_layers, + nn_width=nn_width, + nn_depth=nn_depth, + ) + elif kind == "masked": + inner = make_flow_scan( + key, + n_dim, + zero_init=zero_init, + n_layers=n_layers, + nn_width=nn_width, + nn_depth=nn_depth, + n_embed=n_embed, + n_deembed=n_deembed, + ) + else: + raise ValueError(f"Unknown flow kind: {kind}") + return bijections.Chain([inner, *flows]) def extend_flow( diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index 4cabf2a..fd1e49d 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -893,13 +893,18 @@ def make_transform_adapter( gamma=None, log_inside_batch=False, initial_skip=120, - extension_windows=[], + extension_windows=None, extend_dct=False, extension_var_count=4, extension_var_trafo_count=2, debug_save_bijection=False, make_optimizer=None, + coupling_type="masked", + n_embed=None, + n_deembed=None, ): + if extension_windows is None: + extension_windows = [] return partial( TransformAdapter, verbose=verbose, @@ -908,7 +913,11 @@ def make_transform_adapter( make_flow, householder_layer=householder_layer, dct_layer=dct_layer, + nn_depth=nn_depth, nn_width=nn_width, + n_embed=n_embed, + n_deembed=n_deembed, + kind=coupling_type, ), show_progress=show_progress, num_diag_windows=num_diag_windows, @@ -927,4 +936,5 @@ def make_transform_adapter( extension_var_trafo_count=extension_var_trafo_count, debug_save_bijection=debug_save_bijection, make_optimizer=make_optimizer, + num_layers=num_layers, ) diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 9cb7daa..22c297f 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -270,7 +270,8 @@ def test_pymc_var_names(backend, gradient_backend): @pytest.mark.pymc @pytest.mark.flow -def test_normalizing_flow(): +@pytest.mark.parametrize("kind", ["masked", "subset"]) +def test_normalizing_flow(kind): with pm.Model() as model: pm.HalfNormal("x", shape=2) @@ -279,6 +280,7 @@ def test_normalizing_flow(): ).with_transform_adapt( num_diag_windows=6, verbose=True, + coupling_type=kind, ) trace = nutpie.sample( compiled, From e97922691c0e42ba0f267244a4c8db1c72f979cf Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 12 Mar 2025 19:06:31 +0100 Subject: [PATCH 02/10] feat: expose static trajectory length in nuts --- src/wrapper.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/wrapper.rs b/src/wrapper.rs index bb1a523..f29620c 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -629,6 +629,31 @@ impl PyNutsSettings { } Ok(()) } + + #[getter] + fn check_turning(&self) -> Result { + match &self.inner { + Settings::LowRank(inner) => Ok(inner.check_turning), + Settings::Diag(inner) => Ok(inner.check_turning), + Settings::Transforming(inner) => Ok(inner.check_turning), + } + } + + #[setter(check_turning)] + fn set_check_turning(&mut self, val: bool) -> Result<()> { + match &mut self.inner { + Settings::LowRank(inner) => { + inner.check_turning = val; + } + Settings::Diag(inner) => { + inner.check_turning = val; + } + Settings::Transforming(inner) => { + inner.check_turning = val; + } + } + Ok(()) + } } pub(crate) enum SamplerState { From 7723f4300ba7afa9cf07cc9adc308abf5912168d Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 12 Mar 2025 19:48:30 +0100 Subject: [PATCH 03/10] feat: make mvscale layer optional --- python/nutpie/normalizing_flow.py | 10 ++++++++-- python/nutpie/transform_adapter.py | 2 ++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/nutpie/normalizing_flow.py b/python/nutpie/normalizing_flow.py index 9d043a9..e2ff564 100644 --- a/python/nutpie/normalizing_flow.py +++ b/python/nutpie/normalizing_flow.py @@ -969,6 +969,7 @@ def make_flow_scan( nn_depth=None, n_embed=None, n_deembed=None, + mvscale=False, ): dim = n_dim @@ -1060,8 +1061,11 @@ def make_layer(key, mask, embed, embed_back): coupling, ) - mvscale = make_mvscale(key4, dim) - return bijections.Chain([coupling, mvscale]) + if mvscale: + scale = make_mvscale(key4, dim) + return bijections.Chain([coupling, scale]) + else: + return bijections.Chain([coupling]) keys = jax.random.split(key, n_layers) @@ -1214,6 +1218,7 @@ def make_flow( n_embed=None, n_deembed=None, kind="subset", + mvscale=False, ): positions = np.array(positions) gradients = np.array(gradients) @@ -1283,6 +1288,7 @@ def make_flow( nn_depth=nn_depth, n_embed=n_embed, n_deembed=n_deembed, + mvscale=mvscale, ) else: raise ValueError(f"Unknown flow kind: {kind}") diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index fd1e49d..3965067 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -900,6 +900,7 @@ def make_transform_adapter( debug_save_bijection=False, make_optimizer=None, coupling_type="masked", + mvscale_layer=False, n_embed=None, n_deembed=None, ): @@ -917,6 +918,7 @@ def make_transform_adapter( nn_width=nn_width, n_embed=n_embed, n_deembed=n_deembed, + mvscale=mvscale_layer, kind=coupling_type, ), show_progress=show_progress, From 32985e86a510e9fab68aea0516816d1fa8af7b64 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 12 Mar 2025 19:49:42 +0100 Subject: [PATCH 04/10] chore: update dependencies --- Cargo.lock | 75 +++++++++++++++++++++++++++--------------------------- Cargo.toml | 6 ++--- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5c86861..4876f21 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -255,9 +255,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64ct" -version = "1.6.0" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +checksum = "bb97d56060ee67d285efb8001fec9d2a4c710c32efd2e14b5cbb5ba71930fc2d" [[package]] name = "bindgen" @@ -459,18 +459,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.31" +version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "027bb0d98429ae334a8698531da7077bdf906419543a35a55c2cb1b66437d767" +checksum = "6088f3ae8c3608d19260cd7445411865a485688711b78b5be70d78cd96136f83" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.31" +version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5589e0cba072e0f3d23791efac0fd8627b49c829c196a492e88168e6a669d863" +checksum = "22a7ef7f676155edfb82daa97f99441f3ebf4a58d5e32f295a56259f1b6facc8" dependencies = [ "anstyle", "clap_lex", @@ -653,9 +653,9 @@ dependencies = [ [[package]] name = "either" -version = "1.14.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7914353092ddf589ad78f25c5c1c21b7f80b0ff8621e7c814c3485b5306da9d" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "encode_unicode" @@ -963,9 +963,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +checksum = "fbd780fe5cc30f81464441920d82ac8740e2e46b29a6fad543ddd075229ce37e" [[package]] name = "hmac" @@ -1029,9 +1029,9 @@ dependencies = [ [[package]] name = "is-terminal" -version = "0.4.15" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ "hermit-abi", "libc", @@ -1162,9 +1162,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.170" +version = "0.2.171" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" +checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" [[package]] name = "libloading" @@ -1503,15 +1503,15 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.3" +version = "1.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" +checksum = "cde51589ab56b20a6f686b2c68f7a0bd6add753d697abf720d63f8db3ab7b1ad" [[package]] name = "oorandom" -version = "11.1.4" +version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "password-hash" @@ -1662,11 +1662,11 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "zerocopy 0.7.35", + "zerocopy 0.8.23", ] [[package]] @@ -1781,9 +1781,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1f1914ce909e1658d9907913b4b91947430c7d9be598b15a1912935b8c04801" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] @@ -1807,7 +1807,7 @@ checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", - "zerocopy 0.8.21", + "zerocopy 0.8.23", ] [[package]] @@ -1973,18 +1973,18 @@ checksum = "1bc711410fbe7399f390ca1c3b60ad0f53f80e95c5eb935e52268a0e2cd49acc" [[package]] name = "serde" -version = "1.0.218" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.218" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", @@ -2051,9 +2051,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.99" +version = "2.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e02e925281e18ffd9d640e234264753c43edc62d64b2d4cf898f1bc5e75f3fc2" +checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" dependencies = [ "proc-macro2", "quote", @@ -2145,9 +2145,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.38" +version = "0.3.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb041120f25f8fbe8fd2dbe4671c7c2ed74d83be2e7a77529bf7e0790ae3f472" +checksum = "dad298b01a40a23aac4580b67e3dbedb7cc8402f3592d7f49469de2ea4aecdd8" dependencies = [ "deranged", "num-conv", @@ -2450,17 +2450,16 @@ version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ - "byteorder", "zerocopy-derive 0.7.35", ] [[package]] name = "zerocopy" -version = "0.8.21" +version = "0.8.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcf01143b2dd5d134f11f545cf9f1431b13b749695cb33bcce051e7568f99478" +checksum = "fd97444d05a4328b90e75e503a34bad781f14e28a823ad3557f0750df1ebcbc6" dependencies = [ - "zerocopy-derive 0.8.21", + "zerocopy-derive 0.8.23", ] [[package]] @@ -2476,9 +2475,9 @@ dependencies = [ [[package]] name = "zerocopy-derive" -version = "0.8.21" +version = "0.8.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712c8386f4f4299382c9abee219bee7084f78fb939d88b6840fcc1320d5f6da2" +checksum = "6352c01d0edd5db859a63e2605f4ea3183ddbd15e2c4a9e7d32184df75e4f154" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index a8f2639..4754321 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,19 +29,19 @@ thiserror = "2.0.3" rand_chacha = "0.9.0" rayon = "1.10.0" # Keep arrow in sync with nuts-rs requirements -arrow = { version = "54.1.0", default-features = false, features = ["ffi"] } +arrow = { version = "54.2.0", default-features = false, features = ["ffi"] } anyhow = "1.0.72" itertools = "0.14.0" bridgestan = "2.6.1" rand_distr = "0.5.0" -smallvec = "1.13.0" +smallvec = "1.14.0" upon = { version = "0.9.0", default-features = false, features = [] } time-humanize = { version = "0.1.3", default-features = false } indicatif = "0.17.8" tch = { version = "0.19.0", optional = true } [dependencies.pyo3] -version = "0.23.4" +version = "0.23.5" features = ["extension-module", "anyhow"] [dev-dependencies] From 49dfa2ae7cca8429f79c7ef7f34ab83f6dc08e21 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 13 Mar 2025 11:39:26 +0100 Subject: [PATCH 05/10] fix: fix normalizing flows for 1d posteriors --- python/nutpie/transform_adapter.py | 2 +- tests/test_pymc.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index 3965067..5f5c8ed 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -666,7 +666,7 @@ def update(self, seed, positions, gradients, logps): if self._debug_save_bijection: _BIJECTION_TRACE.append( - (self.index, fit, (positions, gradients, logps)) + (self.index, base, (positions, gradients, logps)) ) return diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 22c297f..9a9a61e 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -299,6 +299,33 @@ def test_normalizing_flow(kind): assert kstest.pvalue > 0.01 +@pytest.mark.pymc +@pytest.mark.flow +@pytest.mark.parametrize("kind", ["masked", "subset"]) +def test_normalizing_flow_1d(kind): + with pm.Model() as model: + pm.HalfNormal("x") + + compiled = nutpie.compile_pymc_model( + model, backend="jax", gradient_backend="jax" + ).with_transform_adapt( + num_diag_windows=6, + verbose=True, + coupling_type=kind, + ) + trace = nutpie.sample( + compiled, + chains=1, + transform_adapt=True, + window_switch_freq=150, + tune=600, + seed=1, + ) + draws = trace.posterior.x.isel(chain=0) + kstest = stats.ks_1samp(draws, stats.halfnorm.cdf) + assert kstest.pvalue > 0.01 + + @pytest.mark.pymc @pytest.mark.parametrize( ("backend", "gradient_backend"), From 34ebad40720518c48329e6777dd00468e0b9e25d Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 13 Mar 2025 19:57:40 +0100 Subject: [PATCH 06/10] feat: add layer norm in normalizing flow --- python/nutpie/normalizing_flow.py | 13 ++++++++++--- python/nutpie/transform_adapter.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/python/nutpie/normalizing_flow.py b/python/nutpie/normalizing_flow.py index e2ff564..7a83c76 100644 --- a/python/nutpie/normalizing_flow.py +++ b/python/nutpie/normalizing_flow.py @@ -34,6 +34,8 @@ def _generate_sequences(k, r_vals): Returns: A NumPy boolean array of shape (comb(k, r), k) containing all sequences. """ + if k > 30: + raise ValueError("Too many sequences to enumerate.") all_sequences = [] for r in r_vals: N = math.comb(k, r) # number of sequences @@ -807,12 +809,12 @@ def __init__( ) def transform_and_log_det(self, x, condition=None): - transformer_params = self.conditioner(x) + transformer_params = self.conditioner(x.astype(jnp.float32)) transformer = self._flat_params_to_transformer(transformer_params) return transformer.transform_and_log_det(x) def inverse_and_log_det(self, y, condition=None): - transformer_params = self.conditioner(y) + transformer_params = self.conditioner(y.astype(jnp.float32)) transformer = self._flat_params_to_transformer(transformer_params) return transformer.inverse_and_log_det(y) @@ -987,7 +989,12 @@ def make_flow_scan( size = MaskedCoupling.conditioner_output_size(dim, transformer) key, key1 = jax.random.split(key) - embed = eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32) + embed = eqx.nn.Sequential( + [ + eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32), + eqx.nn.LayerNorm(shape=(n_embed,), dtype=jnp.float32), + ] + ) key, key1 = jax.random.split(key) embed_back = eqx.nn.Linear(n_deembed, size, key=key1, dtype=jnp.float32) diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index 5f5c8ed..78b4152 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -878,7 +878,7 @@ def make_transform_adapter( verbose=False, window_size=600, show_progress=False, - nn_depth=1, + nn_depth=None, nn_width=None, num_layers=9, num_diag_windows=9, From b407982d33e329db2a0d05a79637e98a0fe2dd63 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 14 Mar 2025 16:21:02 +0100 Subject: [PATCH 07/10] feat: small improvements to the masked normalizing flow --- python/nutpie/normalizing_flow.py | 109 ++++++++++++++++++++--------- python/nutpie/transform_adapter.py | 8 ++- tests/test_pymc.py | 12 ++-- 3 files changed, 87 insertions(+), 42 deletions(-) diff --git a/python/nutpie/normalizing_flow.py b/python/nutpie/normalizing_flow.py index 7a83c76..9c1c4f4 100644 --- a/python/nutpie/normalizing_flow.py +++ b/python/nutpie/normalizing_flow.py @@ -579,6 +579,28 @@ def step(carry, bijection): ) return y, log_det + def inverse_gradient_and_val( + self, + y: Array, + y_grad: Array, + y_logp: Array, + condition: Array | None = None, + ) -> tuple[Array, Array, Array]: + def step(carry, bijection): + from nutpie.transform_adapter import inverse_gradient_and_val + + carry = inverse_gradient_and_val(bijection, *carry) + return (carry, None) + + (y, y_grad, y_logp), _ = _filter_scan( + step, + (y, y_grad, y_logp), + self.bijection, + reverse=True, + filter_spec=self.filter_spec, + ) + return y, y_grad, y_logp + @property def shape(self): return self.bijection.shape @@ -961,6 +983,16 @@ def make_mlp(out_size): ) +class Add(eqx.Module): + bias: Array + + def __init__(self, bias): + self.bias = bias + + def __call__(self, x: Array, *, key=None) -> Array: + return x + self.bias + + def make_flow_scan( key, n_dim, @@ -984,15 +1016,47 @@ def make_flow_scan( if nn_depth is None: nn_depth = 1 + def make_transformer(): + elemwises = [] + # loc = bijections.Loc(jnp.zeros(())) + # elemwises.append(loc) + + for loc in [0.0]: + scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + + affine = AsymmetricAffine( + jnp.zeros(()) + loc, + jnp.ones(()), + jnp.ones(()), + ) + + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=affine, + replace=scale, + ) + affine = eqx.tree_at( + where=lambda aff: aff.theta, + pytree=affine, + replace=theta, + ) + elemwises.append(bijections.Invert(affine)) + + if len(elemwises) == 1: + return elemwises[0] + return bijections.Chain(elemwises) + # Just to get at the size - transformer = AsymmetricAffine() + transformer = make_transformer() size = MaskedCoupling.conditioner_output_size(dim, transformer) key, key1 = jax.random.split(key) embed = eqx.nn.Sequential( [ eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32), - eqx.nn.LayerNorm(shape=(n_embed,), dtype=jnp.float32), + # Activation(_NN_ACTIVATION), + # eqx.nn.LayerNorm(shape=(n_embed,), dtype=jnp.float32), ] ) key, key1 = jax.random.split(key) @@ -1005,37 +1069,15 @@ def make_flow_scan( for i in range(len(mask)): mask[i, order[i, : counts[i]]] = True - def make_transformer(): - scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) - theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) - - affine = AsymmetricAffine( - jnp.zeros(()), - jnp.ones(()), - jnp.ones(()), - ) - - affine = eqx.tree_at( - where=lambda aff: aff.scale, - pytree=affine, - replace=scale, - ) - affine = eqx.tree_at( - where=lambda aff: aff.theta, - pytree=affine, - replace=theta, - ) - - return bijections.Invert(affine) - def make_mvscale(key, n_dim): params = jax.random.normal(key, (n_dim,)) params = params / jnp.linalg.norm(params) return MvScale(params) def make_layer(key, mask, embed, embed_back): - key1, key2, key3, key4 = jax.random.split(key, 4) + key1, key2, key3, key4, key5 = jax.random.split(key, 5) transformer = make_transformer() + bias = Add(jax.random.normal(key5, (size,)) * 0.01) conditioner = eqx.nn.Sequential( [ @@ -1049,7 +1091,12 @@ def make_layer(key, mask, embed, embed_back): dtype=jnp.float32, activation=_NN_ACTIVATION, ), - embed_back, + eqx.nn.Sequential( + [ + embed_back, + bias, + ] + ), ] ) @@ -1083,7 +1130,7 @@ def make_layer(key, mask, embed, embed_back): replace=None, ) out_axes = eqx.tree_at( - lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1], + lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1].layers[0], pytree=out_axes, replace=None, ) @@ -1100,7 +1147,7 @@ def make_layer(key, mask, embed, embed_back): replace=False, ) vectorize = eqx.tree_at( - lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1], + lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1].layers[0], pytree=vectorize, replace=False, ) @@ -1234,10 +1281,6 @@ def make_flow( return n_draws, n_dim = positions.shape - - if n_dim < 2: - n_layers = 0 - assert positions.shape == gradients.shape if n_draws == 0: diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index 78b4152..0ce30c9 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -29,7 +29,7 @@ import optax from paramax import unwrap, NonTrainable -from nutpie.normalizing_flow import Coupling, extend_flow, make_flow +from nutpie.normalizing_flow import Coupling, Scan, extend_flow, make_flow import tqdm _BIJECTION_TRACE = [] @@ -241,6 +241,8 @@ def inner(bijection, y, y_grad, y_logp): axis_size=bijection.axis_size, )(bijection.bijection, draw, grad, jnp.zeros(())) return y, y_grad, jnp.sum(log_det) + logp + elif isinstance(bijection, Scan): + return bijection.inverse_gradient_and_val(draw, grad, logp) elif isinstance(bijection, bijections.Sandwich): draw, grad, logp = inverse_gradient_and_val( bijections.Invert(bijection.outer), draw, grad, logp @@ -880,8 +882,8 @@ def make_transform_adapter( show_progress=False, nn_depth=None, nn_width=None, - num_layers=9, - num_diag_windows=9, + num_layers=20, + num_diag_windows=6, learning_rate=5e-4, untransformed_dim=None, zero_init=True, diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 9a9a61e..2135194 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -278,17 +278,17 @@ def test_normalizing_flow(kind): compiled = nutpie.compile_pymc_model( model, backend="jax", gradient_backend="jax" ).with_transform_adapt( - num_diag_windows=6, verbose=True, coupling_type=kind, + num_layers=2, ) trace = nutpie.sample( compiled, chains=1, transform_adapt=True, - window_switch_freq=150, - tune=600, + window_switch_freq=128, seed=1, + draws=2000, ) draws = trace.posterior.x.isel(x_dim_0=0, chain=0) kstest = stats.ks_1samp(draws, stats.halfnorm.cdf) @@ -309,17 +309,17 @@ def test_normalizing_flow_1d(kind): compiled = nutpie.compile_pymc_model( model, backend="jax", gradient_backend="jax" ).with_transform_adapt( - num_diag_windows=6, verbose=True, coupling_type=kind, + num_layers=2, ) trace = nutpie.sample( compiled, chains=1, transform_adapt=True, - window_switch_freq=150, - tune=600, + window_switch_freq=128, seed=1, + draws=2000, ) draws = trace.posterior.x.isel(chain=0) kstest = stats.ks_1samp(draws, stats.halfnorm.cdf) From de2cd950975f74b5b67ba729c85f35e071681a4f Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 14 Mar 2025 19:42:17 +0100 Subject: [PATCH 08/10] ci: split some test into sections with optional deps --- .github/workflows/ci.yml | 19 ++++++++++++++----- tests/test_pymc.py | 4 ++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cca5f5f..c252dad 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -182,9 +182,14 @@ jobs: set -e python3 -m venv .venv source .venv/Scripts/activate - uv pip install "nutpie[all]" --find-links dist --force-reinstall + uv pip install "nutpie[stan]" --find-links dist --force-reinstall uv pip install pytest pytest-timeout - pytest + pytest -m "stan and not flow" + uv pip install "nutpie[pymc]" --find-links dist --force-reinstall + uv pip install jax + pytest -m "pymc and not flow" + uv pip install "nutpie[all]" --find-links dist --force-reinstall + pytest -m flow macos: runs-on: ${{ matrix.platform.runner }} @@ -226,10 +231,14 @@ jobs: set -e python3 -m venv .venv source .venv/bin/activate - uv pip install 'nutpie[all]' --find-links dist --force-reinstall + uv pip install 'nutpie[stan]' --find-links dist --force-reinstall uv pip install pytest pytest-timeout - pytest -m "not (flow and stan)" # The stan tests seem to run out of memory on macOS? - + pytest -m "stan and not flow" + uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall + uv pip install jax + pytest -m "pymc and not flow" + uv pip install 'nutpie[all]' --find-links dist --force-reinstall + pytest -m flow sdist: runs-on: ubuntu-latest steps: diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 2135194..6a0f242 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -280,7 +280,7 @@ def test_normalizing_flow(kind): ).with_transform_adapt( verbose=True, coupling_type=kind, - num_layers=2, + num_layers=4, ) trace = nutpie.sample( compiled, @@ -311,7 +311,7 @@ def test_normalizing_flow_1d(kind): ).with_transform_adapt( verbose=True, coupling_type=kind, - num_layers=2, + num_layers=4, ) trace = nutpie.sample( compiled, From d85992f3e77d6840b475d842b31c88be2445c652 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 18 Mar 2025 15:34:09 +0100 Subject: [PATCH 09/10] fix: better initialization of masked flows --- python/nutpie/normalizing_flow.py | 40 +++++++++++++++++------------- python/nutpie/transform_adapter.py | 30 ++++------------------ tests/test_pymc.py | 4 +-- 3 files changed, 29 insertions(+), 45 deletions(-) diff --git a/python/nutpie/normalizing_flow.py b/python/nutpie/normalizing_flow.py index 9c1c4f4..ea6981b 100644 --- a/python/nutpie/normalizing_flow.py +++ b/python/nutpie/normalizing_flow.py @@ -1054,13 +1054,19 @@ def make_transformer(): key, key1 = jax.random.split(key) embed = eqx.nn.Sequential( [ - eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32), + eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32, use_bias=True), # Activation(_NN_ACTIVATION), # eqx.nn.LayerNorm(shape=(n_embed,), dtype=jnp.float32), ] ) key, key1 = jax.random.split(key) - embed_back = eqx.nn.Linear(n_deembed, size, key=key1, dtype=jnp.float32) + embed_back = eqx.nn.Linear( + n_deembed, size, key=key1, dtype=jnp.float32, use_bias=True + ) + embed_back = jax.tree_util.tree_map( + lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, + embed_back, + ) rng = np.random.default_rng(42) # TODO order, counts = _generate_permutations(rng, dim, n_layers) @@ -1077,20 +1083,25 @@ def make_mvscale(key, n_dim): def make_layer(key, mask, embed, embed_back): key1, key2, key3, key4, key5 = jax.random.split(key, 5) transformer = make_transformer() - bias = Add(jax.random.normal(key5, (size,)) * 0.01) + bias = Add(jax.random.normal(key5, (size,)) * 0.001) + inner = eqx.nn.MLP( + n_embed, + n_deembed, + width_size=nn_width, + depth=nn_depth, + key=key2, + dtype=jnp.float32, + activation=_NN_ACTIVATION, + ) + inner = jax.tree_util.tree_map( + lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, + inner, + ) conditioner = eqx.nn.Sequential( [ embed, - eqx.nn.MLP( - n_embed, - n_deembed, - width_size=nn_width, - depth=nn_depth, - key=key2, - dtype=jnp.float32, - activation=_NN_ACTIVATION, - ), + inner, eqx.nn.Sequential( [ embed_back, @@ -1110,11 +1121,6 @@ def make_layer(key, mask, embed, embed_back): nn_depth=nn_depth, ) - coupling = jax.tree_util.tree_map( - lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, - coupling, - ) - if mvscale: scale = make_mvscale(key4, dim) return bijections.Chain([coupling, scale]) diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index 0ce30c9..64b8880 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -112,24 +112,14 @@ def fit_to_data( for i in loop: # Shuffle data - start = time.time() key, *subkeys = jr.split(key, 3) train_data = [jr.permutation(subkeys[0], a) for a in train_data] val_data = [jr.permutation(subkeys[1], a) for a in val_data] - if verbose and i == 0: - print("shuffle timing:", time.time() - start) - - start = time.time() key, subkey = jr.split(key) batches = get_batches(train_data, batch_size) batch_losses = [] - if verbose and i == 0: - print("batch timing:", time.time() - start) - - start = time.time() - if True: for batch in zip(*batches, strict=True): key, subkey = jr.split(key) @@ -156,10 +146,6 @@ def fit_to_data( losses["train"].append((sum(batch_losses) / len(batch_losses)).item()) - if verbose and i == 0: - print("step timing:", time.time() - start) - - start = time.time() # Val epoch batch_losses = [] for batch in zip(*get_batches(val_data, batch_size), strict=True): @@ -168,9 +154,6 @@ def fit_to_data( batch_losses.append(loss_i) losses["val"].append(sum(batch_losses) / len(batch_losses)) - if verbose and i == 0: - print("val timing:", time.time() - start) - loop.set_postfix({k: v[-1] for k, v in losses.items()}) if losses["val"][-1] == min(losses["val"]): best_params = params @@ -228,7 +211,7 @@ def inverse_gradient_and_val(bijection, draw, grad, logp): ) elif isinstance(bijection, bijections.Affine): draw, logdet = bijection.inverse_and_log_det(draw) - grad = grad * bijection.scale + grad = grad * unwrap(bijection.scale) return (draw, grad, logp - logdet) elif isinstance(bijection, bijections.Vmap): @@ -710,12 +693,9 @@ def update(self, seed, positions, gradients, logps): ) params, static = eqx.partition(flow, eqx.is_inexact_array) - start = time.time() new_loss = self._loss_fn( params, static, positions[-128:], gradients[-128:], logps[-128:] ) - if self._verbose: - print("new loss function time: ", time.time() - start) if self._verbose: print(f"Chain {self._chain}: New loss {new_loss}, old loss {old_loss}") @@ -903,8 +883,8 @@ def make_transform_adapter( make_optimizer=None, coupling_type="masked", mvscale_layer=False, - n_embed=None, - n_deembed=None, + num_project=None, + num_embed=None, ): if extension_windows is None: extension_windows = [] @@ -918,8 +898,8 @@ def make_transform_adapter( dct_layer=dct_layer, nn_depth=nn_depth, nn_width=nn_width, - n_embed=n_embed, - n_deembed=n_deembed, + n_embed=num_project, + n_deembed=num_embed, mvscale=mvscale_layer, kind=coupling_type, ), diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 6a0f242..dd8274c 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -321,9 +321,7 @@ def test_normalizing_flow_1d(kind): seed=1, draws=2000, ) - draws = trace.posterior.x.isel(chain=0) - kstest = stats.ks_1samp(draws, stats.halfnorm.cdf) - assert kstest.pvalue > 0.01 + assert float(trace.sample_stats.fisher_distance.mean()) < 0.1 @pytest.mark.pymc From 8506ccf34599677e4992479b2c0151093a04f160 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 18 Mar 2025 15:46:20 +0100 Subject: [PATCH 10/10] chore(release): prepare release --- CHANGELOG.md | 32 ++++++++++++++++++ Cargo.lock | 95 ++++++++++++++++++++++++++++------------------------ Cargo.toml | 8 ++--- 3 files changed, 87 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 431f749..8eab10f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,38 @@ All notable changes to this project will be documented in this file. +## [0.14.3] - 2025-03-18 + +### Bug Fixes + +- Fix normalizing flows for 1d posteriors (Adrian Seyboldt) + +- Better initialization of masked flows (Adrian Seyboldt) + + +### Features + +- Add masked coupling flow (Adrian Seyboldt) + +- Expose static trajectory length in nuts (Adrian Seyboldt) + +- Make mvscale layer optional (Adrian Seyboldt) + +- Add layer norm in normalizing flow (Adrian Seyboldt) + +- Small improvements to the masked normalizing flow (Adrian Seyboldt) + + +### Miscellaneous Tasks + +- Update dependencies (Adrian Seyboldt) + + +### Ci + +- Split some test into sections with optional deps (Adrian Seyboldt) + + ## [0.14.2] - 2025-03-06 ### Bug Fixes diff --git a/Cargo.lock b/Cargo.lock index 4876f21..22abddc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -255,9 +255,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64ct" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb97d56060ee67d285efb8001fec9d2a4c710c32efd2e14b5cbb5ba71930fc2d" +checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3" [[package]] name = "bindgen" @@ -624,9 +624,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.11" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" dependencies = [ "powerfmt", ] @@ -717,9 +717,9 @@ dependencies = [ [[package]] name = "faer" -version = "0.21.7" +version = "0.21.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d671941ab57443f46ebe3f153a9fc3ed6cce777926c14e5fdf5da178a35ea476" +checksum = "ebe9ac2a073e05ca749eeea503fae16a91440b20d2e92b6fc6f6c6919b9964eb" dependencies = [ "bytemuck", "dyn-stack", @@ -750,9 +750,9 @@ dependencies = [ [[package]] name = "faer-traits" -version = "0.21.0" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2d0172aefb5f869561e558d5390657f1aa98ca3c51a09be69a4687064ebfb9a" +checksum = "1430e111b20872c7eaa82c7ada071bff1c3e3ac09cc6f4df676065fd2d41eb62" dependencies = [ "bytemuck", "dyn-stack", @@ -921,14 +921,14 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0" dependencies = [ "cfg-if", "libc", - "wasi 0.13.3+wasi-0.2.2", - "windows-targets", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", ] [[package]] @@ -939,9 +939,9 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "half" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +checksum = "7db2ff139bba50379da6aa0766b52fdcb62cb5b263009b09ed58ba604e14bbd1" dependencies = [ "bytemuck", "cfg-if", @@ -1445,9 +1445,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "numpy" -version = "0.23.0" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b94caae805f998a07d33af06e6a3891e38556051b8045c615470a71590e13e78" +checksum = "a7cfbf3f0feededcaa4d289fe3079b03659e85c5b5a177f4ba6fb01ab4fb3e39" dependencies = [ "libc", "ndarray", @@ -1455,12 +1455,13 @@ dependencies = [ "num-integer", "num-traits", "pyo3", + "pyo3-build-config", "rustc-hash", ] [[package]] name = "nutpie" -version = "0.14.2" +version = "0.14.3" dependencies = [ "anyhow", "arrow", @@ -1484,9 +1485,9 @@ dependencies = [ [[package]] name = "nuts-rs" -version = "0.15.0" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10e87924d332fce1202087bc67db7ed8f7ef9229da5ec74a5130568f5b7f6ac7" +checksum = "11d3052cf8ae044673a4bb41819943e62e43af7c1443f45c6e2f8c895e9fa994" dependencies = [ "anyhow", "arrow", @@ -1503,9 +1504,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.0" +version = "1.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde51589ab56b20a6f686b2c68f7a0bd6add753d697abf720d63f8db3ab7b1ad" +checksum = "d75b0bedcc4fe52caa0e03d9f1151a323e4aa5e2d78ba3580400cd3c9e2bc4bc" [[package]] name = "oorandom" @@ -1671,9 +1672,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.30" +version = "0.2.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1ccf34da56fc294e7d4ccf69a85992b7dfb826b7cf57bac6a70bba3494cc08a" +checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb" dependencies = [ "proc-macro2", "syn", @@ -1717,9 +1718,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.23.5" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872" +checksum = "7f1c6c3591120564d64db2261bec5f910ae454f01def849b9c22835a84695e86" dependencies = [ "anyhow", "cfg-if", @@ -1736,9 +1737,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.23.5" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb" +checksum = "e9b6c2b34cf71427ea37c7001aefbaeb85886a074795e35f161f5aecc7620a7a" dependencies = [ "once_cell", "target-lexicon", @@ -1746,9 +1747,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.23.5" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d" +checksum = "5507651906a46432cdda02cd02dd0319f6064f1374c9147c45b978621d2c3a9c" dependencies = [ "libc", "pyo3-build-config", @@ -1756,9 +1757,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.23.5" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da" +checksum = "b0d394b5b4fd8d97d48336bb0dd2aebabad39f1d294edd6bcd2cccf2eefe6f42" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -1768,9 +1769,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.23.5" +version = "0.24.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028" +checksum = "fd72da09cfa943b1080f621f024d2ef7e2773df7badd51aa30a2be1f8caa7c8e" dependencies = [ "heck", "proc-macro2", @@ -1788,6 +1789,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" + [[package]] name = "rand" version = "0.8.5" @@ -1845,7 +1852,7 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.1", + "getrandom 0.3.2", ] [[package]] @@ -2082,9 +2089,9 @@ checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" [[package]] name = "target-lexicon" -version = "0.12.16" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" +checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" [[package]] name = "tch" @@ -2145,9 +2152,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.39" +version = "0.3.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dad298b01a40a23aac4580b67e3dbedb7cc8402f3592d7f49469de2ea4aecdd8" +checksum = "9d9c75b47bdff86fa3334a3db91356b8d7d86a9b839dab7d0bdc5c3d3a077618" dependencies = [ "deranged", "num-conv", @@ -2158,9 +2165,9 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "765c97a5b985b7c11d7bc27fa927dc4fe6af3a6dfb021d28deb60d3bf51e76ef" +checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" [[package]] name = "time-humanize" @@ -2259,9 +2266,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasi" -version = "0.13.3+wasi-0.2.2" +version = "0.14.2+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" dependencies = [ "wit-bindgen-rt", ] @@ -2437,9 +2444,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "wit-bindgen-rt" -version = "0.33.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ "bitflags", ] diff --git a/Cargo.toml b/Cargo.toml index 4754321..243f88f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "nutpie" -version = "0.14.2" +version = "0.14.3" authors = [ "Adrian Seyboldt ", "PyMC Developers ", @@ -22,8 +22,8 @@ name = "_lib" crate-type = ["cdylib"] [dependencies] -nuts-rs = "0.15.0" -numpy = "0.23.0" +nuts-rs = "0.15.1" +numpy = "0.24.0" rand = "0.9.0" thiserror = "2.0.3" rand_chacha = "0.9.0" @@ -41,7 +41,7 @@ indicatif = "0.17.8" tch = { version = "0.19.0", optional = true } [dependencies.pyo3] -version = "0.23.5" +version = "0.24.0" features = ["extension-module", "anyhow"] [dev-dependencies]