diff --git a/blackjax/__init__.py b/blackjax/__init__.py index 89d78edd9..ff775f86c 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -12,6 +12,8 @@ from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat +from .ensemble import slice as _ensemble_slice +from .ensemble import stretch as _ensemble from .mcmc import adjusted_mclmc as _adjusted_mclmc from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic from .mcmc import barker @@ -120,6 +122,8 @@ def generate_top_level_api_from(module): elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) +ensemble = generate_top_level_api_from(_ensemble) +ensemble_slice = generate_top_level_api_from(_ensemble_slice) hmc_family = [hmc, nuts] diff --git a/blackjax/ensemble/__init__.py b/blackjax/ensemble/__init__.py new file mode 100644 index 000000000..89d9d9360 --- /dev/null +++ b/blackjax/ensemble/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ensemble sampling methods for BlackJAX.""" +from blackjax.ensemble import slice, stretch +from blackjax.ensemble.base import EnsembleInfo, EnsembleState + +__all__ = ["EnsembleState", "EnsembleInfo", "stretch", "slice"] diff --git a/blackjax/ensemble/base.py b/blackjax/ensemble/base.py new file mode 100644 index 000000000..c897c5f6e --- /dev/null +++ b/blackjax/ensemble/base.py @@ -0,0 +1,178 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base classes and utilities for ensemble sampling methods.""" +from typing import NamedTuple, Optional + +import jax +import jax.numpy as jnp + +from blackjax.types import Array, ArrayTree + +__all__ = [ + "EnsembleState", + "EnsembleInfo", + "get_nwalkers", + "tree_take", + "shuffle_triple", + "unshuffle_triple", + "split_triple", + "concat_triple_groups", + "complementary_triple", + "masked_select", + "vmapped_logdensity", + "prepare_split", + "build_states_from_triples", +] + + +class EnsembleState(NamedTuple): + """State of an ensemble sampler. + + coords + An array or PyTree of arrays of shape `(n_walkers, ...)` that + stores the current position of the walkers. + log_probs + An array of shape `(n_walkers,)` that stores the log-probability of + each walker. + blobs + An optional PyTree that stores metadata returned by the log-probability + function. + """ + + coords: ArrayTree + log_probs: Array + blobs: Optional[ArrayTree] = None + + +class EnsembleInfo(NamedTuple): + """Additional information on the ensemble transition. + + acceptance_rate + The acceptance rate of the ensemble. + is_accepted + A boolean array of shape `(n_walkers,)` indicating whether each walker's + proposal was accepted. + """ + + acceptance_rate: float + is_accepted: Array + + +def get_nwalkers(coords: ArrayTree) -> int: + """Get the number of walkers from ensemble coordinates.""" + return jax.tree_util.tree_flatten(coords)[0][0].shape[0] + + +def tree_take(tree: ArrayTree, idx: jnp.ndarray) -> ArrayTree: + """Index into a PyTree along the leading dimension.""" + return jax.tree_util.tree_map(lambda a: a[idx], tree) + + +def shuffle_triple(key, coords, log_probs, blobs): + """Shuffle ensemble coordinates, log_probs, and blobs.""" + n = get_nwalkers(coords) + idx = jax.random.permutation(key, n) + coords_s = tree_take(coords, idx) + log_probs_s = log_probs[idx] + blobs_s = None if blobs is None else tree_take(blobs, idx) + return coords_s, log_probs_s, blobs_s, idx + + +def unshuffle_triple(coords, log_probs, blobs, indices): + """Reverse a shuffle operation on ensemble coordinates, log_probs, and blobs.""" + inv = jnp.argsort(indices) + coords_u = tree_take(coords, inv) + log_probs_u = log_probs[inv] + blobs_u = None if blobs is None else tree_take(blobs, inv) + return coords_u, log_probs_u, blobs_u + + +def split_triple(coords, log_probs, blobs, nsplits): + """Split ensemble into nsplits contiguous groups.""" + n = get_nwalkers(coords) + group_size = n // nsplits + groups = [] + for i in range(nsplits): + s = i * group_size + e = (i + 1) * group_size if i < nsplits - 1 else n + coords_i = jax.tree_util.tree_map(lambda a: a[s:e], coords) + log_probs_i = log_probs[s:e] + blobs_i = ( + None if blobs is None else jax.tree_util.tree_map(lambda a: a[s:e], blobs) + ) + groups.append((coords_i, log_probs_i, blobs_i)) + return groups + + +def concat_triple_groups(group_triples): + """Concatenate groups of (coords, log_probs, blobs) triples.""" + coords_list, logp_list, blobs_list = zip(*group_triples) + coords = jax.tree_util.tree_map( + lambda *xs: jnp.concatenate(xs, axis=0), *coords_list + ) + logp = jnp.concatenate(logp_list, axis=0) + blobs = ( + None + if all(b is None for b in blobs_list) + else jax.tree_util.tree_map( + lambda *xs: jnp.concatenate(xs, axis=0), *blobs_list + ) + ) + return coords, logp, blobs + + +def complementary_triple(groups, i): + """Build complementary ensemble from all groups except group i.""" + return concat_triple_groups([g for j, g in enumerate(groups) if j != i]) + + +def masked_select(mask, new_val, old_val): + """Select between new and old values based on mask.""" + expand_dims = (1,) * (new_val.ndim - 1) + mask_expanded = mask.reshape((mask.shape[0],) + expand_dims) + return jnp.where(mask_expanded, new_val, old_val) + + +def vmapped_logdensity(logdensity_fn, coords): + """Evaluate logdensity function on ensemble coordinates with vmap.""" + outs = jax.vmap(logdensity_fn)(coords) + return outs if isinstance(outs, tuple) else (outs, None) + + +def prepare_split(rng_key, coords, log_probs, blobs, randomize_split, nsplits): + """Prepare ensemble for splitting into groups. + + Handles optional randomization, splitting, and returns components needed + for group-wise updates and subsequent unshuffling. + """ + if randomize_split: + key_shuffle, key_update = jax.random.split(rng_key) + coords_s, logp_s, blobs_s, indices = shuffle_triple( + key_shuffle, coords, log_probs, blobs + ) + else: + key_update = rng_key + coords_s, logp_s, blobs_s = coords, log_probs, blobs + indices = jnp.arange(get_nwalkers(coords)) + group_triples = split_triple(coords_s, logp_s, blobs_s, nsplits) + return key_update, group_triples, indices + + +def build_states_from_triples(group_triples, state_ctor, extra_fields=()): + """Build state objects from triples with optional extra fields. + + Handles both base EnsembleState and algorithm-specific states like + SliceEnsembleState that have additional fields. + """ + return [state_ctor(t[0], t[1], t[2], *extra_fields) for t in group_triples] diff --git a/blackjax/ensemble/slice.py b/blackjax/ensemble/slice.py new file mode 100644 index 000000000..bb4afa6ee --- /dev/null +++ b/blackjax/ensemble/slice.py @@ -0,0 +1,738 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ensemble Slice Sampling (ESS) implementation.""" +from typing import Callable, NamedTuple, Optional + +import jax +import jax.numpy as jnp + +from blackjax.base import SamplingAlgorithm +from blackjax.ensemble.base import ( + build_states_from_triples, + complementary_triple, + concat_triple_groups, + get_nwalkers, + prepare_split, + unshuffle_triple, + vmapped_logdensity, +) +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey + +__all__ = [ + "SliceEnsembleState", + "SliceEnsembleInfo", + "init", + "build_kernel", + "as_top_level_api", + "differential_direction", + "random_direction", + "gaussian_direction", + "slice_along_direction", +] + + +class SliceEnsembleState(NamedTuple): + """State of the ensemble slice sampler. + + coords + An array or PyTree of arrays of shape `(n_walkers, ...)` that + stores the current position of the walkers. + log_probs + An array of shape `(n_walkers,)` that stores the log-probability of + each walker. + blobs + An optional PyTree that stores metadata returned by the log-probability + function. + mu + The current scale parameter for the slice sampling directions. + tuning_active + Whether adaptive tuning of mu is currently active. + patience_count + Counter for determining when to stop tuning. + """ + + coords: ArrayTree + log_probs: jnp.ndarray + blobs: Optional[ArrayTree] = None + mu: float = 1.0 + tuning_active: bool = True + patience_count: int = 0 + + +class SliceEnsembleInfo(NamedTuple): + """Additional information on the ensemble slice sampling transition. + + acceptance_rate + Fraction of walkers that found valid slice points. + is_accepted + Boolean array of shape `(n_walkers,)` indicating successful slice updates. + expansions + Total number of slice expansions performed. + contractions + Total number of slice contractions performed. + nevals + Total number of log-density evaluations performed. + mu + The current value of the scale parameter. + """ + + acceptance_rate: float + is_accepted: jnp.ndarray + expansions: int + contractions: int + nevals: int + mu: float + + +def differential_direction( + rng_key: PRNGKey, + complementary_coords: ArrayTree, + n_update: int, + mu: float, +) -> tuple[ArrayTree, bool]: + """Generate direction vectors using the differential move. + + Directions are defined by the difference between two randomly selected + walkers from the complementary ensemble, scaled by 2*mu. + + Parameters + ---------- + rng_key + A PRNG key for random number generation. + complementary_coords + Coordinates of the complementary ensemble with shape (n_walkers, ...). + n_update + Number of walkers to update (number of directions needed). + mu + The scale parameter. + + Returns + ------- + A tuple (directions, tune_once) where directions is a PyTree matching + the structure of complementary_coords with leading dimension n_update, + and tune_once is True (indicating mu should be tuned). + """ + comp_leaves, _ = jax.tree_util.tree_flatten(complementary_coords) + n_comp = comp_leaves[0].shape[0] + + key1, key2 = jax.random.split(rng_key) + idx1 = jax.random.randint(key1, (n_update,), 0, n_comp) + j = jax.random.randint(key2, (n_update,), 0, n_comp - 1) + idx2 = j + (j >= idx1) # Ensure idx2 != idx1 + + walker1 = jax.tree_util.tree_map(lambda x: x[idx1], complementary_coords) + walker2 = jax.tree_util.tree_map(lambda x: x[idx2], complementary_coords) + + directions = jax.tree_util.tree_map( + lambda w1, w2: 2.0 * mu * (w1 - w2), walker1, walker2 + ) + + return directions, True + + +def random_direction( + rng_key: PRNGKey, + template_coords: ArrayTree, + n_update: int, + mu: float, +) -> tuple[ArrayTree, bool]: + """Generate random isotropic direction vectors. + + Directions are sampled from a standard normal distribution and scaled by 2*mu. + This corresponds to standard multivariate slice sampling without using + ensemble information. + + Parameters + ---------- + rng_key + A PRNG key for random number generation. + template_coords + Template coordinates to match structure. Leading dimension will be n_update. + n_update + Number of walkers to update (number of directions needed). + mu + The scale parameter. + + Returns + ------- + A tuple (directions, tune_once) where directions is a PyTree matching + the structure of template_coords with leading dimension n_update, + and tune_once is True. + """ + leaves, treedef = jax.tree_util.tree_flatten(template_coords) + n_leaves = len(leaves) + + keys = jax.random.split(rng_key, n_leaves) + + def sample_leaf(key, template_leaf): + template_shape = template_leaf.shape + if len(template_shape) > 0: + new_shape = (n_update,) + template_shape[1:] + else: + new_shape = (n_update,) + return jax.random.normal(key, new_shape) + + direction_leaves = [ + 2.0 * mu * sample_leaf(key, leaf) for key, leaf in zip(keys, leaves) + ] + + directions = jax.tree_util.tree_unflatten(treedef, direction_leaves) + + return directions, True + + +def gaussian_direction( + rng_key: PRNGKey, + complementary_coords: ArrayTree, + n_update: int, + mu: float, +) -> tuple[ArrayTree, bool]: + """Generate direction vectors using the Gaussian move. + + Directions are sampled from a multivariate normal distribution with covariance + estimated from the complementary ensemble. This move adapts to the local + geometry of the target distribution. + + Parameters + ---------- + rng_key + A PRNG key for random number generation. + complementary_coords + Coordinates of the complementary ensemble. + n_update + Number of walkers to update. + mu + The scale parameter. + + Returns + ------- + A tuple (directions, tune_once) where directions is a PyTree matching + the structure of complementary_coords with leading dimension n_update, + and tune_once is True. + """ + leaves, treedef = jax.tree_util.tree_flatten(complementary_coords) + n_leaves = len(leaves) + + keys = jax.random.split(rng_key, n_leaves) + + def sample_gaussian_leaf(key, leaf): + n_comp = leaf.shape[0] + leaf_flat = leaf.reshape(n_comp, -1) + d = leaf_flat.shape[1] + + mean = jnp.mean(leaf_flat, axis=0) + centered = leaf_flat - mean + + cov = jnp.dot(centered.T, centered) / (n_comp - 1) + jitter = 1e-6 * jnp.eye(d) + cov_reg = cov + jitter + + directions_flat = jax.random.multivariate_normal( + key, jnp.zeros(d), cov_reg, (n_update,) + ) + + orig_shape = leaf.shape + new_shape = (n_update,) + orig_shape[1:] + return (2.0 * mu * directions_flat).reshape(new_shape) + + direction_leaves = [ + sample_gaussian_leaf(key, leaf) for key, leaf in zip(keys, leaves) + ] + + directions = jax.tree_util.tree_unflatten(treedef, direction_leaves) + + return directions, True + + +def slice_along_direction( + rng_key: PRNGKey, + x0: ArrayTree, + logp0: float, + direction: ArrayTree, + logprob_fn: Callable, + maxsteps: int = 10000, + maxiter: int = 10000, +) -> tuple[ArrayTree, float, bool, int, int, int]: + """Perform slice sampling along a given direction. + + Implements the stepping-out and shrinking procedures for 1D slice sampling + along the specified direction vector. + + Parameters + ---------- + rng_key + A PRNG key for random number generation. + x0 + Current position (PyTree). + logp0 + Log-probability at current position. + direction + Direction vector (PyTree with same structure as x0). + logprob_fn + Function that computes log-probability given a position. + maxsteps + Maximum number of steps for stepping-out procedure. + maxiter + Maximum total iterations to prevent infinite loops. + + Returns + ------- + A tuple (x1, logp1, accepted, nexp, ncon, neval) where: + x1: New position (PyTree), or x0 if not accepted + logp1: Log-probability at new position, or logp0 if not accepted + accepted: Boolean indicating if a valid slice point was found + nexp: Number of expansions performed + ncon: Number of contractions performed + neval: Number of log-probability evaluations + """ + key_z0, key_lr, key_j, key_shrink = jax.random.split(rng_key, 4) + + z0 = logp0 - jax.random.exponential(key_z0) + + l_init = -jax.random.uniform(key_lr) + r_init = l_init + 1.0 + + j = jax.random.randint(key_j, (), 0, maxsteps) + k = maxsteps - 1 - j + + def eval_at_t(t): + xt = jax.tree_util.tree_map(lambda x, d: x + t * d, x0, direction) + return logprob_fn(xt) + + logp_l_init = eval_at_t(l_init) + + def left_expand_cond(carry): + l, logp_l, j_left, nexp, neval, iter_count = carry + return (j_left > 0) & (iter_count < maxiter) & (logp_l > z0) + + def left_expand_body(carry): + l, logp_l, j_left, nexp, neval, iter_count = carry + l_new = l - 1.0 + logp_new = eval_at_t(l_new) + return l_new, logp_new, j_left - 1, nexp + 1, neval + 1, iter_count + 1 + + l_final, _, _, nexp_left, neval_left, _ = jax.lax.while_loop( + left_expand_cond, left_expand_body, (l_init, logp_l_init, j, 0, 0, 0) + ) + + logp_r_init = eval_at_t(r_init) + + def right_expand_cond(carry): + r, logp_r, k_right, nexp, neval, iter_count = carry + return (k_right > 0) & (iter_count < maxiter) & (logp_r > z0) + + def right_expand_body(carry): + r, logp_r, k_right, nexp, neval, iter_count = carry + r_new = r + 1.0 + logp_new = eval_at_t(r_new) + return r_new, logp_new, k_right - 1, nexp + 1, neval + 1, iter_count + 1 + + r_final, _, _, nexp_right, neval_right, _ = jax.lax.while_loop( + right_expand_cond, right_expand_body, (r_init, logp_r_init, k, 0, 0, 0) + ) + + nexp_total = nexp_left + nexp_right + neval_after_expand = neval_left + neval_right + 2 + + def shrink_cond(carry): + _, _, _, _, _, iter_count, accepted, _, _ = carry + return (~accepted) & (iter_count < maxiter) + + def shrink_body(carry): + key, l, r, neval, ncon, iter_count, accepted, t_acc, logp_acc = carry + key, key_t = jax.random.split(key) + t = jax.random.uniform(key_t, minval=l, maxval=r) + logp_t = eval_at_t(t) + neval_new = neval + 1 + + inside_slice = logp_t >= z0 + accepted_new = accepted | inside_slice + + l_new = jnp.where(inside_slice, l, jnp.where(t < 0, t, l)) + r_new = jnp.where(inside_slice, r, jnp.where(t >= 0, t, r)) + ncon_new = jnp.where(inside_slice, ncon, ncon + 1) + t_acc_new = jnp.where(inside_slice & ~accepted, t, t_acc) + logp_acc_new = jnp.where(inside_slice & ~accepted, logp_t, logp_acc) + + return ( + key, + l_new, + r_new, + neval_new, + ncon_new, + iter_count + 1, + accepted_new, + t_acc_new, + logp_acc_new, + ) + + ( + _, + _, + _, + neval_shrink, + ncon_total, + _, + accepted, + t_final, + logp_final, + ) = jax.lax.while_loop( + shrink_cond, + shrink_body, + (key_shrink, l_final, r_final, 0, 0, 0, False, 0.0, logp0), + ) + + x1 = jax.tree_util.tree_map(lambda x, d: x + t_final * d, x0, direction) + + neval_total = neval_after_expand + neval_shrink + + return x1, logp_final, accepted, nexp_total, ncon_total, neval_total + + +def init( + position: ArrayLikeTree, + logdensity_fn: Callable, + has_blobs: bool = False, + mu: float = 1.0, +) -> SliceEnsembleState: + """Initialize the ensemble slice sampling algorithm. + + Parameters + ---------- + position + Initial positions for all walkers, with shape (n_walkers, ...). + logdensity_fn + The log-density function to evaluate. + has_blobs + Whether the log-density function returns additional metadata (blobs). + mu + Initial value of the scale parameter. + + Returns + ------- + Initial SliceEnsembleState. + """ + log_probs, blobs = vmapped_logdensity(logdensity_fn, position) + return SliceEnsembleState(position, log_probs, blobs, mu, True, 0) + + +def build_kernel( + move: str = "differential", + move_fn=None, + randomize_split: bool = True, + nsplits: int = 2, + maxsteps: int = 10000, + maxiter: int = 10000, + tune: bool = True, + patience: int = 5, + tolerance: float = 0.05, +) -> Callable: + """Build the ensemble slice sampling kernel. + + Parameters + ---------- + move + Type of move: "differential", "random", or "gaussian". Ignored if move_fn provided. + move_fn + Optional custom move function. If None, uses the specified move type. + randomize_split + If True, randomly shuffle walker indices before splitting into groups. + nsplits + Number of groups to split the ensemble into. + maxsteps + Maximum steps for slice stepping-out procedure. + maxiter + Maximum iterations for shrinking procedure. + tune + Whether to enable adaptive tuning of mu. + patience + Number of steps within tolerance before stopping tuning. + tolerance + Tolerance for expansion/contraction ratio to stop tuning. + + Returns + ------- + A kernel function that performs one step of ensemble slice sampling. + """ + if move_fn is None: + if move == "differential": + move_fn = differential_direction + elif move == "random": + move_fn = random_direction + elif move == "gaussian": + move_fn = gaussian_direction + + def kernel( + rng_key: PRNGKey, state: SliceEnsembleState, logdensity_fn: Callable + ) -> tuple[SliceEnsembleState, SliceEnsembleInfo]: + key_update, group_triples, indices = prepare_split( + rng_key, + state.coords, + state.log_probs, + state.blobs, + randomize_split, + nsplits, + ) + groups = build_states_from_triples( + group_triples, + SliceEnsembleState, + (state.mu, state.tuning_active, state.patience_count), + ) + + updated_groups = list(groups) + accepted_groups = [] + total_nexp = jnp.array(0, dtype=jnp.int32) + total_ncon = jnp.array(0, dtype=jnp.int32) + total_neval = jnp.array(0, dtype=jnp.int32) + + keys = jax.random.split(key_update, nsplits) + for i in range(nsplits): + comp_triple = complementary_triple( + [(g.coords, g.log_probs, g.blobs) for g in updated_groups], i + ) + complementary = SliceEnsembleState( + comp_triple[0], + comp_triple[1], + comp_triple[2], + state.mu, + state.tuning_active, + state.patience_count, + ) + + updated_group, accepted, nexp, ncon, neval = _update_half_slice( + keys[i], + groups[i], + complementary, + logdensity_fn, + move_fn, + maxsteps, + maxiter, + ) + updated_groups[i] = updated_group + accepted_groups.append(accepted) + total_nexp = total_nexp + nexp + total_ncon = total_ncon + ncon + total_neval = total_neval + neval + + coords_cat, logp_cat, blobs_cat = concat_triple_groups( + [(g.coords, g.log_probs, g.blobs) for g in updated_groups] + ) + shuffled_accepted = jnp.concatenate(accepted_groups, axis=0) + + if randomize_split: + new_coords, new_log_probs, new_blobs = unshuffle_triple( + coords_cat, logp_cat, blobs_cat, indices + ) + accepted = shuffled_accepted[jnp.argsort(indices)] + else: + new_coords, new_log_probs, new_blobs = coords_cat, logp_cat, blobs_cat + accepted = shuffled_accepted + + should_tune = tune & state.tuning_active + + nexp_eff = jnp.maximum(total_nexp, 1) + mu_tuned = state.mu * 2.0 * nexp_eff / (nexp_eff + total_ncon) + + exp_ratio = total_nexp / jnp.maximum(total_nexp + total_ncon, 1) + within_tolerance = jnp.abs(exp_ratio - 0.5) < tolerance + + patience_count_updated = jnp.where( + within_tolerance, state.patience_count + 1, 0 + ) + tuning_active_updated = patience_count_updated < patience + + mu_new = jnp.where(should_tune, mu_tuned, state.mu) + patience_count_new = jnp.where( + should_tune, patience_count_updated, state.patience_count + ) + tuning_active_new = jnp.where( + should_tune, tuning_active_updated, state.tuning_active + ) + + new_state = SliceEnsembleState( + new_coords, + new_log_probs, + new_blobs, + mu_new, + tuning_active_new, + patience_count_new, + ) + + acceptance_rate = jnp.mean(accepted) + info = SliceEnsembleInfo( + acceptance_rate=acceptance_rate, + is_accepted=accepted, + expansions=total_nexp, + contractions=total_ncon, + nevals=total_neval, + mu=mu_new, + ) + + return new_state, info + + return kernel + + +def _update_half_slice( + rng_key: PRNGKey, + walkers_to_update: SliceEnsembleState, + complementary_walkers: SliceEnsembleState, + logdensity_fn: Callable, + direction_fn: Callable, + maxsteps: int, + maxiter: int, +) -> tuple[SliceEnsembleState, jnp.ndarray, int, int, int]: + """Update a group of walkers using ensemble slice sampling. + + Parameters + ---------- + rng_key + PRNG key for random number generation. + walkers_to_update + Group of walkers to update. + complementary_walkers + Complementary ensemble used for generating directions. + logdensity_fn + Log-density function. + direction_fn + Function to generate direction vectors. + maxsteps + Maximum steps for stepping-out. + maxiter + Maximum iterations. + + Returns + ------- + Tuple of (updated_group_state, accepted_array, total_expansions, total_contractions, total_evals). + """ + n_update = get_nwalkers(walkers_to_update.coords) + + key_dir, key_slice = jax.random.split(rng_key) + directions, _ = direction_fn( + key_dir, complementary_walkers.coords, n_update, walkers_to_update.mu + ) + + def logprob_only(x): + out = logdensity_fn(x) + return out[0] if isinstance(out, tuple) else out + + keys = jax.random.split(key_slice, n_update) + + def slice_one_walker(key, x0, logp0, direction): + return slice_along_direction( + key, x0, logp0, direction, logprob_only, maxsteps, maxiter + ) + + results = jax.vmap(slice_one_walker)( + keys, walkers_to_update.coords, walkers_to_update.log_probs, directions + ) + + ( + new_coords, + new_log_probs, + accepted_array, + nexp_array, + ncon_array, + neval_array, + ) = results + + total_nexp = jnp.sum(nexp_array) + total_ncon = jnp.sum(ncon_array) + total_neval = jnp.sum(neval_array) + + if walkers_to_update.blobs is not None: + _, new_blobs = vmapped_logdensity(logdensity_fn, new_coords) + else: + new_blobs = None + + updated_state = SliceEnsembleState( + new_coords, + new_log_probs, + new_blobs, + walkers_to_update.mu, + walkers_to_update.tuning_active, + walkers_to_update.patience_count, + ) + + return updated_state, accepted_array, total_nexp, total_ncon, total_neval + + +def as_top_level_api( + logdensity_fn: Callable, + move: str = "differential", + move_fn=None, + mu: float = 1.0, + has_blobs: bool = False, + randomize_split: bool = True, + nsplits: int = 2, + maxsteps: int = 10000, + maxiter: int = 10000, + tune: bool = True, + patience: int = 5, + tolerance: float = 0.05, +) -> SamplingAlgorithm: + """Ensemble slice sampling algorithm. + + Parameters + ---------- + logdensity_fn + Function that returns the log density at a given position. + move + Type of move: "differential", "random", or "gaussian". Ignored if move_fn provided. + move_fn + Optional custom move function. If None, uses the specified move type. + mu + Initial value of the scale parameter. + has_blobs + Whether the logdensity function returns additional information. + randomize_split + If True, randomly shuffle walker indices before splitting into groups. + nsplits + Number of groups to split the ensemble into. + maxsteps + Maximum steps for slice stepping-out procedure. + maxiter + Maximum iterations for shrinking procedure. + tune + Whether to enable adaptive tuning of mu. + patience + Number of steps within tolerance before stopping tuning. + tolerance + Tolerance for expansion/contraction ratio to stop tuning. + + Returns + ------- + A `SamplingAlgorithm`. + """ + kernel = build_kernel( + move=move, + move_fn=move_fn, + randomize_split=randomize_split, + nsplits=nsplits, + maxsteps=maxsteps, + maxiter=maxiter, + tune=tune, + patience=patience, + tolerance=tolerance, + ) + + def init_fn(position: ArrayTree, rng_key=None): + return init(position, logdensity_fn, has_blobs, mu) + + def step_fn( + rng_key: PRNGKey, state + ) -> tuple[SliceEnsembleState, SliceEnsembleInfo]: + return kernel(rng_key, state, logdensity_fn) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/ensemble/stretch.py b/blackjax/ensemble/stretch.py new file mode 100644 index 000000000..cbf7dde9b --- /dev/null +++ b/blackjax/ensemble/stretch.py @@ -0,0 +1,283 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stretch move ensemble sampler (affine-invariant MCMC).""" +from typing import Callable + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from blackjax.base import SamplingAlgorithm +from blackjax.ensemble.base import ( + EnsembleInfo, + EnsembleState, + build_states_from_triples, + complementary_triple, + concat_triple_groups, + get_nwalkers, + masked_select, + prepare_split, + unshuffle_triple, + vmapped_logdensity, +) +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey + +__all__ = [ + "init", + "build_kernel", + "as_top_level_api", + "stretch_move", +] + + +def stretch_move( + rng_key: PRNGKey, + walker_coords: ArrayTree, + complementary_coords: ArrayTree, + a: float = 2.0, +) -> tuple[ArrayTree, float]: + """Generate a proposal using the affine-invariant stretch move. + + The stretch move selects a random walker from the complementary ensemble + and proposes a new position along the line connecting the two walkers, + scaled by a random factor z drawn from g(z) ∝ 1/√z on [1/a, a]. + + Parameters + ---------- + rng_key + A PRNG key for random number generation. + walker_coords + The current walker's coordinates as an array or PyTree. + complementary_coords + The coordinates of the complementary ensemble with shape (n_walkers, ...) + where the leading dimension indexes walkers. + a + The stretch scale parameter. Must be > 1. Default is 2.0. + + Returns + ------- + A tuple (proposal, log_hastings_ratio) where proposal is the proposed + position with the same structure as walker_coords, and log_hastings_ratio + is (ndim - 1) * log(z). + """ + key_select, key_stretch = jax.random.split(rng_key) + + walker_flat, unravel_fn = ravel_pytree(walker_coords) + + comp_leaves, _ = jax.tree_util.tree_flatten(complementary_coords) + n_walkers_comp = comp_leaves[0].shape[0] + + idx = jax.random.randint(key_select, (), 0, n_walkers_comp) + complementary_walker = jax.tree_util.tree_map( + lambda x: x[idx], complementary_coords + ) + + complementary_walker_flat, _ = ravel_pytree(complementary_walker) + + z = ((a - 1.0) * jax.random.uniform(key_stretch) + 1) ** 2.0 / a + + proposal_flat = complementary_walker_flat + z * ( + walker_flat - complementary_walker_flat + ) + + n_dims = walker_flat.shape[0] + log_hastings_ratio = (n_dims - 1.0) * jnp.log(z) + + return unravel_fn(proposal_flat), log_hastings_ratio + + +def build_kernel( + move_fn=None, a: float = 2.0, randomize_split: bool = True, nsplits: int = 2 +) -> Callable: + """Build the stretch move kernel. + + Parameters + ---------- + move_fn + Optional custom move function. If None, uses stretch_move with parameter a. + a + The stretch parameter. Must be > 1. Default is 2.0. + randomize_split + If True, randomly shuffle walker indices before splitting into groups + each iteration. This improves mixing and matches emcee's default behavior. + nsplits + Number of groups to split the ensemble into. Default is 2 (red-blue). + """ + if move_fn is None: + move_fn = lambda key, w, c: stretch_move(key, w, c, a) + + def kernel( + rng_key: PRNGKey, state: EnsembleState, logdensity_fn: Callable + ) -> tuple[EnsembleState, EnsembleInfo]: + key_update, group_triples, indices = prepare_split( + rng_key, + state.coords, + state.log_probs, + state.blobs, + randomize_split, + nsplits, + ) + groups = build_states_from_triples(group_triples, EnsembleState) + + updated_groups = list(groups) + accepted_groups = [] + + keys = jax.random.split(key_update, nsplits) + for i in range(nsplits): + comp_triple = complementary_triple( + [(g.coords, g.log_probs, g.blobs) for g in updated_groups], i + ) + complementary = EnsembleState(*comp_triple) + + updated_group, accepted = _update_half( + keys[i], groups[i], complementary, logdensity_fn, move_fn + ) + updated_groups[i] = updated_group + accepted_groups.append(accepted) + + coords_cat, logp_cat, blobs_cat = concat_triple_groups( + [(g.coords, g.log_probs, g.blobs) for g in updated_groups] + ) + accepted_cat = jnp.concatenate(accepted_groups, axis=0) + + if randomize_split: + new_coords, new_log_probs, new_blobs = unshuffle_triple( + coords_cat, logp_cat, blobs_cat, indices + ) + accepted = accepted_cat[jnp.argsort(indices)] + else: + new_coords, new_log_probs, new_blobs = coords_cat, logp_cat, blobs_cat + accepted = accepted_cat + + new_state = EnsembleState(new_coords, new_log_probs, new_blobs) + acceptance_rate = jnp.mean(accepted) + info = EnsembleInfo(acceptance_rate, accepted) + + return new_state, info + + return kernel + + +def _update_half( + rng_key, walkers_to_update, complementary_walkers, logdensity_fn, move_fn +): + """Helper to update one half of the ensemble.""" + n_update = get_nwalkers(walkers_to_update.coords) + + key_moves, key_accept = jax.random.split(rng_key) + keys = jax.random.split(key_moves, n_update) + + proposals, log_hastings_ratios = jax.vmap( + lambda k, w_coords: move_fn(k, w_coords, complementary_walkers.coords) + )(keys, walkers_to_update.coords) + + log_probs_proposal, blobs_proposal = vmapped_logdensity(logdensity_fn, proposals) + + log_p_accept = ( + log_hastings_ratios + log_probs_proposal - walkers_to_update.log_probs + ) + + is_curr_fin = jnp.isfinite(walkers_to_update.log_probs) + is_prop_fin = jnp.isfinite(log_probs_proposal) + log_p_accept = jnp.where( + ~is_curr_fin & is_prop_fin, + jnp.inf, + jnp.where( + is_curr_fin & ~is_prop_fin, + -jnp.inf, + jnp.where(~is_curr_fin & ~is_prop_fin, -jnp.inf, log_p_accept), + ), + ) + + u = jax.random.uniform(key_accept, shape=(n_update,)) + accepted = jnp.log(u) < log_p_accept + + new_coords = jax.tree_util.tree_map( + lambda prop, old: masked_select(accepted, prop, old), + proposals, + walkers_to_update.coords, + ) + new_log_probs = jnp.where(accepted, log_probs_proposal, walkers_to_update.log_probs) + + if walkers_to_update.blobs is not None: + new_blobs = jax.tree_util.tree_map( + lambda prop, old: masked_select(accepted, prop, old), + blobs_proposal, + walkers_to_update.blobs, + ) + else: + new_blobs = None + + new_walkers = EnsembleState(new_coords, new_log_probs, new_blobs) + return new_walkers, accepted + + +def init( + position: ArrayLikeTree, + logdensity_fn: Callable, + has_blobs: bool = False, +) -> EnsembleState: + """Initialize the stretch move algorithm. + + Parameters + ---------- + position + Initial positions for all walkers, with shape (n_walkers, ...). + logdensity_fn + The log-density function to evaluate. + has_blobs + Whether the log-density function returns additional metadata (blobs). + """ + log_probs, blobs = vmapped_logdensity(logdensity_fn, position) + return EnsembleState(position, log_probs, blobs) + + +def as_top_level_api( + logdensity_fn: Callable, + a: float = 2.0, + has_blobs: bool = False, + randomize_split: bool = True, + nsplits: int = 2, +) -> SamplingAlgorithm: + """A user-facing API for the stretch move algorithm. + + Parameters + ---------- + logdensity_fn + A function that returns the log density of the model at a given position. + a + The stretch parameter. Must be > 1. Default is 2.0. + has_blobs + Whether the logdensity function returns additional information (blobs). + randomize_split + If True, randomly shuffle walker indices before splitting into groups + each iteration. This improves mixing and matches emcee's default behavior. + nsplits + Number of groups to split the ensemble into. Default is 2 (red-blue). + Each group is updated sequentially using all other groups as complementary. + + Returns + ------- + A `SamplingAlgorithm` that can be used to sample from the target distribution. + """ + move_fn = lambda key, w, c: stretch_move(key, w, c, a) + kernel = build_kernel(move_fn, randomize_split=randomize_split, nsplits=nsplits) + + def init_fn(position: ArrayTree, rng_key=None): + return init(position, logdensity_fn, has_blobs) + + def step_fn(rng_key: PRNGKey, state) -> tuple[EnsembleState, EnsembleInfo]: + return kernel(rng_key, state, logdensity_fn) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/tests/ensemble/test_slice.py b/tests/ensemble/test_slice.py new file mode 100644 index 000000000..3addc9517 --- /dev/null +++ b/tests/ensemble/test_slice.py @@ -0,0 +1,302 @@ +"""Test the ensemble slice sampling kernel.""" + +import chex +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +from absl.testing import absltest + +from blackjax.ensemble.slice import ( + SliceEnsembleInfo, + SliceEnsembleState, + as_top_level_api, + differential_direction, + gaussian_direction, + init, + random_direction, + slice_along_direction, +) + + +class EnsembleSliceTest(chex.TestCase): + """Test the ensemble slice sampling algorithm.""" + + def test_differential_direction(self): + """Test that differential_direction produces valid directions.""" + rng_key = jax.random.PRNGKey(0) + + # Complementary ensemble + complementary_coords = jnp.array([[0.0, 0.0], [2.0, 4.0], [3.0, 1.0]]) + n_update = 2 + mu = 1.0 + + directions, tune_once = differential_direction( + rng_key, complementary_coords, n_update, mu + ) + + # Check shape and properties + self.assertEqual(directions.shape, (n_update, 2)) + self.assertTrue(tune_once) + self.assertTrue(jnp.isfinite(directions).all()) + + def test_random_direction(self): + """Test that random_direction produces valid directions.""" + rng_key = jax.random.PRNGKey(0) + + # Template coordinates + template_coords = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + n_update = 2 + mu = 1.0 + + directions, tune_once = random_direction(rng_key, template_coords, n_update, mu) + + # Check shape and properties + self.assertEqual(directions.shape, template_coords.shape) + self.assertTrue(tune_once) + self.assertTrue(jnp.isfinite(directions).all()) + + def test_slice_along_direction_1d(self): + """Test slice sampling along a direction in 1D.""" + rng_key = jax.random.PRNGKey(42) + + # Simple 1D Gaussian + def logprob_fn(x): + return stats.norm.logpdf(x, 0.0, 1.0) + + x0 = 0.0 + logp0 = logprob_fn(x0) + direction = 1.0 + + x1, logp1, accepted, nexp, ncon, neval = slice_along_direction( + rng_key, x0, logp0, direction, logprob_fn, maxsteps=100, maxiter=1000 + ) + + # Check that we got valid results + self.assertTrue(jnp.isfinite(x1)) + self.assertTrue(jnp.isfinite(logp1)) + self.assertTrue(accepted) # Should have found a valid point + self.assertGreater(neval, 0) + + def test_init(self): + """Test initialization of ensemble slice sampler.""" + + def logdensity_fn(x): + return stats.norm.logpdf(x, 0.0, 1.0).sum() + + rng_key = jax.random.PRNGKey(0) + n_walkers = 10 + initial_position = jax.random.normal(rng_key, (n_walkers, 2)) + + state = init(initial_position, logdensity_fn, has_blobs=False, mu=1.0) + + # Check state properties + self.assertIsInstance(state, SliceEnsembleState) + self.assertEqual(state.coords.shape, (n_walkers, 2)) + self.assertEqual(state.log_probs.shape, (n_walkers,)) + self.assertIsNone(state.blobs) + self.assertEqual(state.mu, 1.0) + self.assertTrue(state.tuning_active) + self.assertEqual(state.patience_count, 0) + + def test_ensemble_slice_1d_gaussian(self): + """Test ensemble slice sampling on a 1D Gaussian distribution.""" + + # Define 1D Gaussian target + mu_true = 2.0 + sigma_true = 1.5 + + def logdensity_fn(x): + return stats.norm.logpdf(x.squeeze(), mu_true, sigma_true) + + rng_key = jax.random.PRNGKey(123) + init_key, sample_key = jax.random.split(rng_key) + + # Initialize with 20 walkers + n_walkers = 20 + initial_position = jax.random.normal(init_key, (n_walkers, 1)) + + # Create algorithm + algorithm = as_top_level_api( + logdensity_fn, move="differential", mu=1.0, maxsteps=100, maxiter=1000 + ) + initial_state = algorithm.init(initial_position) + + # Run a few steps + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, (new_state, info) + + keys = jax.random.split(sample_key, 100) + final_state, (states, infos) = jax.lax.scan(run_step, initial_state, keys) + + # Check that we get valid states + self.assertIsInstance(final_state, SliceEnsembleState) + self.assertEqual(final_state.coords.shape, (n_walkers, 1)) + self.assertEqual(final_state.log_probs.shape, (n_walkers,)) + + # Check info + self.assertIsInstance(infos, SliceEnsembleInfo) + # Slice sampling always accepts + self.assertTrue(jnp.all(infos.acceptance_rate == 1.0)) + self.assertTrue(jnp.all(infos.is_accepted)) + + # Check that evaluations are happening + self.assertTrue(jnp.all(infos.nevals > 0)) + + def test_ensemble_slice_2d_gaussian(self): + """Test ensemble slice sampling on a 2D Gaussian distribution.""" + + # Define 2D Gaussian target + mu = jnp.array([1.0, 2.0]) + cov = jnp.array([[1.0, 0.5], [0.5, 2.0]]) + + def logdensity_fn(x): + return stats.multivariate_normal.logpdf(x, mu, cov) + + rng_key = jax.random.PRNGKey(42) + init_key, sample_key = jax.random.split(rng_key) + + # Initialize ensemble + n_walkers = 20 + initial_position = jax.random.normal(init_key, (n_walkers, 2)) + + # Create algorithm + algorithm = as_top_level_api( + logdensity_fn, move="differential", mu=1.0, maxsteps=100, maxiter=1000 + ) + initial_state = algorithm.init(initial_position) + + # Run steps + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, new_state.coords + + keys = jax.random.split(sample_key, 200) + final_state, samples = jax.lax.scan(run_step, initial_state, keys) + + # Take second half as samples (burn-in) + samples = samples[100:] # Shape: (100, n_walkers, 2) + samples = samples.reshape(-1, 2) # Flatten to (100 * n_walkers, 2) + + # Check convergence (loose tolerance for quick test) + sample_mean = jnp.mean(samples, axis=0) + self.assertAlmostEqual(sample_mean[0].item(), mu[0], places=0) + self.assertAlmostEqual(sample_mean[1].item(), mu[1], places=0) + + def test_jit_compilation(self): + """Test that the algorithm can be JIT compiled.""" + + def logdensity_fn(x): + return stats.norm.logpdf(x.squeeze(), 0.0, 1.0) + + rng_key = jax.random.PRNGKey(0) + n_walkers = 10 + initial_position = jax.random.normal(rng_key, (n_walkers, 1)) + + algorithm = as_top_level_api(logdensity_fn, move="differential") + initial_state = algorithm.init(initial_position) + + # JIT compile step function + jitted_step = jax.jit(algorithm.step) + + # Run one step + key = jax.random.PRNGKey(1) + new_state, info = jitted_step(key, initial_state) + + # Check results are valid + self.assertIsInstance(new_state, SliceEnsembleState) + self.assertIsInstance(info, SliceEnsembleInfo) + + def test_gaussian_direction_array(self): + """Test that gaussian_direction produces valid directions for arrays.""" + rng_key = jax.random.PRNGKey(42) + + complementary_coords = jnp.array([[0.0, 0.0], [2.0, 4.0], [3.0, 1.0]]) + n_update = 2 + mu = 1.0 + + directions, tune_once = gaussian_direction( + rng_key, complementary_coords, n_update, mu + ) + + self.assertEqual(directions.shape, (n_update, 2)) + self.assertTrue(tune_once) + self.assertTrue(jnp.isfinite(directions).all()) + + def test_gaussian_direction_pytree(self): + """Test that gaussian_direction works with PyTree coordinates.""" + rng_key = jax.random.PRNGKey(43) + + complementary_coords = { + "x": jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + "y": jnp.array([[0.5], [1.5], [2.5]]), + } + n_update = 2 + mu = 1.0 + + directions, tune_once = gaussian_direction( + rng_key, complementary_coords, n_update, mu + ) + + self.assertIsInstance(directions, dict) + self.assertEqual(directions["x"].shape, (n_update, 2)) + self.assertEqual(directions["y"].shape, (n_update, 1)) + self.assertTrue(tune_once) + self.assertTrue(jnp.isfinite(directions["x"]).all()) + self.assertTrue(jnp.isfinite(directions["y"]).all()) + + def test_random_move(self): + """Test ensemble slice sampling with random move.""" + + def logdensity_fn(x): + return stats.norm.logpdf(x, 0.0, 1.0).sum() + + rng_key = jax.random.PRNGKey(99) + n_walkers = 10 + initial_position = jax.random.normal(rng_key, (n_walkers, 2)) + + algorithm = as_top_level_api( + logdensity_fn, move="random", mu=1.0, maxsteps=100, maxiter=1000 + ) + initial_state = algorithm.init(initial_position) + + keys = jax.random.split(rng_key, 10) + + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, info + + final_state, infos = jax.lax.scan(run_step, initial_state, keys) + + self.assertIsInstance(final_state, SliceEnsembleState) + self.assertTrue(jnp.all(infos.acceptance_rate == 1.0)) + + def test_gaussian_move(self): + """Test ensemble slice sampling with gaussian move.""" + + def logdensity_fn(x): + return stats.norm.logpdf(x, 0.0, 1.0).sum() + + rng_key = jax.random.PRNGKey(100) + n_walkers = 10 + initial_position = jax.random.normal(rng_key, (n_walkers, 2)) + + algorithm = as_top_level_api( + logdensity_fn, move="gaussian", mu=1.0, maxsteps=100, maxiter=1000 + ) + initial_state = algorithm.init(initial_position) + + keys = jax.random.split(rng_key, 10) + + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, info + + final_state, infos = jax.lax.scan(run_step, initial_state, keys) + + self.assertIsInstance(final_state, SliceEnsembleState) + self.assertTrue(jnp.all(infos.acceptance_rate >= 0.0)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/ensemble/test_stretch.py b/tests/ensemble/test_stretch.py new file mode 100644 index 000000000..4e8166ca2 --- /dev/null +++ b/tests/ensemble/test_stretch.py @@ -0,0 +1,138 @@ +"""Test the ensemble MCMC kernels.""" + +import chex +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +from absl.testing import absltest + +import blackjax +from blackjax.ensemble.base import EnsembleState +from blackjax.ensemble.stretch import stretch_move + + +class EnsembleTest(chex.TestCase): + """Test the ensemble MCMC algorithms.""" + + def test_stretch_move(self): + """Test that stretch_move produces valid proposals.""" + rng_key = jax.random.PRNGKey(0) + + # Simple 2D case + walker_coords = jnp.array([1.0, 2.0]) + complementary_coords = jnp.array([[0.0, 0.0], [2.0, 4.0], [3.0, 1.0]]) + + proposal, log_hastings_ratio = stretch_move( + rng_key, walker_coords, complementary_coords, a=2.0 + ) + + # Check shapes + self.assertEqual(proposal.shape, walker_coords.shape) + self.assertEqual(log_hastings_ratio.shape, ()) + + # Check that proposal is finite + self.assertTrue(jnp.isfinite(proposal).all()) + self.assertTrue(jnp.isfinite(log_hastings_ratio)) + + def test_stretch_move_pytree(self): + """Test that stretch_move works with PyTree structures.""" + rng_key = jax.random.PRNGKey(0) + + # PyTree case + walker_coords = {"a": jnp.array([1.0, 2.0]), "b": jnp.array(3.0)} + complementary_coords = { + "a": jnp.array([[0.0, 0.0], [2.0, 4.0], [3.0, 1.0]]), + "b": jnp.array([1.0, 2.0, 3.0]), + } + + proposal, log_hastings_ratio = stretch_move( + rng_key, walker_coords, complementary_coords, a=2.0 + ) + + # Check structure + self.assertEqual(set(proposal.keys()), {"a", "b"}) + self.assertEqual(proposal["a"].shape, walker_coords["a"].shape) + self.assertEqual(proposal["b"].shape, walker_coords["b"].shape) + self.assertEqual(log_hastings_ratio.shape, ()) + + def test_stretch_algorithm_2d_gaussian(self): + """Test the stretch algorithm on a 2D Gaussian distribution.""" + + # Define a 2D Gaussian target + mu = jnp.array([1.0, 2.0]) + cov = jnp.array([[1.0, 0.5], [0.5, 2.0]]) + + def logdensity_fn(x): + return stats.multivariate_normal.logpdf(x, mu, cov) + + # Initialize ensemble of 20 walkers + rng_key = jax.random.PRNGKey(42) + init_key, sample_key = jax.random.split(rng_key) + + n_walkers = 20 + initial_position = jax.random.normal(init_key, (n_walkers, 2)) + + # Create algorithm + algorithm = blackjax.ensemble(logdensity_fn, a=2.0) + initial_state = algorithm.init(initial_position) + + # Run a few steps + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, (new_state, info) + + keys = jax.random.split(sample_key, 100) + final_state, (states, infos) = jax.lax.scan(run_step, initial_state, keys) + + # Check that we get valid states + self.assertIsInstance(final_state, EnsembleState) + self.assertEqual(final_state.coords.shape, (n_walkers, 2)) + self.assertEqual(final_state.log_probs.shape, (n_walkers,)) + + # Check that acceptance rate is reasonable + mean_acceptance = jnp.mean(infos.acceptance_rate) + self.assertGreater(mean_acceptance, 0.1) # Should accept some proposals + self.assertLess(mean_acceptance, 0.9) # Should reject some proposals + + def test_stretch_algorithm_convergence(self): + """Test that the stretch algorithm converges to the correct distribution.""" + + # Simple 1D Gaussian + mu = 2.0 + sigma = 1.5 + + def logdensity_fn(x): + return stats.norm.logpdf(x.squeeze(), mu, sigma) + + rng_key = jax.random.PRNGKey(123) + init_key, sample_key = jax.random.split(rng_key) + + n_walkers = 50 + initial_position = jax.random.normal(init_key, (n_walkers, 1)) + + # Run algorithm + algorithm = blackjax.ensemble(logdensity_fn, a=2.0) + initial_state = algorithm.init(initial_position) + + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, new_state.coords + + keys = jax.random.split(sample_key, 1000) + final_state, samples = jax.lax.scan(run_step, initial_state, keys) + + # Take samples from the second half (burn-in) + samples = samples[500:] # Shape: (500, n_walkers, 1) + samples = samples.reshape(-1, 1) # Flatten to (500 * n_walkers, 1) + + # Check convergence + sample_mean = jnp.mean(samples) + sample_std = jnp.std(samples) + + # Allow for some tolerance due to finite sampling + self.assertAlmostEqual(sample_mean.item(), mu, places=1) + self.assertAlmostEqual(sample_std.item(), sigma, places=1) + + +if __name__ == "__main__": + absltest.main()