diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..8f61a8e --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# SCM syntax highlighting +pixi.lock linguist-language=YAML linguist-generated=true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4420c42..2fe0449 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,13 +56,13 @@ jobs: python3 -m venv .venv source .venv/bin/activate uv pip install 'nutpie[stan]' --find-links dist --force-reinstall - uv pip install pytest pytest-timeout - pytest -m "stan and not flow" + uv pip install pytest pytest-timeout pytest-arraydiff + pytest -m "stan and not flow" --arraydiff uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall uv pip install jax - pytest -m "pymc and not flow" + pytest -m "pymc and not flow" --arraydiff uv pip install 'nutpie[all]' --find-links dist --force-reinstall - pytest -m flow + pytest -m flow --arraydiff # pyarrow doesn't currently seem to work on musllinux #musllinux: @@ -183,13 +183,13 @@ jobs: python3 -m venv .venv source .venv/Scripts/activate uv pip install "nutpie[stan]" --find-links dist --force-reinstall - uv pip install pytest pytest-timeout - pytest -m "stan and not flow" + uv pip install pytest pytest-timeout pytest-arraydiff + pytest -m "stan and not flow" --arraydiff uv pip install "nutpie[pymc]" --find-links dist --force-reinstall uv pip install jax - pytest -m "pymc and not flow" + pytest -m "pymc and not flow" --arraydiff uv pip install "nutpie[all]" --find-links dist --force-reinstall - pytest -m flow + pytest -m flow --arraydiff macos: runs-on: ${{ matrix.platform.runner }} @@ -232,13 +232,13 @@ jobs: python3 -m venv .venv source .venv/bin/activate uv pip install 'nutpie[stan]' --find-links dist --force-reinstall - uv pip install pytest pytest-timeout - pytest -m "stan and not flow" + uv pip install pytest pytest-timeout pytest-arraydiff + pytest -m "stan and not flow" --arraydiff uv pip install 'nutpie[pymc]' --find-links dist --force-reinstall uv pip install jax - pytest -m "pymc and not flow" + pytest -m "pymc and not flow" --arraydiff uv pip install 'nutpie[all]' --find-links dist --force-reinstall - pytest -m flow + pytest -m flow --arraydiff sdist: runs-on: ubuntu-latest steps: diff --git a/.gitignore b/.gitignore index bffa4f4..82b98be 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,11 @@ example-iree posteriordb .quarto docs/.quarto +Untitled* +notebooks-local +pixi.lock +pixi.toml +reports +benchmarks* +reports* +results* diff --git a/Cargo.lock b/Cargo.lock index 44faebd..e3a940c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2095,9 +2095,9 @@ checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" [[package]] name = "tch" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa1ed622c8f13b0c42f8b1afa0e5e9ccccd82ecb6c0e904120722ab52fdc5234" +checksum = "a760143efe7e4bb5b56e95d01f52ee6773bc315202e7c47db6a6429b0705a1f2" dependencies = [ "half", "lazy_static", @@ -2196,9 +2196,9 @@ dependencies = [ [[package]] name = "torch-sys" -version = "0.19.0" +version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef14f5d239e3d60f4919f536a5dfe1d4f71b27b7abf6fe6875fd3a4b22c2dcd5" +checksum = "ad6fa4ac5662b84047081375b007f102d4968d5a0191f567a9776294445af9ac" dependencies = [ "anyhow", "cc", diff --git a/Cargo.toml b/Cargo.toml index f07fbec..dfc640f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ smallvec = "1.14.0" upon = { version = "0.9.0", default-features = false, features = [] } time-humanize = { version = "0.1.3", default-features = false } indicatif = "0.17.8" -tch = { version = "0.19.0", optional = true } +tch = { version = "0.20.0", optional = true } [dependencies.pyo3] version = "0.24.1" @@ -50,7 +50,7 @@ criterion = "0.5.1" [profile.release] lto = "fat" codegen-units = 1 -opt-level = 2 +opt-level = 3 [profile.bench] debug = true diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000..075b254 --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1 @@ +/.quarto/ diff --git a/pyproject.toml b/pyproject.toml index bd268d2..c53f390 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "nutpie" description = "Sample Stan or PyMC models" authors = [{ name = "PyMC Developers", email = "pymc.devs@gmail.com" }] readme = "README.md" -requires-python = ">=3.10,<3.14" +requires-python = ">=3.10" license = { text = "MIT" } classifiers = [ "Programming Language :: Rust", @@ -41,6 +41,7 @@ dev = [ "flowjax >= 17.0.2", "pytest", "pytest-timeout", + "pytest-arraydiff", ] all = [ "bridgestan >= 2.6.1", diff --git a/python/nutpie/normalizing_flow.py b/python/nutpie/normalizing_flow.py index ea6981b..792ffb6 100644 --- a/python/nutpie/normalizing_flow.py +++ b/python/nutpie/normalizing_flow.py @@ -18,9 +18,6 @@ from paramax.wrappers import AbstractUnwrappable -_NN_ACTIVATION = jax.nn.gelu - - def _generate_sequences(k, r_vals): """ Generate all binary sequences of length k with exactly r 1's. @@ -121,7 +118,6 @@ def _generate_permutations(rng, n_dim, n_layers, max_run=3): ) rng.shuffle(valid_sequences, axis=0) is_in_first = valid_sequences[:n_dim] - rng = np.random.default_rng(42) permutations = (~is_in_first).argsort(axis=0, kind="stable") return permutations.T, is_in_first.sum(0) @@ -425,6 +421,58 @@ def inverse(y, mu, sigma, theta): # return x, jnp.log(jac) +class Householder(AbstractBijection): + """A Householder reflection. + + A linear transformation reflecting vectors across a hyperplane defined by a normal + vector (params). The transformation is its own inverse and volume-preserving + (determinant = -1). Given a unit vector :math:`v`, the transformation is + :math:`y = x - 2(x^T v)v`. + + It is often desirable to stack multiple such transforms (e.g. up to the + dimensionality of the data): + + .. doctest:: + + >>> from flowjax.bijections import Householder, Scan + >>> import jax.random as jr + >>> import equinox as eqx + >>> import jax.numpy as jnp + + >>> dim = 5 + >>> keys = jr.split(jr.key(0), dim) + >>> householder_stack = Scan( + ... eqx.filter_vmap(lambda key: Householder(jr.normal(key, dim)))(keys) + ... ) + + Args: + params: Normal vector defining the reflection hyperplane. The vector is + normalized in the transformation, so scaling params will have no effect + on the bijection. + """ + + shape: tuple[int, ...] + params: Array + cond_shape = None + + def __init__(self, params: ArrayLike): + params = arraylike_to_array(params) + if params.ndim != 1: + raise ValueError("params must be a vector.") + self.shape = params.shape + self.params = params + + def _householder(self, x: Array) -> Array: + unit_vec = self.params / jnp.linalg.norm(self.params) + return x - 2 * unit_vec * (x @ unit_vec) + + def transform_and_log_det(self, x: jnp.ndarray, condition: Array | None = None): + return self._householder(x), jnp.zeros(()) + + def inverse_and_log_det(self, y: Array, condition: Array | None = None): + return self._householder(y), jnp.zeros(()) + + class MvScale(bijections.AbstractBijection): shape: tuple[int, ...] params: Array @@ -831,12 +879,12 @@ def __init__( ) def transform_and_log_det(self, x, condition=None): - transformer_params = self.conditioner(x.astype(jnp.float32)) + transformer_params = self.conditioner(x.astype(jnp.float32)).astype(jnp.float64) transformer = self._flat_params_to_transformer(transformer_params) return transformer.transform_and_log_det(x) def inverse_and_log_det(self, y, condition=None): - transformer_params = self.conditioner(y.astype(jnp.float32)) + transformer_params = self.conditioner(y.astype(jnp.float32)).astype(jnp.float64) transformer = self._flat_params_to_transformer(transformer_params) return transformer.inverse_and_log_det(y) @@ -875,8 +923,9 @@ def make_single_hh(key, idx): def make_hh(key, n_dim, size, randomize_base=False): def make_single_hh(key, idx): key1, key2 = jax.random.split(key) - params = jax.random.normal(key1, (n_dim,)) * 1e-2 - return bijections.Householder(params, base_index=idx) + params = jax.random.normal(key1, (n_dim,)) * 1e-3 + params = params.at[idx].set(1.0) + return Householder(params) keys = jax.random.split(key, size) @@ -886,9 +935,18 @@ def make_single_hh(key, idx): else: indices = [val % n_dim for val in range(size)] - return bijections.Chain( + if size == 1: + return make_single_hh(keys[0], indices[0]) + + make_single_hh_vec = eqx.filter_vmap(make_single_hh, axis_size=size)(keys, indices) + return bijections.Scan(make_single_hh_vec) + + chain = bijections.Chain( [make_single_hh(key, idx) for key, idx in zip(keys, indices)] ) + if len(chain.bijections) == 1: + return chain.bijections[0] + return chain def make_elemwise_trafo(key, n_dim, *, count=1, vmap=True): @@ -933,7 +991,9 @@ def make(key): return make(key) -def make_coupling(key, dim, n_untransformed, *, inner_mvscale=False, **kwargs): +def make_coupling( + key, dim, n_untransformed, *, activation, inner_mvscale=False, **kwargs +): n_transformed = dim - n_untransformed nn_width = kwargs.get("nn_width", None) @@ -970,7 +1030,7 @@ def make_mlp(out_size): depth=nn_depth, key=key, dtype=jnp.float32, - activation=_NN_ACTIVATION, + activation=activation, ) return Coupling( @@ -993,34 +1053,384 @@ def __call__(self, x: Array, *, key=None) -> Array: return x + self.bias -def make_flow_scan( - key, - n_dim, - *, - zero_init=False, - n_layers, - nn_width=None, - nn_depth=None, - n_embed=None, - n_deembed=None, - mvscale=False, -): - dim = n_dim +class UnconstrainedAffine(bijections.AbstractBijection): + loc: Array + unconstrained_scale: Array + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None = None - if nn_width is None: - nn_width = 32 - if n_embed is None: - n_embed = 2 * nn_width - if n_deembed is None: - n_deembed = 2 * nn_width - if nn_depth is None: - nn_depth = 1 + def __init__(self, loc, unconstrained_scale): + self.loc = loc + self.unconstrained_scale = unconstrained_scale + self.shape = loc.shape + + def transform_and_log_det( + self, x: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + scale = self.unconstrained_scale + jnp.sqrt(1 + self.unconstrained_scale**2) + y = self.loc + scale * x + return y, jnp.sum(jnp.log(scale)) + + def inverse_and_log_det( + self, y: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + scale = self.unconstrained_scale + jnp.sqrt(1 + self.unconstrained_scale**2) + x = (y - self.loc) / scale + return x, -jnp.sum(jnp.log(scale)) + + +def pairwise_rotation(x, thetas): + """ + Applies a rotation to each consecutive pair in x. + + Parameters: + x (jnp.ndarray): 1D array containing the values to be rotated. + thetas (jnp.ndarray): 1D array of angles (in radians) for each pair. + Length must equal x.shape[0] // 2. + + Returns: + jnp.ndarray: The rotated vector where each pair (x[2*i], x[2*i+1]) + is rotated by the corresponding angle thetas[i]. If x has + an odd length, the last element is unchanged. + """ + n = x.shape[0] + num_pairs = n // 2 + + # Reshape the first 2*num_pairs elements into pairs + x_pairs = x[: num_pairs * 2].reshape(num_pairs, 2) + + # Compute cosine and sine of each rotation angle for the pairs + cos_thetas = jnp.cos(thetas) + sin_thetas = jnp.sin(thetas) + + # Apply the rotation to each pair without forming a 2x2 matrix: + # rotated_x = x * cos(theta) - y * sin(theta) + # rotated_y = x * sin(theta) + y * cos(theta) + rotated_x = x_pairs[:, 0] * cos_thetas - x_pairs[:, 1] * sin_thetas + rotated_y = x_pairs[:, 0] * sin_thetas + x_pairs[:, 1] * cos_thetas + + # Stack the rotated coordinates and flatten back into a 1D array + rotated_pairs = jnp.stack([rotated_x, rotated_y], axis=1) + y_rotated = rotated_pairs.reshape(-1) + + # If x has an odd length, append the last unchanged element. + if n % 2 == 1: + y = jnp.concatenate([y_rotated, x[num_pairs * 2 :]]) + else: + y = y_rotated + + return y + + +class Rotations(bijections.AbstractBijection): + theta: Array + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None = None + + def __init__(self, key, ndim): + n_rotations = ndim // 2 + self.theta = jax.random.normal(key, (n_rotations,)) / 10 + self.shape = (ndim,) + + def transform_and_log_det( + self, x: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + return pairwise_rotation(x, self.theta), jnp.zeros(()) + + def inverse_and_log_det( + self, y: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + return pairwise_rotation(y, -self.theta), jnp.zeros(()) + + +class Orthogonal(bijections.AbstractBijection): + theta: Array + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None = None + + def __init__(self, key, ndim, k): + self.theta = jax.random.normal(key, (ndim, k)) + self.shape = (ndim,) + + def transform_and_log_det( + self, x: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + q, _ = jnp.linalg.qr(self.theta, mode="reduced") + q = q.T + qx = q @ x + return x - 2 * q.T @ qx, jnp.zeros(()) + + def inverse_and_log_det( + self, y: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + q, _ = jnp.linalg.qr(self.theta, mode="reduced") + q = q.T + qy = q @ y + return y - 2 * q.T @ qy, jnp.zeros(()) + + +class Planar(bijections.AbstractBijection): + u: Array + + # One dimensional transformation (assumed to operate on shape (..., 1)) + inner: bijections.AbstractBijection + + # The full input shape (e.g. (ndim,)) + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None = None + + def __init__(self, key, inner, ndim: int): + self.inner = inner + self.shape = (ndim,) + # Initialize u as a random vector of shape (ndim,) + self.u = jax.random.normal(key, (ndim,)) + + def transform_and_log_det( + self, x: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + # Normalize u to have unit norm. + u = self.u / jnp.linalg.norm(self.u) + # Compute the scalar projection d = . + d = x @ u + # Apply the one-dimensional bijection to d. + f_d, logdet_inner = self.inner.transform_and_log_det(d, condition) + # Lift the 1D transformation to the full space: + # y = x + (f(d) - d) * u + y = x + (f_d - d) * u + return y, logdet_inner + + def inverse_and_log_det( + self, y: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + # Normalize u. + u = self.u / jnp.linalg.norm(self.u) + # Compute the projected coordinate from y. + d_y = y @ u + # Invert the inner bijection to get d. + d, logdet_inner = self.inner.inverse_and_log_det(d_y, condition) + # Invert the full transformation: + # x = y + (d - f(d)) * u, but note that f(d)=d_y. + x = y + (d - d_y) * u + # The log–determinant of the inverse is the negative of the forward. + return x, -logdet_inner + + +class Contract2(bijections.AbstractBijection): + alpha: Array | None + beta: Array + sigma: Array + mu: Array + nu: Array + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None = None + + def __init__(self, alpha, beta, sigma, mu, nu): + if alpha is not None: + self.alpha = jnp.array(alpha) + else: + self.alpha = None + self.beta = jnp.array(beta) + self.sigma = jnp.array(sigma) + self.mu = jnp.array(mu) + self.nu = jnp.array(nu) + self.shape = beta.shape + assert self.shape == () + + def transform_and_log_det( + self, x: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + """ + Forward transformation: + + T(x) = sigma_mod * (delta^2 * z^gamma - delta^(-2) * z^(-gamma)) / gamma + mu, + + where + gamma = exp(alpha), + delta = exp(beta), + sigma_mod = sigma + sqrt(1 + sigma^2), + z = x/2 + sqrt(1 + x^2/4) + (note: z = exp(asinh(x/2))). + + """ + if self.alpha is not None: + gamma = jnp.exp(self.alpha) + else: + gamma = 1 + delta = jnp.exp(self.beta) + sigma_mod = self.sigma + jnp.sqrt(1 + self.sigma * self.sigma) + mu = self.mu + nu = self.nu + + def trafo(x): + x = x - nu + z = x / 2 + jnp.sqrt(1 + x * x / 4) + return ( + sigma_mod + * (delta**2 * z**gamma - delta ** (-2) * z ** (-gamma)) + / gamma + + mu + ) + + y, det = jax.jvp(trafo, [x], [jnp.ones(())]) + return y, jnp.log(det) + + def inverse_and_log_det( + self, y: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + """ + Inverse transformation: + + Given y, we compute x such that + y = T(x) = sigma_mod * (delta^2 * z^gamma - delta^(-2) * z^(-gamma)) / gamma + mu, + with z = x/2 + sqrt(1 + x^2/4) = exp(asinh(x/2)). + + The inverse is computed via: + + 1. Set sigma_mod = sigma + sqrt(1 + sigma^2), gamma = exp(alpha), delta = exp(beta). + 2. Define A = (gamma/sigma_mod) * (y - mu). + 3. Solve for w from: delta^2 * w - delta^(-2) / w = A, + i.e., w = (A + sqrt(A^2 + 4)) / (2 * delta^2), where w = z^gamma. + 4. Recover z = w^(1/gamma). + 5. Then, x = z - 1/z. + """ + if self.alpha is not None: + gamma = jnp.exp(self.alpha) + else: + gamma = 1 + delta = jnp.exp(self.beta) + sigma_mod = self.sigma + jnp.sqrt(1 + self.sigma * self.sigma) + mu = self.mu + nu = self.nu + + def inv_trafo(y): + A = (gamma / sigma_mod) * (y - mu) + w = (A + jnp.sqrt(A * A + 4)) / (2 * delta**2) + z = w ** (1 / gamma) + z = z - 1 / z + return z + nu + + x, det = jax.jvp(inv_trafo, [y], [jnp.ones(())]) + return x, jnp.log(det) + + +class DipBij(bijections.AbstractBijection): + b: jnp.ndarray # raw parameter (scalar) + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None = None + + def __init__(self): + # Store the parameter b (a scalar). Then a = sigmoid(b) ∈ (0,1). + self.b = jnp.zeros(()) + self.shape = self.b.shape + # We expect b to be scalar. + assert self.shape == (), "Parameter b must be a scalar." + + def transform_and_log_det( + self, x: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + # Compute a = sigmoid(b) + a = jnp.tanh(self.b) + + # Define the forward transformation. + def f(x): + return x - (a * x) / (1 + x**2) + + # Use jax.jvp to compute the derivative of f at x. + y, tangent = jax.jvp(f, (x,), (jnp.ones_like(x),)) + # For a 1d transformation the log-det is just the log of the absolute derivative. + logdet = jnp.sum(jnp.log(jnp.abs(tangent))) + return y, logdet + + def inverse_and_log_det( + self, y: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + # Compute a = sigmoid(b) + a = jnp.tanh(self.b) + + # The forward map is: f(x) = x - (a*x)/(1+x**2). + # Its inverse is given by solving for x in + # x - (a*x)/(1+x^2) = y, + # which can be rearranged to: + # x^3 - y*x^2 + (1 - a)*x - y = 0. + # We solve this cubic by first shifting: let x = z + y/3. + m = y / 3 + # The depressed cubic is: z^3 + P*z + Q = 0, with + P = (1 - a) - y**2 / 3 + Q = -2 * y**3 / 27 - (a + 2) * y / 3 + # Compute the discriminant: + delta = (Q / 2) ** 2 + (P / 3) ** 3 + # Cardano's formula for the real solution: + z = jnp.cbrt(-Q / 2 + jnp.sqrt(delta)) + jnp.cbrt(-Q / 2 - jnp.sqrt(delta)) + x = z + m + + # To compute the log-det for the inverse, note that it is the negative of + # the forward log-det. We compute f'(x) via jax.jvp. + def f(x): + return x - (a * x) / (1 + x**2) + + _, fprime = jax.jvp(f, (x,), (jnp.ones_like(x),)) + logdet = -jnp.sum(jnp.log(jnp.abs(fprime))) + return x, logdet + + +class Contract(bijections.AbstractBijection): + alpha: Array + shape: tuple[int, ...] + cond_shape: tuple[int, ...] | None = None - def make_transformer(): - elemwises = [] - # loc = bijections.Loc(jnp.zeros(())) - # elemwises.append(loc) + def __init__(self, alpha): + self.alpha = jnp.array(alpha) + self.shape = alpha.shape + assert self.shape == () + def transform_and_log_det( + self, x: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + beta = jax.scipy.special.expit(self.alpha) + + def trafo(x): + z = 2 * jnp.asinh(x / 2) + return jnp.sinh(beta * z) / beta + + y, det = jax.jvp(trafo, [x], [jnp.ones(())]) + return y, jnp.sum(jnp.log(det)) + + def inverse_and_log_det( + self, y: ArrayLike, condition: ArrayLike | None = None + ) -> tuple[Array, Array]: + beta = jax.scipy.special.expit(self.alpha) + + def trafo(y): + z = jnp.asinh(beta * y) / beta / 2 + return 2 * jnp.sinh(z) + + x, det = jax.jvp(trafo, [y], [jnp.ones(())]) + return x, jnp.sum(jnp.log(det)) + + +class Activation(eqx.Module): + fn: Callable + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.fn(*args) + + +def make_transformer( + affine_transformer=False, contract_transformer=True, asymmetric_transformer=True +): + elemwises = [] + + if affine_transformer: + affine = bijections.Affine(jnp.zeros(()), jnp.ones(())) + scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) + affine = eqx.tree_at( + where=lambda aff: aff.scale, + pytree=affine, + replace=scale, + ) + elemwises.append(affine) + + if asymmetric_transformer: for loc in [0.0]: scale = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) theta = Parameterize(lambda x: x + jnp.sqrt(1 + x**2), jnp.zeros(())) @@ -1043,38 +1453,183 @@ def make_transformer(): ) elemwises.append(bijections.Invert(affine)) - if len(elemwises) == 1: - return elemwises[0] - return bijections.Chain(elemwises) + if contract_transformer: + elemwises.append( + Contract2( + None, + jnp.zeros(()), + jnp.zeros(()), + jnp.zeros(()), + jnp.zeros(()), + ) + ) + + if len(elemwises) == 1: + return elemwises[0] + return bijections.Chain(elemwises) + + +def make_twin_flow_scan( + key, + n_dim, + *, + zero_init=False, + n_layers, + nn_width=None, + nn_depth=None, + num_householder=1, + affine_transformer=False, + contract_transformer=True, + asymmetric_transformer=True, + activation, +): + if nn_width is None: + nn_width = 32 + if nn_depth is None: + nn_depth = 1 + + def make_layer(key): + keys = jax.random.split(key, 4) + + def make_coupling(key): + transformer = make_transformer( + affine_transformer=affine_transformer, + contract_transformer=contract_transformer, + asymmetric_transformer=asymmetric_transformer, + ) + + coupling = bijections.Coupling( + key=key, + transformer=transformer, + untransformed_dim=n_dim // 2, + dim=n_dim, + nn_width=nn_width, + nn_depth=nn_depth, + nn_activation=activation, + ) + + if zero_init: + coupling = jax.tree_util.tree_map( + lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, + coupling, + ) + return coupling + + layers = [] + + if num_householder > 0: + layers.append(make_hh(keys[0], n_dim, num_householder, randomize_base=True)) + + # layers.append(Rotations(keys[0], n_dim)) + + # layers.append(Orthogonal(keys[0], n_dim, 8)) + + inner = Contract2( + jnp.zeros(()), + jnp.zeros(()), + jnp.zeros(()), + jnp.zeros(()), + jnp.zeros(()), + ) + layers.append(Planar(keys[0], inner, n_dim)) + + layers.append(make_coupling(keys[1])) + + # layers.append(bijections.Flip((n_dim,))) + # layers.append(make_coupling(keys[2])) + + permutation = jax.random.permutation(keys[3], n_dim) + layers.append(bijections.Permute(permutation)) + + return bijections.Chain(layers) + + keys = jax.random.split(key, n_layers) + layers = eqx.filter_vmap(make_layer)(keys) + return bijections.Scan(layers) + + +def make_flow_scan( + key, + n_dim, + *, + zero_init=False, + n_layers, + nn_width=None, + nn_depth=None, + n_embed=None, + n_deembed=None, + mvscale=False, + num_householder=1, + twin_layers=False, + affine_transformer=False, + contract_transformer=True, + asymmetric_transformer=True, + sandwich_householder=False, + activation, + reuse_embed=True, +): + dim = n_dim + + if nn_width is None: + nn_width = 32 + if n_embed is None: + n_embed = 2 * nn_width + if n_deembed is None: + n_deembed = 2 * nn_width + if nn_depth is None: + nn_depth = 1 # Just to get at the size - transformer = make_transformer() + transformer = make_transformer( + affine_transformer=affine_transformer, + contract_transformer=contract_transformer, + asymmetric_transformer=asymmetric_transformer, + ) size = MaskedCoupling.conditioner_output_size(dim, transformer) key, key1 = jax.random.split(key) embed = eqx.nn.Sequential( [ eqx.nn.Linear(dim, n_embed, key=key1, dtype=jnp.float32, use_bias=True), - # Activation(_NN_ACTIVATION), - # eqx.nn.LayerNorm(shape=(n_embed,), dtype=jnp.float32), + eqx.nn.LayerNorm(shape=(n_embed,), dtype=jnp.float32), ] ) key, key1 = jax.random.split(key) embed_back = eqx.nn.Linear( - n_deembed, size, key=key1, dtype=jnp.float32, use_bias=True + n_deembed, size, key=key1, dtype=jnp.float32, use_bias=False ) embed_back = jax.tree_util.tree_map( lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, embed_back, ) - rng = np.random.default_rng(42) # TODO + key, key1 = jax.random.split(key) + seeds = jax.random.randint(key1, (4,), 0, 1 << 31 - 1) + rng = np.random.default_rng([int(seed) for seed in seeds]) order, counts = _generate_permutations(rng, dim, n_layers) mask = order == 0 mask[...] = False for i in range(len(mask)): mask[i, order[i, : counts[i]]] = True + if False: + if n_layers >= 12 and dim > 2: + mask[n_layers // 2, :] = True + mask[n_layers // 2, 0] = False + mask[n_layers // 2 + 1, :] = False + mask[n_layers // 2 + 1, 0] = True + mask[n_layers // 2 + 2, :] = True + mask[n_layers // 2 + 2, -1] = False + mask[n_layers // 2 + 3, :] = False + mask[n_layers // 2 + 3, -1] = True + + if twin_layers: + interleaved = np.empty((mask.shape[0] * 2, mask.shape[1]), dtype=mask.dtype) + interleaved[0::2] = mask + interleaved[1::2] = ~mask + mask = interleaved + n_layers = len(mask) + def make_mvscale(key, n_dim): params = jax.random.normal(key, (n_dim,)) params = params / jnp.linalg.norm(params) @@ -1082,8 +1637,11 @@ def make_mvscale(key, n_dim): def make_layer(key, mask, embed, embed_back): key1, key2, key3, key4, key5 = jax.random.split(key, 5) - transformer = make_transformer() - bias = Add(jax.random.normal(key5, (size,)) * 0.001) + transformer = make_transformer( + affine_transformer=affine_transformer, + contract_transformer=contract_transformer, + asymmetric_transformer=asymmetric_transformer, + ) inner = eqx.nn.MLP( n_embed, n_deembed, @@ -1091,10 +1649,10 @@ def make_layer(key, mask, embed, embed_back): depth=nn_depth, key=key2, dtype=jnp.float32, - activation=_NN_ACTIVATION, + activation=activation, ) inner = jax.tree_util.tree_map( - lambda x: x * 1e-3 if eqx.is_inexact_array(x) else x, + lambda x: x * 1e-2 if eqx.is_inexact_array(x) else x, inner, ) @@ -1105,7 +1663,6 @@ def make_layer(key, mask, embed, embed_back): eqx.nn.Sequential( [ embed_back, - bias, ] ), ] @@ -1121,42 +1678,76 @@ def make_layer(key, mask, embed, embed_back): nn_depth=nn_depth, ) - if mvscale: - scale = make_mvscale(key4, dim) - return bijections.Chain([coupling, scale]) - else: + if num_householder == 0: return bijections.Chain([coupling]) + if sandwich_householder: + hh = make_hh(key4, dim, num_householder, randomize_base=True) + return bijections.Sandwich(coupling, hh) + else: + hh = make_hh(key4, dim, num_householder, randomize_base=True) + inner = Contract2( + jnp.zeros(()), + jnp.zeros(()), + jnp.zeros(()), + jnp.zeros(()), + jnp.zeros(()), + ) + inner = bijections.Chain([inner, DipBij()]) + planar = Planar(key4, inner, n_dim) + return bijections.Chain([coupling, hh, planar]) keys = jax.random.split(key, n_layers) base = make_layer(key, mask[0], embed, embed_back) - out_axes = eqx.tree_at( - lambda tree: tree.bijections[0].conditioner.layers[1].layers[0], - pytree=base, - replace=None, - ) - out_axes = eqx.tree_at( - lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1].layers[0], - pytree=out_axes, - replace=None, - ) - out_axes = jax.tree.map(lambda leaf: eqx.if_array(0)(leaf), out_axes) + + if sandwich_householder: + + def select_coupling(tree): + return tree.inner + else: + + def select_coupling(tree): + return tree.bijections[0] + + if reuse_embed: + out_axes = eqx.tree_at( + lambda tree: select_coupling(tree).conditioner.layers[1].layers[0], + pytree=base, + replace=None, + ) + + out_axes = eqx.tree_at( + lambda tree: select_coupling(tree) + .conditioner.layers[1] + .layers[-1] + .layers[0], + pytree=out_axes, + replace=None, + ) + out_axes = jax.tree.map(eqx.if_array(0), out_axes) + else: + out_axes = jax.tree.map(eqx.if_array(0), base) vectorized = eqx.filter_vmap( make_layer, in_axes=(0, 0, None, None), out_axes=out_axes ) - vectorize = jax.tree.map(lambda leaf: eqx.is_array(leaf), base) - vectorize = eqx.tree_at( - lambda tree: tree.bijections[0].conditioner.layers[1].layers[0], - pytree=vectorize, - replace=False, - ) - vectorize = eqx.tree_at( - lambda tree: tree.bijections[0].conditioner.layers[1].layers[-1].layers[0], - pytree=vectorize, - replace=False, - ) + vectorize = jax.tree.map(eqx.is_array, base) + + if reuse_embed: + vectorize = eqx.tree_at( + lambda tree: select_coupling(tree).conditioner.layers[1].layers[0], + pytree=vectorize, + replace=False, + ) + vectorize = eqx.tree_at( + lambda tree: select_coupling(tree) + .conditioner.layers[1] + .layers[-1] + .layers[0], + pytree=vectorize, + replace=False, + ) return Scan( vectorized(keys, mask, embed, embed_back), @@ -1175,6 +1766,7 @@ def make_flow_loop( n_layers, nn_width=None, nn_depth=None, + activation, ): def make_layer(key, untransformed_dim: int | None, permutation=None): key, key_couple, key_permute, key_hh = jax.random.split(key, 4) @@ -1189,9 +1781,10 @@ def make_layer(key, untransformed_dim: int | None, permutation=None): key_couple, n_dim, untransformed_dim, - nn_activation=_NN_ACTIVATION, + nn_activation=activation, nn_width=nn_width, nn_depth=nn_depth, + activation=activation, ) if zero_init: @@ -1279,7 +1872,28 @@ def make_flow( n_deembed=None, kind="subset", mvscale=False, + num_householder=1, + twin_layers=False, + affine_transformer=False, + contract_transformer=False, + asymmetric_transformer=False, + sandwich_householder=False, + activation=None, + reuse_embed=False, ): + if activation is None: + activation = jax.nn.leaky_relu + if activation == "gelu": + activation = jax.nn.gelu + if activation == "relu": + activation = jax.nn.relu + if activation == "leaky_relu": + activation = jax.nn.leaky_relu + if activation == "tanh": + activation = jnp.tanh + if activation == "sigmoid": + activation = jax.nn.sigmoid + positions = np.array(positions) gradients = np.array(gradients) @@ -1293,7 +1907,7 @@ def make_flow( raise ValueError("No draws") elif n_draws == 1: assert np.all(gradients != 0) - diag = np.clip(1 / jnp.sqrt(jnp.abs(gradients[0])), 1e-5, 1e5) + diag = np.clip(1 / jnp.sqrt(jnp.abs(gradients[0])), 1e-8, 1e8) assert np.isfinite(diag).all() mean = jnp.zeros_like(diag) else: @@ -1302,7 +1916,7 @@ def make_flow( diag = jnp.sqrt(pos_std / grad_std) mean = positions.mean(0) + gradients.mean(0) * diag * diag - key = jax.random.PRNGKey(seed % (2**63)) + key = jax.random.key(seed % (2**63), impl="threefry2x32") diag_param = Parameterize( lambda x: x + jnp.sqrt(1 + x**2), @@ -1333,6 +1947,7 @@ def make_flow( n_layers=n_layers, nn_width=nn_width, nn_depth=nn_depth, + activation=activation, ) elif kind == "masked": inner = make_flow_scan( @@ -1345,6 +1960,55 @@ def make_flow( n_embed=n_embed, n_deembed=n_deembed, mvscale=mvscale, + num_householder=num_householder, + twin_layers=twin_layers, + activation=activation, + affine_transformer=affine_transformer, + contract_transformer=contract_transformer, + asymmetric_transformer=asymmetric_transformer, + reuse_embed=reuse_embed, + sandwich_householder=sandwich_householder, + ) + elif kind == "flowjax_coupling": + base_dist = flowjax.distributions.StandardNormal((n_dim,)) + if nn_width is None: + nn_width = 32 + if nn_depth is None: + nn_depth = 1 + + if contract_transformer: + transformer = Contract2( + jnp.zeros(()), + jnp.zeros(()), + jnp.zeros(()), + jnp.zeros(()), + jnp.zeros(()), + ) + else: + transformer = None + + inner = flowjax.flows.coupling_flow( + key, + base_dist=base_dist, + flow_layers=n_layers, + nn_width=nn_width, + nn_depth=nn_depth, + transformer=transformer, + nn_activation=activation, + ) + inner = inner.bijection + elif kind == "twin": + inner = make_twin_flow_scan( + key, + n_dim, + zero_init=zero_init, + n_layers=n_layers, + nn_width=nn_width, + nn_depth=nn_depth, + num_householder=num_householder, + affine_transformer=affine_transformer, + contract_transformer=contract_transformer, + activation=activation, ) else: raise ValueError(f"Unknown flow kind: {kind}") @@ -1369,6 +2033,7 @@ def extend_flow( verbose: bool = False, nn_width=None, nn_depth=None, + activation, ): n_draws, n_dim = positions.shape @@ -1442,7 +2107,7 @@ def extend_flow( transformer=affine, untransformed_dim=n_dim - extension_var_trafo_count, dim=n_dim, - nn_activation=_NN_ACTIVATION, + nn_activation=activation, nn_width=width, nn_depth=nn_depth, ) @@ -1463,7 +2128,7 @@ def extend_flow( transformer=affine, untransformed_dim=extension_var_trafo_count, dim=n_dim, - nn_activation=_NN_ACTIVATION, + nn_activation=activation, nn_width=width, nn_depth=nn_depth, ) @@ -1508,7 +2173,7 @@ def extend_flow( transformer=affine, untransformed_dim=extension_var_trafo_count, dim=n_dim, - nn_activation=_NN_ACTIVATION, + nn_activation=activation, nn_width=width, nn_depth=nn_depth, ) diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index 64b8880..41d18f4 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -1,7 +1,6 @@ from functools import partial from importlib.util import find_spec from typing import Callable -import time if find_spec("flowjax") is None: raise ImportError( @@ -29,7 +28,7 @@ import optax from paramax import unwrap, NonTrainable -from nutpie.normalizing_flow import Coupling, Scan, extend_flow, make_flow +from nutpie.normalizing_flow import Coupling, Householder, Scan, extend_flow, make_flow import tqdm _BIJECTION_TRACE = [] @@ -44,7 +43,7 @@ def fit_to_data( loss_fn: Callable | None = None, max_epochs: int = 100, max_patience: int = 5, - batch_size: int = 100, + batch_size: int = 128, val_prop: float = 0.1, learning_rate: float = 5e-4, optimizer: optax.GradientTransformation | None = None, @@ -52,6 +51,7 @@ def fit_to_data( show_progress: bool = True, opt_state=None, verbose: bool = False, + stop_value: float | None = None, ): r"""Train a distribution (e.g. a flow) to samples from the target distribution. @@ -88,7 +88,7 @@ def fit_to_data( data = tuple(jnp.asarray(a) for a in data) if optimizer is None: - optimizer = optax.adam(learning_rate) + optimizer = optax.apply_if_finite(optax.adamw(learning_rate), 10) if loss_fn is None: loss_fn = MaximumLikelihoodLoss() @@ -152,7 +152,9 @@ def fit_to_data( key, subkey = jr.split(key) loss_i = loss_fn(params, static, *batch, key=subkey) batch_losses.append(loss_i) - losses["val"].append(sum(batch_losses) / len(batch_losses)) + + loss = sum(batch_losses) / len(batch_losses) + losses["val"].append(loss) loop.set_postfix({k: v[-1] for k, v in losses.items()}) if losses["val"][-1] == min(losses["val"]): @@ -162,6 +164,10 @@ def fit_to_data( loop.set_postfix_str(f"{loop.postfix} (Max patience reached)") break + elif stop_value is not None and loss < stop_value: + loop.set_postfix_str(f"{loop.postfix} (Stop value reached)") + break + params = best_params if return_best else params dist = eqx.combine(params, static) return dist, losses, opt_state @@ -213,6 +219,12 @@ def inverse_gradient_and_val(bijection, draw, grad, logp): draw, logdet = bijection.inverse_and_log_det(draw) grad = grad * unwrap(bijection.scale) return (draw, grad, logp - logdet) + elif isinstance(bijection, Householder): + params = unwrap(bijection.params) + params = params / jnp.linalg.norm(params) + draw = draw - 2 * params * (draw @ params) + grad = grad - 2 * params * (grad @ params) + return (draw, grad, logp) elif isinstance(bijection, bijections.Vmap): def inner(bijection, y, y_grad, y_logp): @@ -372,8 +384,8 @@ def fit_flow(key, bijection, loss_fn, draws, grads, logps, **kwargs): dist=flow, x=(draws, grads, logps), loss_fn=loss_fn, - max_epochs=500, return_best=True, + stop_value=-5, **kwargs, ) return fit.bijection, losses, opt_state @@ -480,6 +492,7 @@ def __init__( debug_save_bijection=False, make_optimizer=None, num_layers=9, + max_epochs=200, ): self._logp_fn = logp_fn self._make_flow_fn = make_flow_fn @@ -490,7 +503,7 @@ def __init__( self._num_layers = num_layers if make_optimizer is None: self._make_optimizer = lambda: optax.apply_if_finite( - optax.adamw(learning_rate), 50 + optax.adamw(learning_rate), 10 ) else: self._make_optimizer = make_optimizer @@ -511,6 +524,7 @@ def __init__( self._extension_var_trafo_count = extension_var_trafo_count self._debug_save_bijection = debug_save_bijection self._layers = 0 + self._max_epochs = max_epochs if extension_windows is None: self._extension_windows = [] @@ -661,14 +675,11 @@ def update(self, seed, positions, gradients, logps): ) params, static = eqx.partition(flow, eqx.is_inexact_array) - start = time.time() old_loss = self._loss_fn( params, static, positions[-128:], gradients[-128:], logps[-128:] ) - if self._verbose: - print("loss function time: ", time.time() - start) - if np.isfinite(old_loss) and old_loss < -5 and self.index > 10: + if np.isfinite(old_loss) and old_loss < -4 and self.index > 10: if self._verbose: print(f"Loss is low ({old_loss}), skipping training") return @@ -686,6 +697,7 @@ def update(self, seed, positions, gradients, logps): batch_size=self._batch_size, opt_state=self._opt_state if self._reuse_opt_state else None, max_patience=self._max_patience, + max_epochs=self._max_epochs, ) flow = flowjax.flows.Transformed( @@ -862,7 +874,7 @@ def make_transform_adapter( show_progress=False, nn_depth=None, nn_width=None, - num_layers=20, + num_layers=8, num_diag_windows=6, learning_rate=5e-4, untransformed_dim=None, @@ -885,9 +897,18 @@ def make_transform_adapter( mvscale_layer=False, num_project=None, num_embed=None, + num_householder=8, + twin_layers=False, + activation=None, + max_epochs=200, + affine_transformer=False, + contract_transformer=True, + asymmetric_transformer=False, + reuse_embed=True, ): if extension_windows is None: extension_windows = [] + return partial( TransformAdapter, verbose=verbose, @@ -902,6 +923,13 @@ def make_transform_adapter( n_deembed=num_embed, mvscale=mvscale_layer, kind=coupling_type, + num_householder=num_householder, + twin_layers=twin_layers, + activation=activation, + affine_transformer=affine_transformer, + contract_transformer=contract_transformer, + asymmetric_transformer=asymmetric_transformer, + reuse_embed=reuse_embed, ), show_progress=show_progress, num_diag_windows=num_diag_windows, @@ -921,4 +949,5 @@ def make_transform_adapter( debug_save_bijection=debug_save_bijection, make_optimizer=make_optimizer, num_layers=num_layers, + max_epochs=max_epochs, ) diff --git a/tests/reference/test_deterministic_sampling_jax.txt b/tests/reference/test_deterministic_sampling_jax.txt new file mode 100644 index 0000000..114966e --- /dev/null +++ b/tests/reference/test_deterministic_sampling_jax.txt @@ -0,0 +1,200 @@ +0.941959 +0.559649 +0.534203 +0.561444 +0.561444 +0.418685 +0.827896 +0.847014 +0.738508 +0.961291 +0.923931 +1.00584 +1.16386 +1.10065 +1.6348 +1.13139 +0.993458 +0.993458 +0.966241 +1.10922 +1.10922 +1.05723 +1.05723 +2.32492 +0.0700824 +0.0860656 +1.36431 +0.829624 +0.584658 +0.531506 +0.507961 +0.543701 +0.510104 +2.46898 +0.820341 +0.490474 +0.343958 +0.300549 +2.60267 +0.588131 +0.430013 +0.618032 +1.27527 +1.80449 +1.80449 +0.855217 +0.556106 +1.77619 +2.03761 +1.02106 +0.774811 +1.78438 +1.61398 +0.712683 +1.04966 +1.17936 +1.5425 +1.5425 +1.26262 +1.39659 +0.337024 +0.177694 +0.0424286 +0.180403 +0.140553 +0.367095 +0.348732 +0.341436 +1.82764 +0.692738 +0.629186 +0.245706 +0.732305 +0.56873 +0.498757 +0.204131 +0.417031 +0.184895 +0.208768 +0.238139 +1.95089 +1.95089 +0.593379 +0.593379 +0.750063 +0.69929 +0.490359 +0.478709 +0.361632 +0.346159 +0.728965 +1.58228 +0.985676 +1.58468 +0.709012 +0.700483 +0.805006 +1.70347 +1.26293 +1.24837 +0.23989 +0.881025 +1.39084 +1.37812 +0.969265 +0.969265 +0.938487 +0.846447 +1.61945 +0.108473 +0.173496 +0.897353 +0.455899 +0.571886 +0.891672 +0.891672 +0.864419 +0.739099 +1.49009 +1.49009 +0.385499 +0.228701 +1.83156 +1.83156 +0.947635 +0.805623 +0.714762 +0.853477 +1.45906 +0.908818 +0.540951 +1.40995 +1.22564 +0.26496 +0.159994 +0.423836 +0.350158 +0.388884 +1.39507 +0.727701 +1.80674 +0.466389 +1.61574 +1.61574 +0.42774 +0.217983 +0.14579 +1.01321 +1.01321 +1.19713 +0.390791 +0.223687 +0.149019 +0.103866 +0.153768 +0.12942 +0.346371 +0.814553 +2.41042 +0.42739 +0.322291 +0.248911 +0.854404 +1.35372 +1.35372 +2.00546 +0.0457881 +0.0415644 +0.0797551 +0.0913076 +0.070948 +0.00993872 +0.421448 +0.550377 +0.609387 +0.490487 +2.6607 +0.32804 +0.385999 +0.497294 +1.67109 +1.14328 +1.14328 +0.903063 +0.903063 +0.903063 +0.691269 +2.00151 +0.587672 +0.79679 +1.35563 +0.598471 +0.681826 +0.818296 +1.14265 +0.113094 +0.250861 +0.284491 +0.00420445 +0.00566936 diff --git a/tests/reference/test_deterministic_sampling_numba.txt b/tests/reference/test_deterministic_sampling_numba.txt new file mode 100644 index 0000000..6426e8c --- /dev/null +++ b/tests/reference/test_deterministic_sampling_numba.txt @@ -0,0 +1,200 @@ +0.862203 +0.743827 +0.985284 +0.864159 +1.11537 +1.46228 +1.46228 +0.731645 +0.618394 +0.70658 +1.58816 +1.58816 +1.58816 +1.58816 +1.02597 +1.02597 +2.38965 +0.0442154 +0.0556998 +1.20147 +0.878239 +0.595919 +0.542086 +0.520452 +0.56279 +0.539904 +0.129453 +0.136407 +0.408806 +0.34263 +0.929525 +0.947864 +0.947864 +1.94444 +0.911973 +0.429576 +0.776378 +0.452981 +0.985476 +1.74745 +1.74095 +1.74095 +0.9855 +0.886535 +0.617313 +0.86405 +2.00577 +0.839407 +0.745118 +1.49611 +1.74491 +1.40854 +0.631877 +1.95302 +1.01379 +1.1063 +0.930275 +0.315935 +0.225544 +0.136821 +0.180021 +0.498635 +0.462448 +0.445633 +0.0878991 +0.105731 +0.355683 +0.750934 +0.750934 +0.874486 +1.15119 +0.657067 +0.500027 +1.28332 +1.28332 +0.919994 +1.09658 +1.73803 +1.13439 +1.21956 +0.643106 +0.329788 +0.456239 +0.596018 +0.180103 +0.388767 +1.03772 +1.03192 +1.03192 +1.04759 +1.04759 +1.13558 +0.673716 +0.871073 +0.50739 +0.625146 +0.999657 +1.00779 +2.06182 +0.707917 +0.107437 +0.0772623 +0.10719 +0.36616 +0.14863 +0.0333724 +0.0295763 +0.0205304 +0.127619 +0.164319 +0.241143 +0.376838 +0.87369 +1.64165 +0.106128 +0.170459 +0.916833 +0.458599 +0.575215 +0.894488 +0.894488 +0.865427 +0.739365 +0.681649 +0.72888 +1.38352 +1.38352 +2.28238 +2.28238 +2.28238 +0.567775 +0.41864 +1.41709 +1.41709 +1.41709 +1.41709 +0.600311 +0.598689 +0.627731 +0.460137 +1.86219 +1.81783 +1.78092 +1.78092 +1.78092 +0.492732 +1.37953 +1.16762 +0.597573 +0.627465 +0.617661 +0.649115 +0.608255 +0.685365 +0.685365 +0.685365 +0.685365 +0.685365 +2.2227 +0.971606 +0.4219 +0.879055 +0.74434 +2.08679 +1.34952 +1.34952 +1.34952 +1.34952 +0.513284 +0.16734 +0.174037 +0.626756 +0.913504 +0.271423 +0.200176 +0.132462 +0.465497 +0.406755 +0.493296 +0.0175891 +0.0234891 +0.0220327 +0.132404 +0.0788943 +0.0949265 +0.103031 +0.0760492 +0.377155 +1.90599 +1.58063 +1.58063 +1.17038 +0.556726 +0.55085 +0.24632 +0.375951 +0.339243 +0.747524 +1.82921 +0.794344 diff --git a/tests/reference/test_deterministic_sampling_stan.txt b/tests/reference/test_deterministic_sampling_stan.txt new file mode 100644 index 0000000..3bed2a2 --- /dev/null +++ b/tests/reference/test_deterministic_sampling_stan.txt @@ -0,0 +1,2 @@ +1.21572 1.03376 1.60518 1.60518 1.59553 1.35023 0.761056 1.41688 1.41688 1.41688 +0.252389 0.999663 0.999663 0.999663 0.740026 0.387763 0.944247 0.289785 1.52909 0.683129 diff --git a/tests/reference/test_normalizing_flow.txt b/tests/reference/test_normalizing_flow.txt new file mode 100644 index 0000000..8a224a0 --- /dev/null +++ b/tests/reference/test_normalizing_flow.txt @@ -0,0 +1,100 @@ +0.324871 +1.16777 +0.102039 +0.0579082 +0.985197 +0.550663 +0.929168 +0.543959 +0.166275 +0.359855 +0.764495 +1.77769 +0.462447 +0.984399 +0.490158 +1.02799 +0.702622 +0.473246 +0.0127807 +1.13249 +1.0929 +0.357403 +1.27519 +0.842248 +1.00152 +2.38541 +0.854202 +0.00735577 +0.218296 +2.20921 +1.74756 +0.119245 +1.74756 +0.119245 +0.381874 +1.82536 +1.29837 +2.52243 +0.86956 +0.0373546 +0.105311 +0.706774 +0.217778 +0.700421 +0.858623 +1.65319 +0.499877 +0.832728 +2.06511 +1.26955 +0.334041 +0.681329 +1.12741 +0.21517 +1.04719 +0.269313 +0.512924 +2.31191 +0.374169 +0.633086 +0.374169 +0.633086 +0.43021 +1.19342 +0.101336 +0.323738 +0.147408 +1.16919 +0.0614138 +1.33695 +1.54323 +0.199492 +0.351604 +0.396807 +1.05526 +1.62499 +0.266035 +0.54486 +0.861217 +0.621417 +0.416124 +0.64435 +0.430111 +0.634412 +0.614077 +0.986153 +0.76649 +0.549664 +0.353417 +1.19541 +0.467103 +1.8358 +1.04574 +0.438734 +0.641016 +0.699005 +1.69749 +0.317435 +0.175226 +0.739452 diff --git a/tests/test_pymc.py b/tests/test_pymc.py index dd8274c..5c5d2b3 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -8,7 +8,6 @@ import numpy as np import pymc as pm import pytest -from scipy import stats import nutpie import nutpie.compile_pymc @@ -268,10 +267,13 @@ def test_pymc_var_names(backend, gradient_backend): assert not hasattr(trace.posterior, "c") +# TODO For some reason, the sampling results with jax are +# not reproducible accross operating systems. Figure this +# out and add the array_compare marker. +# @pytest.mark.array_compare @pytest.mark.pymc @pytest.mark.flow -@pytest.mark.parametrize("kind", ["masked", "subset"]) -def test_normalizing_flow(kind): +def test_normalizing_flow(): with pm.Model() as model: pm.HalfNormal("x", shape=2) @@ -279,8 +281,7 @@ def test_normalizing_flow(kind): model, backend="jax", gradient_backend="jax" ).with_transform_adapt( verbose=True, - coupling_type=kind, - num_layers=4, + num_layers=2, ) trace = nutpie.sample( compiled, @@ -288,40 +289,10 @@ def test_normalizing_flow(kind): transform_adapt=True, window_switch_freq=128, seed=1, - draws=2000, - ) - draws = trace.posterior.x.isel(x_dim_0=0, chain=0) - kstest = stats.ks_1samp(draws, stats.halfnorm.cdf) - assert kstest.pvalue > 0.01 - - draws = trace.posterior.x.isel(x_dim_0=1, chain=0) - kstest = stats.ks_1samp(draws, stats.halfnorm.cdf) - assert kstest.pvalue > 0.01 - - -@pytest.mark.pymc -@pytest.mark.flow -@pytest.mark.parametrize("kind", ["masked", "subset"]) -def test_normalizing_flow_1d(kind): - with pm.Model() as model: - pm.HalfNormal("x") - - compiled = nutpie.compile_pymc_model( - model, backend="jax", gradient_backend="jax" - ).with_transform_adapt( - verbose=True, - coupling_type=kind, - num_layers=4, - ) - trace = nutpie.sample( - compiled, - chains=1, - transform_adapt=True, - window_switch_freq=128, - seed=1, - draws=2000, + draws=500, ) assert float(trace.sample_stats.fisher_distance.mean()) < 0.1 + # return trace.posterior.x.isel(draw=slice(-50, None)).values.ravel() @pytest.mark.pymc @@ -357,3 +328,25 @@ def test_missing(backend, gradient_backend): tr = nutpie.sample(compiled, chains=1, seed=1) print(tr.posterior) assert hasattr(tr.posterior, "y_unobserved") + + +@pytest.mark.pymc +@pytest.mark.array_compare +def test_deterministic_sampling_numba(): + with pm.Model() as model: + pm.HalfNormal("a") + + compiled = nutpie.compile_pymc_model(model, backend="numba") + trace = nutpie.sample(compiled, chains=2, seed=123, draws=100, tune=100) + return trace.posterior.a.values.ravel() + + +@pytest.mark.pymc +@pytest.mark.array_compare +def test_deterministic_sampling_jax(): + with pm.Model() as model: + pm.HalfNormal("a") + + compiled = nutpie.compile_pymc_model(model, backend="jax", gradient_backend="jax") + trace = nutpie.sample(compiled, chains=2, seed=123, draws=100, tune=100) + return trace.posterior.a.values.ravel() diff --git a/tests/test_stan.py b/tests/test_stan.py index f44b755..da84fc9 100644 --- a/tests/test_stan.py +++ b/tests/test_stan.py @@ -225,13 +225,41 @@ def test_stan_flow(): b ~ normal(0, 1); } """ + import jax - compiled_model = nutpie.compile_stan_model(code=model).with_transform_adapt( - num_layers=2, - nn_width=4, - num_diag_windows=6, - ) - trace = nutpie.sample( - compiled_model, transform_adapt=True, window_switch_freq=150, tune=600, chains=1 - ) - trace.posterior.a # noqa: B018 + old = jax.config.update("jax_enable_x64", True) + try: + compiled_model = nutpie.compile_stan_model(code=model).with_transform_adapt( + num_layers=2, + nn_width=4, + ) + trace = nutpie.sample(compiled_model, transform_adapt=True, tune=2000, chains=1) + assert float(trace.sample_stats.fisher_distance.mean()) < 0.1 + trace.posterior.a # noqa: B018 + finally: + jax.config.update("jax_enable_x64", old) + + +# TODO: There are small numerical differences between linux and windows. +# We should figure out if they originate in stan or in nutpie. +@pytest.mark.array_compare(atol=1e-4) +@pytest.mark.stan +def test_deterministic_sampling_stan(): + model = """ + parameters { + real a; + } + model { + a ~ normal(0, 1); + } + generated quantities { + real b = normal_rng(0, 1) + a; + } + """ + + compiled_model = nutpie.compile_stan_model(code=model) + trace = nutpie.sample(compiled_model, chains=2, seed=123, draws=100, tune=100) + trace2 = nutpie.sample(compiled_model, chains=2, seed=123, draws=100, tune=100) + np.testing.assert_allclose(trace.posterior.a.values, trace2.posterior.a.values) + np.testing.assert_allclose(trace.posterior.b.values, trace2.posterior.b.values) + return trace.posterior.a.isel(draw=slice(None, 10)).values