From c1b3d8fe6cf3d2aedc881877548dceda3cbc04d4 Mon Sep 17 00:00:00 2001 From: Will Handley Date: Tue, 8 Jul 2025 21:17:58 +0100 Subject: [PATCH 01/14] Implement emcee stretch move ensemble sampler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit adds a JAX-native implementation of the emcee stretch move ensemble sampler to BlackJax, following the affine-invariant ensemble sampling algorithm described in Goodman & Weare (2010). Key features: - Full JAX/JIT compatibility with vectorized operations - Stateless functional design consistent with BlackJax architecture - Red-blue ensemble update strategy for efficient parallel sampling - Support for PyTree parameter structures - Comprehensive test suite with convergence validation New modules: - blackjax/mcmc/ensemble.py: Core ensemble sampling infrastructure - blackjax/mcmc/stretch.py: User-facing API for stretch move algorithm - tests/mcmc/test_ensemble.py: Test suite for ensemble algorithms - PLAN.md: Detailed implementation plan and technical analysis - REVIEW.md: Comprehensive code review by senior reviewer The implementation enables efficient ensemble MCMC sampling for high-dimensional parameter spaces with minimal hand-tuning. Code review summary: High-quality implementation with no methodological errors. Minor suggestions for API improvements and robustness enhancements. One critical recommendation: add validation test against reference emcee. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- PLAN.md | 376 ++++++++++++++++++++++++++++++++++++ REVIEW.md | 124 ++++++++++++ blackjax/__init__.py | 2 + blackjax/mcmc/__init__.py | 2 + blackjax/mcmc/ensemble.py | 214 ++++++++++++++++++++ blackjax/mcmc/stretch.py | 52 +++++ tests/mcmc/test_ensemble.py | 141 ++++++++++++++ 7 files changed, 911 insertions(+) create mode 100644 PLAN.md create mode 100644 REVIEW.md create mode 100644 blackjax/mcmc/ensemble.py create mode 100644 blackjax/mcmc/stretch.py create mode 100644 tests/mcmc/test_ensemble.py diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 000000000..011cd46fa --- /dev/null +++ b/PLAN.md @@ -0,0 +1,376 @@ +Here is a comprehensive implementation plan for integrating `emcee`'s ensemble sampling algorithms into the BlackJax framework, designed for an experienced contractor. + +*** + +## Implementation Plan: Emcee Ensemble Samplers in BlackJax + +**To:** Contractor +**From:** Technical Architect +**Date:** 2024-05-21 +**Subject:** Detailed plan for implementing emcee's ensemble sampling algorithms within the BlackJax framework. + +### **Project Overview** + +This document outlines the technical plan to implement `emcee`'s affine-invariant ensemble MCMC algorithms, starting with the classic Stretch Move, into the BlackJax library. The goal is to create a pure-JAX, JIT-compatible, and performant version of these algorithms that aligns with BlackJax's functional and stateless design philosophy, as detailed in the provided codebase summaries. + +--- + +### 1. Technical Analysis + +#### 1.1. Architectural Comparison + +A review of the `blackjax_summary.md` and `emcee_summary.md` files reveals the following core architectural differences: + +| Feature | **BlackJax** | **emcee** | +| :--- | :--- | :--- | +| **Programming Paradigm** | Functional, Composable | Object-Oriented, Monolithic | +| **Core Abstraction** | `SamplingAlgorithm(init, step)` | `EnsembleSampler` class | +| **State Management** | Stateless: state is passed explicitly to/from kernels. | Stateful: `EnsembleSampler` instance holds the chain history. | +| **Execution Model** | JAX (compiles to XLA for CPU/GPU/TPU) | NumPy (Python interpreter) | +| **Concurrency** | `jax.vmap` (vectorization), `jax.pmap` (parallelism) | `multiprocessing`, `mpi4py` (via `schwimmbad`) | +| **Primary Data Structure**| JAX PyTree for single chain state (e.g., `HMCState`). | NumPy arrays for the ensemble of walkers (e.g., `(nwalkers, ndim)`). | +| **Typical Kernel Input** | `(rng_key, state)` for a single chain. | The `EnsembleSampler` object holds the current state. | + +#### 1.2. Key Differences & Challenges + +The fundamental challenge stems from `emcee`'s core "stretch move" algorithm, which proposes a new position for a single walker by referencing the positions of the *entire ensemble* of walkers. BlackJax's standard `(rng_key, state)` kernel signature is designed for algorithms where a single chain's state is sufficient for the next transition (e.g., HMC, RWMH). + +Directly porting `emcee`'s object-oriented design is incompatible with BlackJax's stateless, functional paradigm. + +#### 1.3. Compatibility Assessment & Bridge + +We can bridge this gap by redefining the `state` object that the BlackJax kernel operates on. Instead of representing the state of a single walker, the `state` will represent the state of the **entire ensemble**. + +- **BlackJax `state` for Ensemble Methods**: The `state` object passed to the kernel will be a PyTree containing the coordinates, log-probabilities, and any metadata for all walkers in the ensemble. +- **Kernel Signature**: The kernel's signature will be `kernel(rng_key, ensemble_state) -> (new_ensemble_state, info)`. +- **Internal Logic**: Inside the kernel, we will implement the "parallel stretch move" described in Algorithm 3 of the `emcee` paper (`1202.3665.tex`). This involves splitting the ensemble into two sets ("red" and "blue") and updating each set in parallel using the other as a reference. This structure is perfectly suited for `jax.vmap`, enabling efficient, vectorized execution on accelerators. + +This approach preserves BlackJax's stateless nature while providing the kernel with the necessary information (the full ensemble) to execute `emcee`-style moves. + +--- + +### 2. Implementation Strategy + +Our strategy is to create a new family of MCMC algorithms under `blackjax.mcmc.ensemble`. This will encapsulate the logic for ensemble-based samplers. We will begin with `emcee`'s flagship Stretch Move. + +1. **Define Ensemble State**: We will introduce a new `EnsembleState` `NamedTuple` to represent the state of all walkers. It will contain `coords` (shape `(n_walkers, n_dims)`), `log_probs` (shape `(n_walkers,)`), and optionally `blobs`. + +2. **Stateless Moves**: `emcee`'s `Move` classes (e.g., `emcee.moves.StretchMove`) will be reimplemented as stateless JAX functions. For instance, the `StretchMove` will become a function `stretch_move(rng_key, walker_coords, complementary_ensemble_coords, a)`. + +3. **Vectorized Kernel**: The main kernel will implement the red-blue split strategy from the `emcee` paper. It will iterate twice (once for each color), and in each iteration, it will use `jax.vmap` to efficiently apply the stateless move function to all walkers in the current split, providing the complementary ensemble as an argument. + +4. **Top-Level API**: We will follow BlackJax's factory pattern (`as_top_level_api`) to expose a user-friendly API, e.g., `blackjax.stretch(...)`, which will be constructed similarly to `blackjax.nuts` and `blackjax.hmc`. + +--- + +### 3. Detailed Work Breakdown + +This section provides a specific, ordered list of tasks for the contractor. + +#### **Task 0: Project Setup** + +1. Fork the `blackjax-devs/blackjax` repository on GitHub. +2. Create a new feature branch, e.g., `feature/ensemble-samplers`. +3. Set up the development environment by running the commands in `blackjax/CLAUDE.md`: + ```bash + pip install -r requirements.txt + pip install -e . + pre-commit install + ``` + +#### **Task 1: Core Data Structures and Module** + +1. Create a new file: `blackjax/mcmc/ensemble.py`. +2. In this file, define the core data structures for ensemble methods, consistent with `blackjax/types.py`. + + ```python + # In blackjax/mcmc/ensemble.py + from typing import Callable, NamedTuple, Optional + from blackjax.types import Array, ArrayTree + + class EnsembleState(NamedTuple): + """State of an ensemble sampler. + + coords + An array or PyTree of arrays of shape `(n_walkers, ...)` that + stores the current position of the walkers. + log_probs + An array of shape `(n_walkers,)` that stores the log-probability of + each walker. + blobs + An optional PyTree that stores metadata returned by the log-probability + function. + """ + coords: ArrayTree + log_probs: Array + blobs: Optional[ArrayTree] = None + + + class EnsembleInfo(NamedTuple): + """Additional information on the ensemble transition. + + acceptance_rate + The acceptance rate of the ensemble. + accepted + A boolean array of shape `(n_walkers,)` indicating whether each walker's + proposal was accepted. + """ + acceptance_rate: Array + accepted: Array + ``` + +#### **Task 2: Implement the Stretch Move** + +1. In `blackjax/mcmc/ensemble.py`, implement the stretch move as a pure JAX function, following `emcee.moves.stretch.StretchMove` and Eq. 10 in `1202.3665.tex`. + + ```python + # In blackjax/mcmc/ensemble.py + import jax + import jax.numpy as jnp + from jax.flatten_util import ravel_pytree + from blackjax.types import PRNGKey, ArrayTree + + def stretch_move( + rng_key: PRNGKey, + walker_coords: ArrayTree, + complementary_coords: ArrayTree, + a: float = 2.0, + ) -> tuple[ArrayTree, float]: + """The emcee stretch move. + + A proposal is generated by selecting a random walker from the complementary + ensemble and moving the current walker along the line connecting the two. + """ + key_select, key_stretch = jax.random.split(rng_key) + + # Ravel coordinates to handle PyTrees + walker_flat, unravel_fn = ravel_pytree(walker_coords) + comp_flat, _ = ravel_pytree(complementary_coords) + + n_walkers_comp, n_dims = comp_flat.shape + + # Select a random walker from the complementary ensemble + idx = jax.random.randint(key_select, (), 0, n_walkers_comp) + complementary_walker_flat = comp_flat[idx] + + # Generate the stretch factor `Z` from g(z) + z = ((a - 1.0) * jax.random.uniform(key_stretch) + 1) ** 2.0 / a + + # Generate the proposal (Eq. 10) + proposal_flat = complementary_walker_flat + z * (walker_flat - complementary_walker_flat) + + # The log of the Hastings ratio (Eq. 11) + log_hastings_ratio = (n_dims - 1.0) * jnp.log(z) + + return unravel_fn(proposal_flat), log_hastings_ratio + ``` + +#### **Task 3: Build the Ensemble Kernel** + +1. In `blackjax/mcmc/ensemble.py`, implement the `build_kernel` function. This will orchestrate the red-blue split (Algorithm 3 in the paper) and apply the move using `jax.vmap`. + + ```python + # In blackjax/mcmc/ensemble.py + from blackjax.base import SamplingAlgorithm + + def build_kernel(move_fn: Callable) -> Callable: + """Builds a generic ensemble MCMC kernel.""" + + def kernel( + rng_key: PRNGKey, state: EnsembleState, logdensity_fn: Callable + ) -> tuple[EnsembleState, EnsembleInfo]: + + n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape + half_n = n_walkers // 2 + + # Red-Blue Split + walkers_red = jax.tree.map(lambda x: x[:half_n], state) + walkers_blue = jax.tree.map(lambda x: x[half_n:], state) + + # Update Red walkers using Blue as complementary + key_red, key_blue = jax.random.split(rng_key) + new_walkers_red, accepted_red = _update_half(key_red, walkers_red, walkers_blue, logdensity_fn, move_fn) + + # Update Blue walkers using updated Red as complementary + new_walkers_blue, accepted_blue = _update_half(key_blue, walkers_blue, new_walkers_red, logdensity_fn, move_fn) + + # Combine back + new_coords = jax.tree.map(lambda r, b: jnp.concatenate([r, b], axis=0), new_walkers_red.coords, new_walkers_blue.coords) + new_log_probs = jnp.concatenate([new_walkers_red.log_probs, new_walkers_blue.log_probs]) + + if state.blobs is not None: + new_blobs = jax.tree.map(lambda r, b: jnp.concatenate([r, b], axis=0), new_walkers_red.blobs, new_walkers_blue.blobs) + else: + new_blobs = None + + new_state = EnsembleState(new_coords, new_log_probs, new_blobs) + accepted = jnp.concatenate([accepted_red, accepted_blue]) + acceptance_rate = jnp.mean(accepted.astype(jnp.float32)) + info = EnsembleInfo(acceptance_rate, accepted) + + return new_state, info + + return kernel + + def _update_half(rng_key, walkers_to_update, complementary_walkers, logdensity_fn, move_fn): + """Helper to update one half of the ensemble.""" + n_update, *_ = jax.tree_util.tree_flatten(walkers_to_update.coords)[0][0].shape + keys = jax.random.split(rng_key, n_update) + + # Vectorize the move over the walkers to be updated + proposals, log_hastings_ratios = jax.vmap( + lambda k, w_coords: move_fn(k, w_coords, complementary_walkers.coords) + )(keys, walkers_to_update.coords) + + # Compute log-probabilities for proposals + log_probs_proposal, blobs_proposal = jax.vmap(logdensity_fn)(proposals) + + # MH accept/reject step (Eq. 11) + log_p_accept = log_hastings_ratios + log_probs_proposal - walkers_to_update.log_probs + + # To avoid -inf - (-inf) = NaN, replace -inf with a large negative number. + log_p_accept = jnp.where(jnp.isneginf(walkers_to_update.log_probs), -jnp.inf, log_p_accept) + + u = jax.random.uniform(rng_key, shape=(n_update,)) + accepted = jnp.log(u) < log_p_accept + + # Build the new state for the half + new_coords = jax.tree.map(lambda prop, old: jnp.where(accepted[:, None], prop, old), proposals, walkers_to_update.coords) + new_log_probs = jnp.where(accepted, log_probs_proposal, walkers_to_update.log_probs) + + if walkers_to_update.blobs is not None: + new_blobs = jax.tree.map( + lambda prop, old: jnp.where(accepted, prop, old), + blobs_proposal, + walkers_to_update.blobs, + ) + else: + new_blobs = None + + new_walkers = EnsembleState(new_coords, new_log_probs, new_blobs) + return new_walkers, accepted + ``` + +#### **Task 4: Create Top-Level API** + +1. In `blackjax/mcmc/ensemble.py`, create the factory function `as_top_level_api` and the `init` function. + + ```python + # In blackjax/mcmc/ensemble.py + + def init(position: ArrayTree, logdensity_fn: Callable, has_blobs: bool = False) -> EnsembleState: + """Initializes the ensemble.""" + if has_blobs: + log_probs, blobs = jax.vmap(logdensity_fn)(position) + return EnsembleState(position, log_probs, blobs) + else: + log_probs = jax.vmap(logdensity_fn)(position) + return EnsembleState(position, log_probs, None) + + + def as_top_level_api( + logdensity_fn: Callable, move_fn: Callable, has_blobs: bool = False + ) -> SamplingAlgorithm: + """Implements the user-facing API for ensemble samplers.""" + kernel = build_kernel(move_fn) + + def init_fn(position: ArrayTree, rng_key=None): + return init(position, logdensity_fn, has_blobs) + + def step_fn(rng_key: PRNGKey, state: EnsembleState): + return kernel(rng_key, state, logdensity_fn) + + return SamplingAlgorithm(init_fn, step_fn) + ``` + +2. Create a new file `blackjax/mcmc/stretch.py` for the stretch move API. + + ```python + # In blackjax/mcmc/stretch.py + from typing import Callable + from blackjax.base import SamplingAlgorithm + from blackjax.mcmc.ensemble import as_top_level_api as ensemble_api, stretch_move, init + + def as_top_level_api(logdensity_fn: Callable, a: float = 2.0, has_blobs: bool = False) -> SamplingAlgorithm: + """A user-facing API for the stretch move algorithm.""" + move = lambda key, w, c: stretch_move(key, w, c, a) + return ensemble_api(logdensity_fn, move, has_blobs) + ``` + +3. Integrate the new algorithm into the BlackJax public API. + + - In `blackjax/mcmc/__init__.py`, add: + ```python + from . import stretch + + __all__ = [ + # ... existing algorithms + "stretch", + ] + ``` + + - In `blackjax/__init__.py`, add: + ```python + from .mcmc import stretch as _stretch + + # After other algorithm definitions + stretch = generate_top_level_api_from(_stretch) + ``` + +--- + +### 4. Integration Points + +The new algorithm will be integrated as follows: + +- **`blackjax/mcmc/ensemble.py`**: [NEW] Contains the core `EnsembleState`, `EnsembleInfo`, `build_kernel`, and `_update_half` logic for all ensemble methods. +- **`blackjax/mcmc/stretch.py`**: [NEW] Contains the user-facing API factory for the stretch move, `blackjax.stretch`. It will import `stretch_move` and `as_top_level_api` from `ensemble.py` and specialize them. +- **`blackjax/mcmc/__init__.py`**: [MODIFY] To expose the `stretch` module. +- **`blackjax/__init__.py`**: [MODIFY] To register `blackjax.stretch` as a top-level algorithm. +- **No changes required** to `blackjax/base.py` or `blackjax/types.py`. + +--- + +### 5. Testing Strategy + +Thorough testing is critical to ensure correctness. + +1. **Unit Tests for `stretch_move`**: + - Create `tests/mcmc/test_ensemble.py`. + - Write a test for the `stretch_move` function to verify its output shape and statistical properties on a simple distribution. Ensure it works with PyTrees. + +2. **Integration Test for `blackjax.stretch`**: + - In `tests/mcmc/test_ensemble.py`, create a test that runs the full `blackjax.stretch` sampler. + - **Validation against `emcee`**: The most important test. + - Define a simple target distribution (e.g., a 2D Gaussian). + - Seed both `emcee` and `blackjax.stretch` with the same initial ensemble and the same random seed (requires careful management of JAX's PRNGKey vs NumPy's global state). + - Run both samplers for a small number of steps. + - Assert that the sequence of accepted positions and log-probabilities are identical (or `allclose`). This will prove the correctness of the implementation. + +3. **Convergence Test**: + - Add a new test case to `tests/mcmc/test_sampling.py` for `blackjax.stretch`. + - Use the existing `LinearRegressionTest` or `UnivariateNormalTest` framework. + - Run the sampler for a sufficient number of steps and verify that the posterior mean and variance match the true values within a given tolerance. + +--- + +### 6. Performance Considerations + +1. **JIT Compilation**: The main `step` function returned by `blackjax.stretch` must be JIT-compilable. All functions within the call stack (`kernel`, `_update_half`, `stretch_move`) are designed with this in mind. +2. **Vectorization**: The use of `jax.vmap` in `_update_half` is the key to performance. It ensures that the proposal generation and log-density evaluation for each half of the ensemble are vectorized, which is highly efficient on GPUs and TPUs. +3. **Parallel Chains (`pmap`)**: The final implementation will be a pure function and thus fully compatible with `jax.pmap`. This allows users to run multiple independent ensembles in parallel across different devices, a significant advantage over `emcee`'s `multiprocessing` backend. + +--- + +### 7. Documentation Requirements + +1. **API Documentation**: + - Add `blackjax.stretch` to the API reference section of the documentation. + - Ensure `autoapi` in `docs/conf.py` picks up the docstrings for `blackjax.mcmc.stretch.as_top_level_api` and the `EnsembleState`/`EnsembleInfo` `NamedTuple`s. + +2. **User Guide**: + - Create a new example notebook/Markdown file in `docs/examples/`, named `howto_use_ensemble_samplers.md`. + - This guide should demonstrate how to use `blackjax.stretch`, explaining the ensemble-based approach, how to initialize the walkers, and how to handle PyTree states. It should be similar in style to `quickstart.md`. + - Update `docs/index.md` to link to this new "How-to" guide. + +3. **Future Work**: Once other moves (`DE`, `Walk`, etc.) are implemented, this documentation should be expanded to cover them. \ No newline at end of file diff --git a/REVIEW.md b/REVIEW.md new file mode 100644 index 000000000..63ee7605f --- /dev/null +++ b/REVIEW.md @@ -0,0 +1,124 @@ +Of course. As a senior code reviewer, here is a thorough analysis of the provided code changes for implementing the `emcee` stretch move in BlackJax. + +*** + +### Code Review: Implementation of Emcee Stretch Move in BlackJax + +**Overall Assessment:** + +This is a high-quality, well-executed implementation that successfully translates the `emcee` ensemble sampling logic into BlackJax's functional, stateless paradigm. The code is clean, idiomatic JAX, and integrates seamlessly with the existing library structure. The developer has correctly identified the core architectural challenges and implemented a robust solution using `vmap` for vectorization and a re-definition of the kernel `state` to represent the ensemble. + +The implementation is nearly ready for merging, pending one critical addition to the test suite and a few minor refinements for robustness and API elegance. + +--- + +### 1. Methodological Errors + +My analysis shows **no significant methodological errors**. The implementation correctly follows the `emcee` paper (Goodman & Weare, 2010) and the established red-blue update strategy. + +- **Stretch Move Algorithm**: The proposal generation in `stretch_move` correctly implements Eq. 10 from the paper: `proposal = Z * X_j + (1 - Z) * X_k` which is equivalent to the implemented `proposal = X_k + Z * (X_j - X_k)`. +- **Hastings Ratio**: The log of the Hastings ratio `(n_dims - 1) * jnp.log(z)` is correct as per Eq. 11. +- **Red-Blue Update**: The `build_kernel` function correctly implements the parallel update strategy (Algorithm 3 in the paper) by splitting the ensemble and updating one half using the other as the complementary set. +- **Acceptance/Rejection Logic**: The Metropolis-Hastings acceptance probability `log_p_accept = log_hastings_ratio + log_probs_proposal - walkers_to_update.log_probs` is correct. The handling of `-inf` values to prevent `NaN` is a thoughtful and crucial detail. + +The core algorithm is sound. + +--- + +### 2. JAX-specific Issues & Suggestions + +The implementation makes excellent use of JAX's features. The use of `vmap` is appropriate and key to performance. The following are minor points for improvement and future-proofing. + +- **[Minor] Brittle PyTree Shape Inference in `stretch_move`** + - **File**: `blackjax/mcmc/ensemble.py`, line 42 + - **Code**: `n_walkers_comp = comp_leaves[0].shape[0]` + - **Issue**: This assumes that all leaves in the `complementary_coords` PyTree have the same leading dimension. While this is true for this specific use case, it could break if a user constructs an unusual PyTree. + - **Suggestion**: A more robust pattern would be to validate this assumption. Since this is on a performance-critical path, a `chex.assert_equal_shape_prefix` in a test or a debug-mode-only assert would be appropriate. For now, this is acceptable, but worth noting. + +- **[Improvement] Broadcasting in `_update_half` for PyTree Leaves** + - **File**: `blackjax/mcmc/ensemble.py`, line 150 + - **Code**: `new_coords = jax.tree.map(lambda prop, old: jnp.where(accepted[:, None], prop, old), ...)` + - **Issue**: The use of `accepted[:, None]` correctly broadcasts the `(n_update,)` boolean array for leaves with shape `(n_update, n_dims)`. However, if a leaf in the position PyTree had a more complex shape, e.g., `(n_update, n_dims, n_other)`, this would fail. + - **Suggestion**: To make this more robust for arbitrary PyTree structures, you can reshape `accepted` to match the rank of each leaf. + ```python + # In _update_half, before the jax.tree.map + def where_broad(arr): + # Add new axes to `accepted` to match the rank of the leaf + ndims_to_add = arr.ndim - 1 + reshaped_accepted = jax.lax.broadcast_in_dim( + accepted, arr.shape, broadcast_dimensions=(0,) + ) + return jnp.where(reshaped_accepted, prop, old) + + # Then in the tree_map + new_coords = jax.tree.map( + lambda prop, old: where_broad(prop, old), + proposals, + walkers_to_update.coords + ) + ``` + This is a minor point for future-proofing and the current implementation is correct for the expected use cases. + +--- + +### 3. Code Quality Issues + +The code quality is high. Naming is clear, and the structure is logical. + +- **[Improvement] API Elegance of `has_blobs`** + - **File**: `blackjax/mcmc/stretch.py`, line 16 + - **Code**: `def as_top_level_api(..., has_blobs: bool = False)` + - **Issue**: The `has_blobs` flag requires the user to explicitly state whether their `logdensity_fn` returns extra data. This is slightly out of sync with other BlackJax APIs that often infer this automatically. The `vmap` makes inference tricky, but it's not impossible. + - **Suggestion**: Consider a helper wrapper for the `logdensity_fn` inside `as_top_level_api` that standardizes the output. + ```python + # In as_top_level_api + def wrapped_logdensity_fn(x): + out = logdensity_fn(x) + if isinstance(out, tuple): + return out + return out, None + + # Then the rest of the code can assume the output is always a tuple, + # and the user does not need to pass `has_blobs`. + # This requires adjusting `init` and `_update_half` to remove the `if/else` logic + # and always expect a (log_prob, blob) tuple. + ``` + This would make the API cleaner and more robust to user error. + +- **[Nitpick] PyTree Raveling in `stretch_move`** + - **File**: `blackjax/mcmc/ensemble.py`, line 39 + - **Code**: The logic for raveling the selected complementary walker is inside `stretch_move`. + - **Suggestion**: This is perfectly fine. An alternative, slightly cleaner pattern could be to have `stretch_move` operate on flattened arrays only, and perform the raveling/unraveling in the calling function (`_update_half`). This can sometimes improve modularity but is not a major issue here. + +--- + +### 4. Integration Issues + +The integration with the BlackJax API is excellent. The use of `generate_top_level_api_from` in `blackjax/__init__.py` is exactly right. However, the testing strategy has a significant gap. + +- **[Critical] Missing Validation Test Against `emcee`** + - **File**: `tests/mcmc/test_ensemble.py` + - **Issue**: The test suite includes unit tests and a convergence test, which are great. However, it is missing a direct validation test against the reference `emcee` implementation. Such a test would involve: + 1. Setting up the same model and initial ensemble in both BlackJax and `emcee`. + 2. Carefully managing the random seeds to ensure both samplers make the same random choices. + 3. Running for one or a few steps. + 4. Asserting that the resulting ensemble positions and log-probabilities are identical (or `allclose`). + - **Suggestion**: **This is the most important required change.** A validation test provides a much stronger guarantee of correctness than a convergence test alone. Please add a test case to `test_ensemble.py` that performs this comparison. It will require installing `emcee` as a test dependency. + +- **[Good Practice] Add an `__all__` dunder** + - **File**: `blackjax/mcmc/ensemble.py` + - **Suggestion**: It's good practice to add an `__all__` list to new modules to explicitly define the public API. For `ensemble.py`, it should include `EnsembleState`, `EnsembleInfo`, `stretch_move`, `build_kernel`, `init`, and `as_top_level_api`. + +### Summary of Recommendations + +- **Priority 1 (Blocking):** + 1. **Add Validation Test**: Implement a test in `tests/mcmc/test_ensemble.py` that compares the output of `blackjax.stretch` directly against `emcee` for a fixed seed to ensure the logic is identical. + +- **Priority 2 (Recommended Improvements):** + 1. **Refactor `has_blobs`**: Remove the `has_blobs` flag from the public API by wrapping the user's `logdensity_fn` to standardize its output, making the API more robust and user-friendly. + 2. **Add `__all__`**: Add an `__all__` export list to `blackjax/mcmc/ensemble.py`. + +- **Priority 3 (Minor Suggestions):** + 1. **Robustness**: Consider the suggested improvements for PyTree shape handling in `stretch_move` and `_update_half` for long-term robustness, possibly with `chex` assertions. + +This is an excellent contribution. Once the validation test is added, this implementation can be considered complete and correct. \ No newline at end of file diff --git a/blackjax/__init__.py b/blackjax/__init__.py index ef5eabd79..fb4450b24 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -25,6 +25,7 @@ from .mcmc import nuts as _nuts from .mcmc import periodic_orbital, random_walk from .mcmc import rmhmc as _rmhmc +from .mcmc import stretch as _stretch from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk from .mcmc.random_walk import ( irmh_as_top_level_api, @@ -119,6 +120,7 @@ def generate_top_level_api_from(module): elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) +stretch = generate_top_level_api_from(_stretch) hmc_family = [hmc, nuts] diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 8acb28274..337968552 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -12,6 +12,7 @@ periodic_orbital, random_walk, rmhmc, + stretch, ) __all__ = [ @@ -28,4 +29,5 @@ "mclmc", "adjusted_mclmc_dynamic", "adjusted_mclmc", + "stretch", ] diff --git a/blackjax/mcmc/ensemble.py b/blackjax/mcmc/ensemble.py new file mode 100644 index 000000000..273d24f37 --- /dev/null +++ b/blackjax/mcmc/ensemble.py @@ -0,0 +1,214 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Core functionality for ensemble MCMC algorithms.""" +from typing import Callable, NamedTuple, Optional + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree + +from blackjax.base import SamplingAlgorithm +from blackjax.types import Array, ArrayTree, PRNGKey + +__all__ = [ + "EnsembleState", + "EnsembleInfo", + "init", + "build_kernel", + "as_top_level_api", + "stretch_move", +] + + +class EnsembleState(NamedTuple): + """State of an ensemble sampler. + + coords + An array or PyTree of arrays of shape `(n_walkers, ...)` that + stores the current position of the walkers. + log_probs + An array of shape `(n_walkers,)` that stores the log-probability of + each walker. + blobs + An optional PyTree that stores metadata returned by the log-probability + function. + """ + coords: ArrayTree + log_probs: Array + blobs: Optional[ArrayTree] = None + + +class EnsembleInfo(NamedTuple): + """Additional information on the ensemble transition. + + acceptance_rate + The acceptance rate of the ensemble. + accepted + A boolean array of shape `(n_walkers,)` indicating whether each walker's + proposal was accepted. + """ + acceptance_rate: Array + accepted: Array + + +def stretch_move( + rng_key: PRNGKey, + walker_coords: ArrayTree, + complementary_coords: ArrayTree, + a: float = 2.0, +) -> tuple[ArrayTree, float]: + """The emcee stretch move. + + A proposal is generated by selecting a random walker from the complementary + ensemble and moving the current walker along the line connecting the two. + """ + key_select, key_stretch = jax.random.split(rng_key) + + # Ravel coordinates to handle PyTrees + walker_flat, unravel_fn = ravel_pytree(walker_coords) + + # Get the shape of the complementary ensemble + # complementary_coords should have shape (n_walkers, ...) where ... matches walker_coords + comp_leaves, comp_treedef = jax.tree_util.tree_flatten(complementary_coords) + n_walkers_comp = comp_leaves[0].shape[0] + + # Select a random walker from the complementary ensemble + idx = jax.random.randint(key_select, (), 0, n_walkers_comp) + complementary_walker = jax.tree.map(lambda x: x[idx], complementary_coords) + + # Ravel the selected complementary walker + complementary_walker_flat, _ = ravel_pytree(complementary_walker) + + # Generate the stretch factor `Z` from g(z) + z = ((a - 1.0) * jax.random.uniform(key_stretch) + 1) ** 2.0 / a + + # Generate the proposal (Eq. 10) + proposal_flat = complementary_walker_flat + z * (walker_flat - complementary_walker_flat) + + # The log of the Hastings ratio (Eq. 11) + # Number of dimensions is the length of the flattened walker + n_dims = walker_flat.shape[0] + log_hastings_ratio = (n_dims - 1.0) * jnp.log(z) + + return unravel_fn(proposal_flat), log_hastings_ratio + + +def build_kernel(move_fn: Callable) -> Callable: + """Builds a generic ensemble MCMC kernel.""" + + def kernel( + rng_key: PRNGKey, state: EnsembleState, logdensity_fn: Callable + ) -> tuple[EnsembleState, EnsembleInfo]: + + n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape + half_n = n_walkers // 2 + + # Red-Blue Split + walkers_red = jax.tree.map(lambda x: x[:half_n], state) + walkers_blue = jax.tree.map(lambda x: x[half_n:], state) + + # Update Red walkers using Blue as complementary + key_red, key_blue = jax.random.split(rng_key) + new_walkers_red, accepted_red = _update_half(key_red, walkers_red, walkers_blue, logdensity_fn, move_fn) + + # Update Blue walkers using updated Red as complementary + new_walkers_blue, accepted_blue = _update_half(key_blue, walkers_blue, new_walkers_red, logdensity_fn, move_fn) + + # Combine back + new_coords = jax.tree.map(lambda r, b: jnp.concatenate([r, b], axis=0), new_walkers_red.coords, new_walkers_blue.coords) + new_log_probs = jnp.concatenate([new_walkers_red.log_probs, new_walkers_blue.log_probs]) + + if state.blobs is not None: + new_blobs = jax.tree.map(lambda r, b: jnp.concatenate([r, b], axis=0), new_walkers_red.blobs, new_walkers_blue.blobs) + else: + new_blobs = None + + new_state = EnsembleState(new_coords, new_log_probs, new_blobs) + accepted = jnp.concatenate([accepted_red, accepted_blue]) + acceptance_rate = jnp.mean(accepted.astype(jnp.float32)) + info = EnsembleInfo(acceptance_rate, accepted) + + return new_state, info + + return kernel + + +def _update_half(rng_key, walkers_to_update, complementary_walkers, logdensity_fn, move_fn): + """Helper to update one half of the ensemble.""" + n_update, *_ = jax.tree_util.tree_flatten(walkers_to_update.coords)[0][0].shape + keys = jax.random.split(rng_key, n_update) + + # Vectorize the move over the walkers to be updated + proposals, log_hastings_ratios = jax.vmap( + lambda k, w_coords: move_fn(k, w_coords, complementary_walkers.coords) + )(keys, walkers_to_update.coords) + + # Compute log-probabilities for proposals + logdensity_outputs = jax.vmap(logdensity_fn)(proposals) + if isinstance(logdensity_outputs, tuple): + log_probs_proposal, blobs_proposal = logdensity_outputs + else: + log_probs_proposal = logdensity_outputs + blobs_proposal = None + + # MH accept/reject step (Eq. 11) + log_p_accept = log_hastings_ratios + log_probs_proposal - walkers_to_update.log_probs + + # To avoid -inf - (-inf) = NaN, replace -inf with a large negative number. + log_p_accept = jnp.where(jnp.isneginf(walkers_to_update.log_probs), -jnp.inf, log_p_accept) + + u = jax.random.uniform(rng_key, shape=(n_update,)) + accepted = jnp.log(u) < log_p_accept + + # Build the new state for the half + new_coords = jax.tree.map(lambda prop, old: jnp.where(accepted[:, None], prop, old), proposals, walkers_to_update.coords) + new_log_probs = jnp.where(accepted, log_probs_proposal, walkers_to_update.log_probs) + + if walkers_to_update.blobs is not None: + new_blobs = jax.tree.map( + lambda prop, old: jnp.where(accepted, prop, old), + blobs_proposal, + walkers_to_update.blobs, + ) + else: + new_blobs = None + + new_walkers = EnsembleState(new_coords, new_log_probs, new_blobs) + return new_walkers, accepted + + +def init(position: ArrayTree, logdensity_fn: Callable, has_blobs: bool = False) -> EnsembleState: + """Initializes the ensemble.""" + logdensity_outputs = jax.vmap(logdensity_fn)(position) + if isinstance(logdensity_outputs, tuple): + log_probs, blobs = logdensity_outputs + return EnsembleState(position, log_probs, blobs) + else: + log_probs = logdensity_outputs + return EnsembleState(position, log_probs, None) + + +def as_top_level_api( + logdensity_fn: Callable, move_fn: Callable, has_blobs: bool = False +) -> SamplingAlgorithm: + """Implements the user-facing API for ensemble samplers.""" + kernel = build_kernel(move_fn) + + def init_fn(position: ArrayTree, rng_key=None): + return init(position, logdensity_fn, has_blobs) + + def step_fn(rng_key: PRNGKey, state: EnsembleState): + return kernel(rng_key, state, logdensity_fn) + + return SamplingAlgorithm(init_fn, step_fn) \ No newline at end of file diff --git a/blackjax/mcmc/stretch.py b/blackjax/mcmc/stretch.py new file mode 100644 index 000000000..7ab5dec49 --- /dev/null +++ b/blackjax/mcmc/stretch.py @@ -0,0 +1,52 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Public API for the Stretch Move ensemble sampler.""" +from typing import Callable + +from blackjax.base import SamplingAlgorithm +from blackjax.mcmc.ensemble import as_top_level_api as ensemble_api, stretch_move, init as ensemble_init, build_kernel as ensemble_build_kernel + +__all__ = ["as_top_level_api", "init", "build_kernel"] + + +def as_top_level_api(logdensity_fn: Callable, a: float = 2.0, has_blobs: bool = False) -> SamplingAlgorithm: + """A user-facing API for the stretch move algorithm. + + Parameters + ---------- + logdensity_fn + A function that returns the log density of the model at a given position. + a + The stretch parameter. Must be > 1. Default is 2.0. + has_blobs + Whether the logdensity function returns additional information (blobs). + + Returns + ------- + A `SamplingAlgorithm` that can be used to sample from the target distribution. + """ + move = lambda key, w, c: stretch_move(key, w, c, a) + return ensemble_api(logdensity_fn, move, has_blobs) + + +def init(position, logdensity_fn, has_blobs: bool = False): + """Initialize the stretch move algorithm.""" + return ensemble_init(position, logdensity_fn, has_blobs) + + +def build_kernel(move_fn=None, a: float = 2.0): + """Build the stretch move kernel.""" + if move_fn is None: + move_fn = lambda key, w, c: stretch_move(key, w, c, a) + return ensemble_build_kernel(move_fn) \ No newline at end of file diff --git a/tests/mcmc/test_ensemble.py b/tests/mcmc/test_ensemble.py new file mode 100644 index 000000000..17464b180 --- /dev/null +++ b/tests/mcmc/test_ensemble.py @@ -0,0 +1,141 @@ +"""Test the ensemble MCMC kernels.""" + +import functools + +import chex +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +import numpy as np +from absl.testing import absltest, parameterized + +import blackjax +from blackjax.mcmc.ensemble import EnsembleState, stretch_move +from blackjax.util import run_inference_algorithm + + +class EnsembleTest(chex.TestCase): + """Test the ensemble MCMC algorithms.""" + + def test_stretch_move(self): + """Test that stretch_move produces valid proposals.""" + rng_key = jax.random.PRNGKey(0) + + # Simple 2D case + walker_coords = jnp.array([1.0, 2.0]) + complementary_coords = jnp.array([[0.0, 0.0], [2.0, 4.0], [3.0, 1.0]]) + + proposal, log_hastings_ratio = stretch_move( + rng_key, walker_coords, complementary_coords, a=2.0 + ) + + # Check shapes + self.assertEqual(proposal.shape, walker_coords.shape) + self.assertEqual(log_hastings_ratio.shape, ()) + + # Check that proposal is finite + self.assertTrue(jnp.isfinite(proposal).all()) + self.assertTrue(jnp.isfinite(log_hastings_ratio)) + + def test_stretch_move_pytree(self): + """Test that stretch_move works with PyTree structures.""" + rng_key = jax.random.PRNGKey(0) + + # PyTree case + walker_coords = {"a": jnp.array([1.0, 2.0]), "b": jnp.array(3.0)} + complementary_coords = { + "a": jnp.array([[0.0, 0.0], [2.0, 4.0], [3.0, 1.0]]), + "b": jnp.array([1.0, 2.0, 3.0]) + } + + proposal, log_hastings_ratio = stretch_move( + rng_key, walker_coords, complementary_coords, a=2.0 + ) + + # Check structure + self.assertEqual(set(proposal.keys()), {"a", "b"}) + self.assertEqual(proposal["a"].shape, walker_coords["a"].shape) + self.assertEqual(proposal["b"].shape, walker_coords["b"].shape) + self.assertEqual(log_hastings_ratio.shape, ()) + + def test_stretch_algorithm_2d_gaussian(self): + """Test the stretch algorithm on a 2D Gaussian distribution.""" + + # Define a 2D Gaussian target + mu = jnp.array([1.0, 2.0]) + cov = jnp.array([[1.0, 0.5], [0.5, 2.0]]) + + def logdensity_fn(x): + return stats.multivariate_normal.logpdf(x, mu, cov) + + # Initialize ensemble of 20 walkers + rng_key = jax.random.PRNGKey(42) + init_key, sample_key = jax.random.split(rng_key) + + n_walkers = 20 + initial_position = jax.random.normal(init_key, (n_walkers, 2)) + + # Create algorithm + algorithm = blackjax.stretch(logdensity_fn, a=2.0) + initial_state = algorithm.init(initial_position) + + # Run a few steps + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, (new_state, info) + + keys = jax.random.split(sample_key, 100) + final_state, (states, infos) = jax.lax.scan(run_step, initial_state, keys) + + # Check that we get valid states + self.assertIsInstance(final_state, EnsembleState) + self.assertEqual(final_state.coords.shape, (n_walkers, 2)) + self.assertEqual(final_state.log_probs.shape, (n_walkers,)) + + # Check that acceptance rate is reasonable + mean_acceptance = jnp.mean(infos.acceptance_rate) + self.assertGreater(mean_acceptance, 0.1) # Should accept some proposals + self.assertLess(mean_acceptance, 0.9) # Should reject some proposals + + def test_stretch_algorithm_convergence(self): + """Test that the stretch algorithm converges to the correct distribution.""" + + # Simple 1D Gaussian + mu = 2.0 + sigma = 1.5 + + def logdensity_fn(x): + return stats.norm.logpdf(x.squeeze(), mu, sigma) + + rng_key = jax.random.PRNGKey(123) + init_key, sample_key = jax.random.split(rng_key) + + n_walkers = 50 + initial_position = jax.random.normal(init_key, (n_walkers, 1)) + + # Run algorithm + algorithm = blackjax.stretch(logdensity_fn, a=2.0) + initial_state = algorithm.init(initial_position) + + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, new_state.coords + + keys = jax.random.split(sample_key, 1000) + final_state, samples = jax.lax.scan(run_step, initial_state, keys) + + # Take samples from the second half (burn-in) + samples = samples[500:] # Shape: (500, n_walkers, 1) + samples = samples.reshape(-1, 1) # Flatten to (500 * n_walkers, 1) + + # Check convergence + sample_mean = jnp.mean(samples) + sample_std = jnp.std(samples) + + # Allow for some tolerance due to finite sampling + self.assertAlmostEqual(sample_mean.item(), mu, places=1) + self.assertAlmostEqual(sample_std.item(), sigma, places=1) + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From 526e5f54cb5d114d8e12cee3ec090aa34313b30a Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 21:06:20 +0100 Subject: [PATCH 02/14] Remove PLAN.md and REVIEW.md --- PLAN.md | 376 ------------------------------------------------------ REVIEW.md | 124 ------------------ 2 files changed, 500 deletions(-) delete mode 100644 PLAN.md delete mode 100644 REVIEW.md diff --git a/PLAN.md b/PLAN.md deleted file mode 100644 index 011cd46fa..000000000 --- a/PLAN.md +++ /dev/null @@ -1,376 +0,0 @@ -Here is a comprehensive implementation plan for integrating `emcee`'s ensemble sampling algorithms into the BlackJax framework, designed for an experienced contractor. - -*** - -## Implementation Plan: Emcee Ensemble Samplers in BlackJax - -**To:** Contractor -**From:** Technical Architect -**Date:** 2024-05-21 -**Subject:** Detailed plan for implementing emcee's ensemble sampling algorithms within the BlackJax framework. - -### **Project Overview** - -This document outlines the technical plan to implement `emcee`'s affine-invariant ensemble MCMC algorithms, starting with the classic Stretch Move, into the BlackJax library. The goal is to create a pure-JAX, JIT-compatible, and performant version of these algorithms that aligns with BlackJax's functional and stateless design philosophy, as detailed in the provided codebase summaries. - ---- - -### 1. Technical Analysis - -#### 1.1. Architectural Comparison - -A review of the `blackjax_summary.md` and `emcee_summary.md` files reveals the following core architectural differences: - -| Feature | **BlackJax** | **emcee** | -| :--- | :--- | :--- | -| **Programming Paradigm** | Functional, Composable | Object-Oriented, Monolithic | -| **Core Abstraction** | `SamplingAlgorithm(init, step)` | `EnsembleSampler` class | -| **State Management** | Stateless: state is passed explicitly to/from kernels. | Stateful: `EnsembleSampler` instance holds the chain history. | -| **Execution Model** | JAX (compiles to XLA for CPU/GPU/TPU) | NumPy (Python interpreter) | -| **Concurrency** | `jax.vmap` (vectorization), `jax.pmap` (parallelism) | `multiprocessing`, `mpi4py` (via `schwimmbad`) | -| **Primary Data Structure**| JAX PyTree for single chain state (e.g., `HMCState`). | NumPy arrays for the ensemble of walkers (e.g., `(nwalkers, ndim)`). | -| **Typical Kernel Input** | `(rng_key, state)` for a single chain. | The `EnsembleSampler` object holds the current state. | - -#### 1.2. Key Differences & Challenges - -The fundamental challenge stems from `emcee`'s core "stretch move" algorithm, which proposes a new position for a single walker by referencing the positions of the *entire ensemble* of walkers. BlackJax's standard `(rng_key, state)` kernel signature is designed for algorithms where a single chain's state is sufficient for the next transition (e.g., HMC, RWMH). - -Directly porting `emcee`'s object-oriented design is incompatible with BlackJax's stateless, functional paradigm. - -#### 1.3. Compatibility Assessment & Bridge - -We can bridge this gap by redefining the `state` object that the BlackJax kernel operates on. Instead of representing the state of a single walker, the `state` will represent the state of the **entire ensemble**. - -- **BlackJax `state` for Ensemble Methods**: The `state` object passed to the kernel will be a PyTree containing the coordinates, log-probabilities, and any metadata for all walkers in the ensemble. -- **Kernel Signature**: The kernel's signature will be `kernel(rng_key, ensemble_state) -> (new_ensemble_state, info)`. -- **Internal Logic**: Inside the kernel, we will implement the "parallel stretch move" described in Algorithm 3 of the `emcee` paper (`1202.3665.tex`). This involves splitting the ensemble into two sets ("red" and "blue") and updating each set in parallel using the other as a reference. This structure is perfectly suited for `jax.vmap`, enabling efficient, vectorized execution on accelerators. - -This approach preserves BlackJax's stateless nature while providing the kernel with the necessary information (the full ensemble) to execute `emcee`-style moves. - ---- - -### 2. Implementation Strategy - -Our strategy is to create a new family of MCMC algorithms under `blackjax.mcmc.ensemble`. This will encapsulate the logic for ensemble-based samplers. We will begin with `emcee`'s flagship Stretch Move. - -1. **Define Ensemble State**: We will introduce a new `EnsembleState` `NamedTuple` to represent the state of all walkers. It will contain `coords` (shape `(n_walkers, n_dims)`), `log_probs` (shape `(n_walkers,)`), and optionally `blobs`. - -2. **Stateless Moves**: `emcee`'s `Move` classes (e.g., `emcee.moves.StretchMove`) will be reimplemented as stateless JAX functions. For instance, the `StretchMove` will become a function `stretch_move(rng_key, walker_coords, complementary_ensemble_coords, a)`. - -3. **Vectorized Kernel**: The main kernel will implement the red-blue split strategy from the `emcee` paper. It will iterate twice (once for each color), and in each iteration, it will use `jax.vmap` to efficiently apply the stateless move function to all walkers in the current split, providing the complementary ensemble as an argument. - -4. **Top-Level API**: We will follow BlackJax's factory pattern (`as_top_level_api`) to expose a user-friendly API, e.g., `blackjax.stretch(...)`, which will be constructed similarly to `blackjax.nuts` and `blackjax.hmc`. - ---- - -### 3. Detailed Work Breakdown - -This section provides a specific, ordered list of tasks for the contractor. - -#### **Task 0: Project Setup** - -1. Fork the `blackjax-devs/blackjax` repository on GitHub. -2. Create a new feature branch, e.g., `feature/ensemble-samplers`. -3. Set up the development environment by running the commands in `blackjax/CLAUDE.md`: - ```bash - pip install -r requirements.txt - pip install -e . - pre-commit install - ``` - -#### **Task 1: Core Data Structures and Module** - -1. Create a new file: `blackjax/mcmc/ensemble.py`. -2. In this file, define the core data structures for ensemble methods, consistent with `blackjax/types.py`. - - ```python - # In blackjax/mcmc/ensemble.py - from typing import Callable, NamedTuple, Optional - from blackjax.types import Array, ArrayTree - - class EnsembleState(NamedTuple): - """State of an ensemble sampler. - - coords - An array or PyTree of arrays of shape `(n_walkers, ...)` that - stores the current position of the walkers. - log_probs - An array of shape `(n_walkers,)` that stores the log-probability of - each walker. - blobs - An optional PyTree that stores metadata returned by the log-probability - function. - """ - coords: ArrayTree - log_probs: Array - blobs: Optional[ArrayTree] = None - - - class EnsembleInfo(NamedTuple): - """Additional information on the ensemble transition. - - acceptance_rate - The acceptance rate of the ensemble. - accepted - A boolean array of shape `(n_walkers,)` indicating whether each walker's - proposal was accepted. - """ - acceptance_rate: Array - accepted: Array - ``` - -#### **Task 2: Implement the Stretch Move** - -1. In `blackjax/mcmc/ensemble.py`, implement the stretch move as a pure JAX function, following `emcee.moves.stretch.StretchMove` and Eq. 10 in `1202.3665.tex`. - - ```python - # In blackjax/mcmc/ensemble.py - import jax - import jax.numpy as jnp - from jax.flatten_util import ravel_pytree - from blackjax.types import PRNGKey, ArrayTree - - def stretch_move( - rng_key: PRNGKey, - walker_coords: ArrayTree, - complementary_coords: ArrayTree, - a: float = 2.0, - ) -> tuple[ArrayTree, float]: - """The emcee stretch move. - - A proposal is generated by selecting a random walker from the complementary - ensemble and moving the current walker along the line connecting the two. - """ - key_select, key_stretch = jax.random.split(rng_key) - - # Ravel coordinates to handle PyTrees - walker_flat, unravel_fn = ravel_pytree(walker_coords) - comp_flat, _ = ravel_pytree(complementary_coords) - - n_walkers_comp, n_dims = comp_flat.shape - - # Select a random walker from the complementary ensemble - idx = jax.random.randint(key_select, (), 0, n_walkers_comp) - complementary_walker_flat = comp_flat[idx] - - # Generate the stretch factor `Z` from g(z) - z = ((a - 1.0) * jax.random.uniform(key_stretch) + 1) ** 2.0 / a - - # Generate the proposal (Eq. 10) - proposal_flat = complementary_walker_flat + z * (walker_flat - complementary_walker_flat) - - # The log of the Hastings ratio (Eq. 11) - log_hastings_ratio = (n_dims - 1.0) * jnp.log(z) - - return unravel_fn(proposal_flat), log_hastings_ratio - ``` - -#### **Task 3: Build the Ensemble Kernel** - -1. In `blackjax/mcmc/ensemble.py`, implement the `build_kernel` function. This will orchestrate the red-blue split (Algorithm 3 in the paper) and apply the move using `jax.vmap`. - - ```python - # In blackjax/mcmc/ensemble.py - from blackjax.base import SamplingAlgorithm - - def build_kernel(move_fn: Callable) -> Callable: - """Builds a generic ensemble MCMC kernel.""" - - def kernel( - rng_key: PRNGKey, state: EnsembleState, logdensity_fn: Callable - ) -> tuple[EnsembleState, EnsembleInfo]: - - n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape - half_n = n_walkers // 2 - - # Red-Blue Split - walkers_red = jax.tree.map(lambda x: x[:half_n], state) - walkers_blue = jax.tree.map(lambda x: x[half_n:], state) - - # Update Red walkers using Blue as complementary - key_red, key_blue = jax.random.split(rng_key) - new_walkers_red, accepted_red = _update_half(key_red, walkers_red, walkers_blue, logdensity_fn, move_fn) - - # Update Blue walkers using updated Red as complementary - new_walkers_blue, accepted_blue = _update_half(key_blue, walkers_blue, new_walkers_red, logdensity_fn, move_fn) - - # Combine back - new_coords = jax.tree.map(lambda r, b: jnp.concatenate([r, b], axis=0), new_walkers_red.coords, new_walkers_blue.coords) - new_log_probs = jnp.concatenate([new_walkers_red.log_probs, new_walkers_blue.log_probs]) - - if state.blobs is not None: - new_blobs = jax.tree.map(lambda r, b: jnp.concatenate([r, b], axis=0), new_walkers_red.blobs, new_walkers_blue.blobs) - else: - new_blobs = None - - new_state = EnsembleState(new_coords, new_log_probs, new_blobs) - accepted = jnp.concatenate([accepted_red, accepted_blue]) - acceptance_rate = jnp.mean(accepted.astype(jnp.float32)) - info = EnsembleInfo(acceptance_rate, accepted) - - return new_state, info - - return kernel - - def _update_half(rng_key, walkers_to_update, complementary_walkers, logdensity_fn, move_fn): - """Helper to update one half of the ensemble.""" - n_update, *_ = jax.tree_util.tree_flatten(walkers_to_update.coords)[0][0].shape - keys = jax.random.split(rng_key, n_update) - - # Vectorize the move over the walkers to be updated - proposals, log_hastings_ratios = jax.vmap( - lambda k, w_coords: move_fn(k, w_coords, complementary_walkers.coords) - )(keys, walkers_to_update.coords) - - # Compute log-probabilities for proposals - log_probs_proposal, blobs_proposal = jax.vmap(logdensity_fn)(proposals) - - # MH accept/reject step (Eq. 11) - log_p_accept = log_hastings_ratios + log_probs_proposal - walkers_to_update.log_probs - - # To avoid -inf - (-inf) = NaN, replace -inf with a large negative number. - log_p_accept = jnp.where(jnp.isneginf(walkers_to_update.log_probs), -jnp.inf, log_p_accept) - - u = jax.random.uniform(rng_key, shape=(n_update,)) - accepted = jnp.log(u) < log_p_accept - - # Build the new state for the half - new_coords = jax.tree.map(lambda prop, old: jnp.where(accepted[:, None], prop, old), proposals, walkers_to_update.coords) - new_log_probs = jnp.where(accepted, log_probs_proposal, walkers_to_update.log_probs) - - if walkers_to_update.blobs is not None: - new_blobs = jax.tree.map( - lambda prop, old: jnp.where(accepted, prop, old), - blobs_proposal, - walkers_to_update.blobs, - ) - else: - new_blobs = None - - new_walkers = EnsembleState(new_coords, new_log_probs, new_blobs) - return new_walkers, accepted - ``` - -#### **Task 4: Create Top-Level API** - -1. In `blackjax/mcmc/ensemble.py`, create the factory function `as_top_level_api` and the `init` function. - - ```python - # In blackjax/mcmc/ensemble.py - - def init(position: ArrayTree, logdensity_fn: Callable, has_blobs: bool = False) -> EnsembleState: - """Initializes the ensemble.""" - if has_blobs: - log_probs, blobs = jax.vmap(logdensity_fn)(position) - return EnsembleState(position, log_probs, blobs) - else: - log_probs = jax.vmap(logdensity_fn)(position) - return EnsembleState(position, log_probs, None) - - - def as_top_level_api( - logdensity_fn: Callable, move_fn: Callable, has_blobs: bool = False - ) -> SamplingAlgorithm: - """Implements the user-facing API for ensemble samplers.""" - kernel = build_kernel(move_fn) - - def init_fn(position: ArrayTree, rng_key=None): - return init(position, logdensity_fn, has_blobs) - - def step_fn(rng_key: PRNGKey, state: EnsembleState): - return kernel(rng_key, state, logdensity_fn) - - return SamplingAlgorithm(init_fn, step_fn) - ``` - -2. Create a new file `blackjax/mcmc/stretch.py` for the stretch move API. - - ```python - # In blackjax/mcmc/stretch.py - from typing import Callable - from blackjax.base import SamplingAlgorithm - from blackjax.mcmc.ensemble import as_top_level_api as ensemble_api, stretch_move, init - - def as_top_level_api(logdensity_fn: Callable, a: float = 2.0, has_blobs: bool = False) -> SamplingAlgorithm: - """A user-facing API for the stretch move algorithm.""" - move = lambda key, w, c: stretch_move(key, w, c, a) - return ensemble_api(logdensity_fn, move, has_blobs) - ``` - -3. Integrate the new algorithm into the BlackJax public API. - - - In `blackjax/mcmc/__init__.py`, add: - ```python - from . import stretch - - __all__ = [ - # ... existing algorithms - "stretch", - ] - ``` - - - In `blackjax/__init__.py`, add: - ```python - from .mcmc import stretch as _stretch - - # After other algorithm definitions - stretch = generate_top_level_api_from(_stretch) - ``` - ---- - -### 4. Integration Points - -The new algorithm will be integrated as follows: - -- **`blackjax/mcmc/ensemble.py`**: [NEW] Contains the core `EnsembleState`, `EnsembleInfo`, `build_kernel`, and `_update_half` logic for all ensemble methods. -- **`blackjax/mcmc/stretch.py`**: [NEW] Contains the user-facing API factory for the stretch move, `blackjax.stretch`. It will import `stretch_move` and `as_top_level_api` from `ensemble.py` and specialize them. -- **`blackjax/mcmc/__init__.py`**: [MODIFY] To expose the `stretch` module. -- **`blackjax/__init__.py`**: [MODIFY] To register `blackjax.stretch` as a top-level algorithm. -- **No changes required** to `blackjax/base.py` or `blackjax/types.py`. - ---- - -### 5. Testing Strategy - -Thorough testing is critical to ensure correctness. - -1. **Unit Tests for `stretch_move`**: - - Create `tests/mcmc/test_ensemble.py`. - - Write a test for the `stretch_move` function to verify its output shape and statistical properties on a simple distribution. Ensure it works with PyTrees. - -2. **Integration Test for `blackjax.stretch`**: - - In `tests/mcmc/test_ensemble.py`, create a test that runs the full `blackjax.stretch` sampler. - - **Validation against `emcee`**: The most important test. - - Define a simple target distribution (e.g., a 2D Gaussian). - - Seed both `emcee` and `blackjax.stretch` with the same initial ensemble and the same random seed (requires careful management of JAX's PRNGKey vs NumPy's global state). - - Run both samplers for a small number of steps. - - Assert that the sequence of accepted positions and log-probabilities are identical (or `allclose`). This will prove the correctness of the implementation. - -3. **Convergence Test**: - - Add a new test case to `tests/mcmc/test_sampling.py` for `blackjax.stretch`. - - Use the existing `LinearRegressionTest` or `UnivariateNormalTest` framework. - - Run the sampler for a sufficient number of steps and verify that the posterior mean and variance match the true values within a given tolerance. - ---- - -### 6. Performance Considerations - -1. **JIT Compilation**: The main `step` function returned by `blackjax.stretch` must be JIT-compilable. All functions within the call stack (`kernel`, `_update_half`, `stretch_move`) are designed with this in mind. -2. **Vectorization**: The use of `jax.vmap` in `_update_half` is the key to performance. It ensures that the proposal generation and log-density evaluation for each half of the ensemble are vectorized, which is highly efficient on GPUs and TPUs. -3. **Parallel Chains (`pmap`)**: The final implementation will be a pure function and thus fully compatible with `jax.pmap`. This allows users to run multiple independent ensembles in parallel across different devices, a significant advantage over `emcee`'s `multiprocessing` backend. - ---- - -### 7. Documentation Requirements - -1. **API Documentation**: - - Add `blackjax.stretch` to the API reference section of the documentation. - - Ensure `autoapi` in `docs/conf.py` picks up the docstrings for `blackjax.mcmc.stretch.as_top_level_api` and the `EnsembleState`/`EnsembleInfo` `NamedTuple`s. - -2. **User Guide**: - - Create a new example notebook/Markdown file in `docs/examples/`, named `howto_use_ensemble_samplers.md`. - - This guide should demonstrate how to use `blackjax.stretch`, explaining the ensemble-based approach, how to initialize the walkers, and how to handle PyTree states. It should be similar in style to `quickstart.md`. - - Update `docs/index.md` to link to this new "How-to" guide. - -3. **Future Work**: Once other moves (`DE`, `Walk`, etc.) are implemented, this documentation should be expanded to cover them. \ No newline at end of file diff --git a/REVIEW.md b/REVIEW.md deleted file mode 100644 index 63ee7605f..000000000 --- a/REVIEW.md +++ /dev/null @@ -1,124 +0,0 @@ -Of course. As a senior code reviewer, here is a thorough analysis of the provided code changes for implementing the `emcee` stretch move in BlackJax. - -*** - -### Code Review: Implementation of Emcee Stretch Move in BlackJax - -**Overall Assessment:** - -This is a high-quality, well-executed implementation that successfully translates the `emcee` ensemble sampling logic into BlackJax's functional, stateless paradigm. The code is clean, idiomatic JAX, and integrates seamlessly with the existing library structure. The developer has correctly identified the core architectural challenges and implemented a robust solution using `vmap` for vectorization and a re-definition of the kernel `state` to represent the ensemble. - -The implementation is nearly ready for merging, pending one critical addition to the test suite and a few minor refinements for robustness and API elegance. - ---- - -### 1. Methodological Errors - -My analysis shows **no significant methodological errors**. The implementation correctly follows the `emcee` paper (Goodman & Weare, 2010) and the established red-blue update strategy. - -- **Stretch Move Algorithm**: The proposal generation in `stretch_move` correctly implements Eq. 10 from the paper: `proposal = Z * X_j + (1 - Z) * X_k` which is equivalent to the implemented `proposal = X_k + Z * (X_j - X_k)`. -- **Hastings Ratio**: The log of the Hastings ratio `(n_dims - 1) * jnp.log(z)` is correct as per Eq. 11. -- **Red-Blue Update**: The `build_kernel` function correctly implements the parallel update strategy (Algorithm 3 in the paper) by splitting the ensemble and updating one half using the other as the complementary set. -- **Acceptance/Rejection Logic**: The Metropolis-Hastings acceptance probability `log_p_accept = log_hastings_ratio + log_probs_proposal - walkers_to_update.log_probs` is correct. The handling of `-inf` values to prevent `NaN` is a thoughtful and crucial detail. - -The core algorithm is sound. - ---- - -### 2. JAX-specific Issues & Suggestions - -The implementation makes excellent use of JAX's features. The use of `vmap` is appropriate and key to performance. The following are minor points for improvement and future-proofing. - -- **[Minor] Brittle PyTree Shape Inference in `stretch_move`** - - **File**: `blackjax/mcmc/ensemble.py`, line 42 - - **Code**: `n_walkers_comp = comp_leaves[0].shape[0]` - - **Issue**: This assumes that all leaves in the `complementary_coords` PyTree have the same leading dimension. While this is true for this specific use case, it could break if a user constructs an unusual PyTree. - - **Suggestion**: A more robust pattern would be to validate this assumption. Since this is on a performance-critical path, a `chex.assert_equal_shape_prefix` in a test or a debug-mode-only assert would be appropriate. For now, this is acceptable, but worth noting. - -- **[Improvement] Broadcasting in `_update_half` for PyTree Leaves** - - **File**: `blackjax/mcmc/ensemble.py`, line 150 - - **Code**: `new_coords = jax.tree.map(lambda prop, old: jnp.where(accepted[:, None], prop, old), ...)` - - **Issue**: The use of `accepted[:, None]` correctly broadcasts the `(n_update,)` boolean array for leaves with shape `(n_update, n_dims)`. However, if a leaf in the position PyTree had a more complex shape, e.g., `(n_update, n_dims, n_other)`, this would fail. - - **Suggestion**: To make this more robust for arbitrary PyTree structures, you can reshape `accepted` to match the rank of each leaf. - ```python - # In _update_half, before the jax.tree.map - def where_broad(arr): - # Add new axes to `accepted` to match the rank of the leaf - ndims_to_add = arr.ndim - 1 - reshaped_accepted = jax.lax.broadcast_in_dim( - accepted, arr.shape, broadcast_dimensions=(0,) - ) - return jnp.where(reshaped_accepted, prop, old) - - # Then in the tree_map - new_coords = jax.tree.map( - lambda prop, old: where_broad(prop, old), - proposals, - walkers_to_update.coords - ) - ``` - This is a minor point for future-proofing and the current implementation is correct for the expected use cases. - ---- - -### 3. Code Quality Issues - -The code quality is high. Naming is clear, and the structure is logical. - -- **[Improvement] API Elegance of `has_blobs`** - - **File**: `blackjax/mcmc/stretch.py`, line 16 - - **Code**: `def as_top_level_api(..., has_blobs: bool = False)` - - **Issue**: The `has_blobs` flag requires the user to explicitly state whether their `logdensity_fn` returns extra data. This is slightly out of sync with other BlackJax APIs that often infer this automatically. The `vmap` makes inference tricky, but it's not impossible. - - **Suggestion**: Consider a helper wrapper for the `logdensity_fn` inside `as_top_level_api` that standardizes the output. - ```python - # In as_top_level_api - def wrapped_logdensity_fn(x): - out = logdensity_fn(x) - if isinstance(out, tuple): - return out - return out, None - - # Then the rest of the code can assume the output is always a tuple, - # and the user does not need to pass `has_blobs`. - # This requires adjusting `init` and `_update_half` to remove the `if/else` logic - # and always expect a (log_prob, blob) tuple. - ``` - This would make the API cleaner and more robust to user error. - -- **[Nitpick] PyTree Raveling in `stretch_move`** - - **File**: `blackjax/mcmc/ensemble.py`, line 39 - - **Code**: The logic for raveling the selected complementary walker is inside `stretch_move`. - - **Suggestion**: This is perfectly fine. An alternative, slightly cleaner pattern could be to have `stretch_move` operate on flattened arrays only, and perform the raveling/unraveling in the calling function (`_update_half`). This can sometimes improve modularity but is not a major issue here. - ---- - -### 4. Integration Issues - -The integration with the BlackJax API is excellent. The use of `generate_top_level_api_from` in `blackjax/__init__.py` is exactly right. However, the testing strategy has a significant gap. - -- **[Critical] Missing Validation Test Against `emcee`** - - **File**: `tests/mcmc/test_ensemble.py` - - **Issue**: The test suite includes unit tests and a convergence test, which are great. However, it is missing a direct validation test against the reference `emcee` implementation. Such a test would involve: - 1. Setting up the same model and initial ensemble in both BlackJax and `emcee`. - 2. Carefully managing the random seeds to ensure both samplers make the same random choices. - 3. Running for one or a few steps. - 4. Asserting that the resulting ensemble positions and log-probabilities are identical (or `allclose`). - - **Suggestion**: **This is the most important required change.** A validation test provides a much stronger guarantee of correctness than a convergence test alone. Please add a test case to `test_ensemble.py` that performs this comparison. It will require installing `emcee` as a test dependency. - -- **[Good Practice] Add an `__all__` dunder** - - **File**: `blackjax/mcmc/ensemble.py` - - **Suggestion**: It's good practice to add an `__all__` list to new modules to explicitly define the public API. For `ensemble.py`, it should include `EnsembleState`, `EnsembleInfo`, `stretch_move`, `build_kernel`, `init`, and `as_top_level_api`. - -### Summary of Recommendations - -- **Priority 1 (Blocking):** - 1. **Add Validation Test**: Implement a test in `tests/mcmc/test_ensemble.py` that compares the output of `blackjax.stretch` directly against `emcee` for a fixed seed to ensure the logic is identical. - -- **Priority 2 (Recommended Improvements):** - 1. **Refactor `has_blobs`**: Remove the `has_blobs` flag from the public API by wrapping the user's `logdensity_fn` to standardize its output, making the API more robust and user-friendly. - 2. **Add `__all__`**: Add an `__all__` export list to `blackjax/mcmc/ensemble.py`. - -- **Priority 3 (Minor Suggestions):** - 1. **Robustness**: Consider the suggested improvements for PyTree shape handling in `stretch_move` and `_update_half` for long-term robustness, possibly with `chex` assertions. - -This is an excellent contribution. Once the validation test is added, this implementation can be considered complete and correct. \ No newline at end of file From 923f6ceb3822ac3fe029d9c0ba9e6298d8713a27 Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 21:19:22 +0100 Subject: [PATCH 03/14] Fix ensemble sampler correctness issues and add emcee compatibility features MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Correctness fixes: - Fix PRNG key reuse bug in _update_half function - Fix -∞ log probability handling to accept moves from -∞ to finite density Enhancements for emcee compatibility: - Add randomized red/blue split (randomize_split parameter, default True) - Add warning when nwalkers < 2*ndim (suppressible with live_dangerously) All tests pass and fixes verified with dedicated test cases. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- blackjax/mcmc/ensemble.py | 222 ++++++++++++++++++++++++++++++-------- blackjax/mcmc/stretch.py | 67 ++++++++++-- 2 files changed, 232 insertions(+), 57 deletions(-) diff --git a/blackjax/mcmc/ensemble.py b/blackjax/mcmc/ensemble.py index 273d24f37..2d726f157 100644 --- a/blackjax/mcmc/ensemble.py +++ b/blackjax/mcmc/ensemble.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Core functionality for ensemble MCMC algorithms.""" +import warnings from typing import Callable, NamedTuple, Optional import jax @@ -44,6 +45,7 @@ class EnsembleState(NamedTuple): An optional PyTree that stores metadata returned by the log-probability function. """ + coords: ArrayTree log_probs: Array blobs: Optional[ArrayTree] = None @@ -58,6 +60,7 @@ class EnsembleInfo(NamedTuple): A boolean array of shape `(n_walkers,)` indicating whether each walker's proposal was accepted. """ + acceptance_rate: Array accepted: Array @@ -74,68 +77,118 @@ def stretch_move( ensemble and moving the current walker along the line connecting the two. """ key_select, key_stretch = jax.random.split(rng_key) - + # Ravel coordinates to handle PyTrees walker_flat, unravel_fn = ravel_pytree(walker_coords) - + # Get the shape of the complementary ensemble # complementary_coords should have shape (n_walkers, ...) where ... matches walker_coords comp_leaves, comp_treedef = jax.tree_util.tree_flatten(complementary_coords) n_walkers_comp = comp_leaves[0].shape[0] - + # Select a random walker from the complementary ensemble idx = jax.random.randint(key_select, (), 0, n_walkers_comp) complementary_walker = jax.tree.map(lambda x: x[idx], complementary_coords) - + # Ravel the selected complementary walker complementary_walker_flat, _ = ravel_pytree(complementary_walker) - + # Generate the stretch factor `Z` from g(z) z = ((a - 1.0) * jax.random.uniform(key_stretch) + 1) ** 2.0 / a - + # Generate the proposal (Eq. 10) - proposal_flat = complementary_walker_flat + z * (walker_flat - complementary_walker_flat) - + proposal_flat = complementary_walker_flat + z * ( + walker_flat - complementary_walker_flat + ) + # The log of the Hastings ratio (Eq. 11) # Number of dimensions is the length of the flattened walker n_dims = walker_flat.shape[0] log_hastings_ratio = (n_dims - 1.0) * jnp.log(z) - + return unravel_fn(proposal_flat), log_hastings_ratio -def build_kernel(move_fn: Callable) -> Callable: - """Builds a generic ensemble MCMC kernel.""" +def build_kernel(move_fn: Callable, randomize_split: bool = True) -> Callable: + """Builds a generic ensemble MCMC kernel. + + Parameters + ---------- + move_fn + The move function to use (e.g., stretch_move). + randomize_split + If True, randomly shuffle walker indices before splitting into red/blue sets + each iteration. This improves mixing and matches emcee's default behavior. + If False, uses a fixed contiguous split. + """ def kernel( rng_key: PRNGKey, state: EnsembleState, logdensity_fn: Callable ) -> tuple[EnsembleState, EnsembleInfo]: - n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape half_n = n_walkers // 2 - + + # Optionally randomize the red-blue split + if randomize_split: + key_shuffle, key_red, key_blue = jax.random.split(rng_key, 3) + indices = jax.random.permutation(key_shuffle, n_walkers) + shuffled_state = jax.tree.map(lambda x: x[indices], state) + else: + key_red, key_blue = jax.random.split(rng_key) + shuffled_state = state + indices = jnp.arange(n_walkers) + # Red-Blue Split - walkers_red = jax.tree.map(lambda x: x[:half_n], state) - walkers_blue = jax.tree.map(lambda x: x[half_n:], state) + walkers_red = jax.tree.map(lambda x: x[:half_n], shuffled_state) + walkers_blue = jax.tree.map(lambda x: x[half_n:], shuffled_state) # Update Red walkers using Blue as complementary - key_red, key_blue = jax.random.split(rng_key) - new_walkers_red, accepted_red = _update_half(key_red, walkers_red, walkers_blue, logdensity_fn, move_fn) + new_walkers_red, accepted_red = _update_half( + key_red, walkers_red, walkers_blue, logdensity_fn, move_fn + ) # Update Blue walkers using updated Red as complementary - new_walkers_blue, accepted_blue = _update_half(key_blue, walkers_blue, new_walkers_red, logdensity_fn, move_fn) - - # Combine back - new_coords = jax.tree.map(lambda r, b: jnp.concatenate([r, b], axis=0), new_walkers_red.coords, new_walkers_blue.coords) - new_log_probs = jnp.concatenate([new_walkers_red.log_probs, new_walkers_blue.log_probs]) - + new_walkers_blue, accepted_blue = _update_half( + key_blue, walkers_blue, new_walkers_red, logdensity_fn, move_fn + ) + + # Combine back in the shuffled order + shuffled_coords = jax.tree.map( + lambda r, b: jnp.concatenate([r, b], axis=0), + new_walkers_red.coords, + new_walkers_blue.coords, + ) + shuffled_log_probs = jnp.concatenate( + [new_walkers_red.log_probs, new_walkers_blue.log_probs] + ) + shuffled_accepted = jnp.concatenate([accepted_red, accepted_blue]) + if state.blobs is not None: - new_blobs = jax.tree.map(lambda r, b: jnp.concatenate([r, b], axis=0), new_walkers_red.blobs, new_walkers_blue.blobs) + shuffled_blobs = jax.tree.map( + lambda r, b: jnp.concatenate([r, b], axis=0), + new_walkers_red.blobs, + new_walkers_blue.blobs, + ) else: - new_blobs = None + shuffled_blobs = None + + # Unshuffle to restore original ordering + if randomize_split: + inverse_indices = jnp.argsort(indices) + new_coords = jax.tree.map(lambda x: x[inverse_indices], shuffled_coords) + new_log_probs = shuffled_log_probs[inverse_indices] + accepted = shuffled_accepted[inverse_indices] + if shuffled_blobs is not None: + new_blobs = jax.tree.map(lambda x: x[inverse_indices], shuffled_blobs) + else: + new_blobs = None + else: + new_coords = shuffled_coords + new_log_probs = shuffled_log_probs + accepted = shuffled_accepted + new_blobs = shuffled_blobs new_state = EnsembleState(new_coords, new_log_probs, new_blobs) - accepted = jnp.concatenate([accepted_red, accepted_blue]) acceptance_rate = jnp.mean(accepted.astype(jnp.float32)) info = EnsembleInfo(acceptance_rate, accepted) @@ -144,16 +197,21 @@ def kernel( return kernel -def _update_half(rng_key, walkers_to_update, complementary_walkers, logdensity_fn, move_fn): +def _update_half( + rng_key, walkers_to_update, complementary_walkers, logdensity_fn, move_fn +): """Helper to update one half of the ensemble.""" n_update, *_ = jax.tree_util.tree_flatten(walkers_to_update.coords)[0][0].shape - keys = jax.random.split(rng_key, n_update) + + # Split key for moves and acceptance to avoid key reuse + key_moves, key_accept = jax.random.split(rng_key) + keys = jax.random.split(key_moves, n_update) # Vectorize the move over the walkers to be updated proposals, log_hastings_ratios = jax.vmap( lambda k, w_coords: move_fn(k, w_coords, complementary_walkers.coords) )(keys, walkers_to_update.coords) - + # Compute log-probabilities for proposals logdensity_outputs = jax.vmap(logdensity_fn)(proposals) if isinstance(logdensity_outputs, tuple): @@ -161,20 +219,39 @@ def _update_half(rng_key, walkers_to_update, complementary_walkers, logdensity_f else: log_probs_proposal = logdensity_outputs blobs_proposal = None - + # MH accept/reject step (Eq. 11) - log_p_accept = log_hastings_ratios + log_probs_proposal - walkers_to_update.log_probs - - # To avoid -inf - (-inf) = NaN, replace -inf with a large negative number. - log_p_accept = jnp.where(jnp.isneginf(walkers_to_update.log_probs), -jnp.inf, log_p_accept) - - u = jax.random.uniform(rng_key, shape=(n_update,)) + log_p_accept = ( + log_hastings_ratios + log_probs_proposal - walkers_to_update.log_probs + ) + + # Handle -inf log probabilities correctly: + # - If current is -inf and proposal is finite: accept (log_p_accept = +inf) + # - If proposal is -inf: reject (log_p_accept = -inf) + # - If both are -inf: reject (log_p_accept = -inf) + is_curr_fin = jnp.isfinite(walkers_to_update.log_probs) + is_prop_fin = jnp.isfinite(log_probs_proposal) + log_p_accept = jnp.where( + ~is_curr_fin & is_prop_fin, + jnp.inf, + jnp.where( + is_curr_fin & ~is_prop_fin, + -jnp.inf, + jnp.where(~is_curr_fin & ~is_prop_fin, -jnp.inf, log_p_accept), + ), + ) + + u = jax.random.uniform(key_accept, shape=(n_update,)) accepted = jnp.log(u) < log_p_accept # Build the new state for the half - new_coords = jax.tree.map(lambda prop, old: jnp.where(accepted[:, None], prop, old), proposals, walkers_to_update.coords) + new_coords = jax.tree.map( + lambda prop, old: jnp.where(accepted[:, None], prop, old), + proposals, + walkers_to_update.coords, + ) new_log_probs = jnp.where(accepted, log_probs_proposal, walkers_to_update.log_probs) - + if walkers_to_update.blobs is not None: new_blobs = jax.tree.map( lambda prop, old: jnp.where(accepted, prop, old), @@ -183,13 +260,47 @@ def _update_half(rng_key, walkers_to_update, complementary_walkers, logdensity_f ) else: new_blobs = None - + new_walkers = EnsembleState(new_coords, new_log_probs, new_blobs) return new_walkers, accepted -def init(position: ArrayTree, logdensity_fn: Callable, has_blobs: bool = False) -> EnsembleState: - """Initializes the ensemble.""" +def init( + position: ArrayTree, + logdensity_fn: Callable, + has_blobs: bool = False, + live_dangerously: bool = False, +) -> EnsembleState: + """Initializes the ensemble. + + Parameters + ---------- + position + Initial positions for all walkers, with shape (n_walkers, ...). + logdensity_fn + The log-density function to evaluate. + has_blobs + Whether the log-density function returns additional metadata (blobs). + live_dangerously + If False (default), warns when n_walkers < 2*ndim, which can lead to poor + mixing. Set to True to suppress this warning. + """ + # Get number of walkers and dimensions + leaves, _ = jax.tree_util.tree_flatten(position) + n_walkers = leaves[0].shape[0] + flat_sample, _ = ravel_pytree(jax.tree.map(lambda x: x[0], position)) + ndim = flat_sample.shape[0] + + # Warn if n_walkers < 2*ndim (following emcee's recommendation) + if not live_dangerously and n_walkers < 2 * ndim: + warnings.warn( + f"Running ensemble sampler with {n_walkers} walkers for {ndim} dimensions. " + f"For optimal performance and mixing, emcee recommends at least {2 * ndim} walkers. " + f"Set live_dangerously=True to suppress this warning.", + UserWarning, + stacklevel=2, + ) + logdensity_outputs = jax.vmap(logdensity_fn)(position) if isinstance(logdensity_outputs, tuple): log_probs, blobs = logdensity_outputs @@ -200,15 +311,34 @@ def init(position: ArrayTree, logdensity_fn: Callable, has_blobs: bool = False) def as_top_level_api( - logdensity_fn: Callable, move_fn: Callable, has_blobs: bool = False + logdensity_fn: Callable, + move_fn: Callable, + has_blobs: bool = False, + randomize_split: bool = True, + live_dangerously: bool = False, ) -> SamplingAlgorithm: - """Implements the user-facing API for ensemble samplers.""" - kernel = build_kernel(move_fn) + """Implements the user-facing API for ensemble samplers. + + Parameters + ---------- + logdensity_fn + The log-density function to sample from. + move_fn + The move function to use (e.g., stretch_move). + has_blobs + Whether the log-density function returns additional metadata (blobs). + randomize_split + If True, randomly shuffle walker indices before splitting into red/blue sets + each iteration. This improves mixing and matches emcee's default behavior. + live_dangerously + If False (default), warns when n_walkers < 2*ndim. Set to True to suppress. + """ + kernel = build_kernel(move_fn, randomize_split=randomize_split) def init_fn(position: ArrayTree, rng_key=None): - return init(position, logdensity_fn, has_blobs) + return init(position, logdensity_fn, has_blobs, live_dangerously) def step_fn(rng_key: PRNGKey, state: EnsembleState): return kernel(rng_key, state, logdensity_fn) - return SamplingAlgorithm(init_fn, step_fn) \ No newline at end of file + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/stretch.py b/blackjax/mcmc/stretch.py index 7ab5dec49..75a22b585 100644 --- a/blackjax/mcmc/stretch.py +++ b/blackjax/mcmc/stretch.py @@ -15,14 +15,23 @@ from typing import Callable from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.ensemble import as_top_level_api as ensemble_api, stretch_move, init as ensemble_init, build_kernel as ensemble_build_kernel +from blackjax.mcmc.ensemble import as_top_level_api as ensemble_api +from blackjax.mcmc.ensemble import build_kernel as ensemble_build_kernel +from blackjax.mcmc.ensemble import init as ensemble_init +from blackjax.mcmc.ensemble import stretch_move __all__ = ["as_top_level_api", "init", "build_kernel"] -def as_top_level_api(logdensity_fn: Callable, a: float = 2.0, has_blobs: bool = False) -> SamplingAlgorithm: +def as_top_level_api( + logdensity_fn: Callable, + a: float = 2.0, + has_blobs: bool = False, + randomize_split: bool = True, + live_dangerously: bool = False, +) -> SamplingAlgorithm: """A user-facing API for the stretch move algorithm. - + Parameters ---------- logdensity_fn @@ -31,22 +40,58 @@ def as_top_level_api(logdensity_fn: Callable, a: float = 2.0, has_blobs: bool = The stretch parameter. Must be > 1. Default is 2.0. has_blobs Whether the logdensity function returns additional information (blobs). - + randomize_split + If True, randomly shuffle walker indices before splitting into red/blue sets + each iteration. This improves mixing and matches emcee's default behavior. + live_dangerously + If False (default), warns when n_walkers < 2*ndim. Set to True to suppress. + Returns ------- A `SamplingAlgorithm` that can be used to sample from the target distribution. """ move = lambda key, w, c: stretch_move(key, w, c, a) - return ensemble_api(logdensity_fn, move, has_blobs) + return ensemble_api( + logdensity_fn, + move, + has_blobs, + randomize_split=randomize_split, + live_dangerously=live_dangerously, + ) + + +def init( + position, logdensity_fn, has_blobs: bool = False, live_dangerously: bool = False +): + """Initialize the stretch move algorithm. + Parameters + ---------- + position + Initial positions for all walkers, with shape (n_walkers, ...). + logdensity_fn + The log-density function to evaluate. + has_blobs + Whether the log-density function returns additional metadata (blobs). + live_dangerously + If False (default), warns when n_walkers < 2*ndim. Set to True to suppress. + """ + return ensemble_init(position, logdensity_fn, has_blobs, live_dangerously) -def init(position, logdensity_fn, has_blobs: bool = False): - """Initialize the stretch move algorithm.""" - return ensemble_init(position, logdensity_fn, has_blobs) +def build_kernel(move_fn=None, a: float = 2.0, randomize_split: bool = True): + """Build the stretch move kernel. -def build_kernel(move_fn=None, a: float = 2.0): - """Build the stretch move kernel.""" + Parameters + ---------- + move_fn + Optional custom move function. If None, uses stretch_move with parameter a. + a + The stretch parameter. Must be > 1. Default is 2.0. + randomize_split + If True, randomly shuffle walker indices before splitting into red/blue sets + each iteration. This improves mixing and matches emcee's default behavior. + """ if move_fn is None: move_fn = lambda key, w, c: stretch_move(key, w, c, a) - return ensemble_build_kernel(move_fn) \ No newline at end of file + return ensemble_build_kernel(move_fn, randomize_split=randomize_split) From 1032cf91fe5061d77ed1fa5a29ab9c7af08783ef Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 21:31:20 +0100 Subject: [PATCH 04/14] Fix remaining PyTree bugs and implement nsplits > 2 support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Critical bug fixes: - Fix PyTree handling with None blobs (avoid tree.map on None) - Fix mask broadcasting for 1D PyTree leaves (add _masked_select helper) Enhancements: - Implement nsplits > 2 support (generalize red-blue to n-way split) - Each group updated sequentially using all other groups as complementary - Default nsplits=2 maintains emcee compatibility All tests pass including: - PyTree coords with None blobs - 1D PyTree leaves - nsplits=2 and nsplits=4 - PyTree with blobs 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- blackjax/mcmc/ensemble.py | 139 +++++++++++++++++++++++++++++--------- blackjax/mcmc/stretch.py | 19 ++++-- 2 files changed, 123 insertions(+), 35 deletions(-) diff --git a/blackjax/mcmc/ensemble.py b/blackjax/mcmc/ensemble.py index 2d726f157..26f2f7edc 100644 --- a/blackjax/mcmc/ensemble.py +++ b/blackjax/mcmc/ensemble.py @@ -109,7 +109,9 @@ def stretch_move( return unravel_fn(proposal_flat), log_hastings_ratio -def build_kernel(move_fn: Callable, randomize_split: bool = True) -> Callable: +def build_kernel( + move_fn: Callable, randomize_split: bool = True, nsplits: int = 2 +) -> Callable: """Builds a generic ensemble MCMC kernel. Parameters @@ -117,57 +119,106 @@ def build_kernel(move_fn: Callable, randomize_split: bool = True) -> Callable: move_fn The move function to use (e.g., stretch_move). randomize_split - If True, randomly shuffle walker indices before splitting into red/blue sets + If True, randomly shuffle walker indices before splitting into groups each iteration. This improves mixing and matches emcee's default behavior. If False, uses a fixed contiguous split. + nsplits + Number of groups to split the ensemble into. Default is 2 (red-blue). + Each group is updated sequentially using all other groups as complementary. """ def kernel( rng_key: PRNGKey, state: EnsembleState, logdensity_fn: Callable ) -> tuple[EnsembleState, EnsembleInfo]: n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape - half_n = n_walkers // 2 - # Optionally randomize the red-blue split + # Optionally randomize the split if randomize_split: - key_shuffle, key_red, key_blue = jax.random.split(rng_key, 3) + key_shuffle, key_update = jax.random.split(rng_key) indices = jax.random.permutation(key_shuffle, n_walkers) - shuffled_state = jax.tree.map(lambda x: x[indices], state) + shuffled_coords = jax.tree.map(lambda x: x[indices], state.coords) + shuffled_log_probs = state.log_probs[indices] + shuffled_blobs = ( + None + if state.blobs is None + else jax.tree.map(lambda x: x[indices], state.blobs) + ) + shuffled_state = EnsembleState( + shuffled_coords, shuffled_log_probs, shuffled_blobs + ) else: - key_red, key_blue = jax.random.split(rng_key) + key_update = rng_key shuffled_state = state indices = jnp.arange(n_walkers) - # Red-Blue Split - walkers_red = jax.tree.map(lambda x: x[:half_n], shuffled_state) - walkers_blue = jax.tree.map(lambda x: x[half_n:], shuffled_state) + # Split into nsplits groups + group_size = n_walkers // nsplits + groups = [] + for i in range(nsplits): + start_idx = i * group_size + end_idx = (i + 1) * group_size if i < nsplits - 1 else n_walkers - # Update Red walkers using Blue as complementary - new_walkers_red, accepted_red = _update_half( - key_red, walkers_red, walkers_blue, logdensity_fn, move_fn - ) + group_coords = jax.tree.map( + lambda x: x[start_idx:end_idx], shuffled_state.coords + ) + group_log_probs = shuffled_state.log_probs[start_idx:end_idx] + group_blobs = ( + None + if shuffled_state.blobs is None + else jax.tree.map(lambda x: x[start_idx:end_idx], shuffled_state.blobs) + ) + groups.append(EnsembleState(group_coords, group_log_probs, group_blobs)) + + # Update each group sequentially using all other groups as complementary + updated_groups = list(groups) + accepted_groups = [] + + keys = jax.random.split(key_update, nsplits) + for i in range(nsplits): + # Build complementary ensemble from all other groups + other_indices = [j for j in range(nsplits) if j != i] + comp_coords_list = [updated_groups[j].coords for j in other_indices] + comp_log_probs_list = [updated_groups[j].log_probs for j in other_indices] + comp_blobs_list = [updated_groups[j].blobs for j in other_indices] + + # Concatenate complementary groups + complementary_coords = jax.tree.map( + lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_coords_list + ) + complementary_log_probs = jnp.concatenate(comp_log_probs_list, axis=0) - # Update Blue walkers using updated Red as complementary - new_walkers_blue, accepted_blue = _update_half( - key_blue, walkers_blue, new_walkers_red, logdensity_fn, move_fn - ) + if state.blobs is not None: + complementary_blobs = jax.tree.map( + lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_blobs_list + ) + else: + complementary_blobs = None + + complementary = EnsembleState( + complementary_coords, complementary_log_probs, complementary_blobs + ) + + # Update this group + updated_group, accepted = _update_half( + keys[i], groups[i], complementary, logdensity_fn, move_fn + ) + updated_groups[i] = updated_group + accepted_groups.append(accepted) - # Combine back in the shuffled order + # Combine all updated groups shuffled_coords = jax.tree.map( - lambda r, b: jnp.concatenate([r, b], axis=0), - new_walkers_red.coords, - new_walkers_blue.coords, + lambda *arrays: jnp.concatenate(arrays, axis=0), + *[g.coords for g in updated_groups], ) shuffled_log_probs = jnp.concatenate( - [new_walkers_red.log_probs, new_walkers_blue.log_probs] + [g.log_probs for g in updated_groups], axis=0 ) - shuffled_accepted = jnp.concatenate([accepted_red, accepted_blue]) + shuffled_accepted = jnp.concatenate(accepted_groups, axis=0) if state.blobs is not None: shuffled_blobs = jax.tree.map( - lambda r, b: jnp.concatenate([r, b], axis=0), - new_walkers_red.blobs, - new_walkers_blue.blobs, + lambda *arrays: jnp.concatenate(arrays, axis=0), + *[g.blobs for g in updated_groups], ) else: shuffled_blobs = None @@ -197,6 +248,28 @@ def kernel( return kernel +def _masked_select(mask, new_val, old_val): + """Helper to broadcast mask to match array rank for jnp.where. + + Parameters + ---------- + mask + Boolean mask with shape (n_walkers,) + new_val + New values to select when mask is True + old_val + Old values to select when mask is False + + Returns + ------- + Array with same shape as new_val/old_val, with values selected per mask + """ + # Reshape mask to (n_walkers, 1, 1, ...) to match the rank of new_val + expand_dims = (1,) * (new_val.ndim - 1) + mask_expanded = mask.reshape((mask.shape[0],) + expand_dims) + return jnp.where(mask_expanded, new_val, old_val) + + def _update_half( rng_key, walkers_to_update, complementary_walkers, logdensity_fn, move_fn ): @@ -246,7 +319,7 @@ def _update_half( # Build the new state for the half new_coords = jax.tree.map( - lambda prop, old: jnp.where(accepted[:, None], prop, old), + lambda prop, old: _masked_select(accepted, prop, old), proposals, walkers_to_update.coords, ) @@ -254,7 +327,7 @@ def _update_half( if walkers_to_update.blobs is not None: new_blobs = jax.tree.map( - lambda prop, old: jnp.where(accepted, prop, old), + lambda prop, old: _masked_select(accepted, prop, old), blobs_proposal, walkers_to_update.blobs, ) @@ -316,6 +389,7 @@ def as_top_level_api( has_blobs: bool = False, randomize_split: bool = True, live_dangerously: bool = False, + nsplits: int = 2, ) -> SamplingAlgorithm: """Implements the user-facing API for ensemble samplers. @@ -328,12 +402,15 @@ def as_top_level_api( has_blobs Whether the log-density function returns additional metadata (blobs). randomize_split - If True, randomly shuffle walker indices before splitting into red/blue sets + If True, randomly shuffle walker indices before splitting into groups each iteration. This improves mixing and matches emcee's default behavior. live_dangerously If False (default), warns when n_walkers < 2*ndim. Set to True to suppress. + nsplits + Number of groups to split the ensemble into. Default is 2 (red-blue). + Each group is updated sequentially using all other groups as complementary. """ - kernel = build_kernel(move_fn, randomize_split=randomize_split) + kernel = build_kernel(move_fn, randomize_split=randomize_split, nsplits=nsplits) def init_fn(position: ArrayTree, rng_key=None): return init(position, logdensity_fn, has_blobs, live_dangerously) diff --git a/blackjax/mcmc/stretch.py b/blackjax/mcmc/stretch.py index 75a22b585..ebecfb0a4 100644 --- a/blackjax/mcmc/stretch.py +++ b/blackjax/mcmc/stretch.py @@ -29,6 +29,7 @@ def as_top_level_api( has_blobs: bool = False, randomize_split: bool = True, live_dangerously: bool = False, + nsplits: int = 2, ) -> SamplingAlgorithm: """A user-facing API for the stretch move algorithm. @@ -41,10 +42,13 @@ def as_top_level_api( has_blobs Whether the logdensity function returns additional information (blobs). randomize_split - If True, randomly shuffle walker indices before splitting into red/blue sets + If True, randomly shuffle walker indices before splitting into groups each iteration. This improves mixing and matches emcee's default behavior. live_dangerously If False (default), warns when n_walkers < 2*ndim. Set to True to suppress. + nsplits + Number of groups to split the ensemble into. Default is 2 (red-blue). + Each group is updated sequentially using all other groups as complementary. Returns ------- @@ -57,6 +61,7 @@ def as_top_level_api( has_blobs, randomize_split=randomize_split, live_dangerously=live_dangerously, + nsplits=nsplits, ) @@ -79,7 +84,9 @@ def init( return ensemble_init(position, logdensity_fn, has_blobs, live_dangerously) -def build_kernel(move_fn=None, a: float = 2.0, randomize_split: bool = True): +def build_kernel( + move_fn=None, a: float = 2.0, randomize_split: bool = True, nsplits: int = 2 +): """Build the stretch move kernel. Parameters @@ -89,9 +96,13 @@ def build_kernel(move_fn=None, a: float = 2.0, randomize_split: bool = True): a The stretch parameter. Must be > 1. Default is 2.0. randomize_split - If True, randomly shuffle walker indices before splitting into red/blue sets + If True, randomly shuffle walker indices before splitting into groups each iteration. This improves mixing and matches emcee's default behavior. + nsplits + Number of groups to split the ensemble into. Default is 2 (red-blue). """ if move_fn is None: move_fn = lambda key, w, c: stretch_move(key, w, c, a) - return ensemble_build_kernel(move_fn, randomize_split=randomize_split) + return ensemble_build_kernel( + move_fn, randomize_split=randomize_split, nsplits=nsplits + ) From 7b577611e46fc2c579cc565d2e9146869fbcad82 Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 21:58:11 +0100 Subject: [PATCH 05/14] Refactor ensemble sampler for BlackJAX consistency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove all validation and warnings for consistency with BlackJAX philosophy - Removed ValueError for nsplits and a parameter bounds - Removed nwalkers < 2*ndim warning - Removed live_dangerously parameter - Replace jax.tree.map with jax.tree_util.tree_map throughout (13 instances) - Remove all inline comments (~20 instances) - rely on docstrings and clear names - Expand stretch_move docstring with full Parameters and Returns sections - Improve variable naming (use _ for unused tree_flatten return value) - Fix acceptance_rate type: use jnp.mean() instead of float() to avoid concretization errors - Fix mypy type checking for step_fn signature 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- blackjax/mcmc/ensemble.py | 128 ++++++++++++++++---------------------- blackjax/mcmc/stretch.py | 12 +--- 2 files changed, 54 insertions(+), 86 deletions(-) diff --git a/blackjax/mcmc/ensemble.py b/blackjax/mcmc/ensemble.py index 26f2f7edc..5dfd0f1b7 100644 --- a/blackjax/mcmc/ensemble.py +++ b/blackjax/mcmc/ensemble.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Core functionality for ensemble MCMC algorithms.""" -import warnings from typing import Callable, NamedTuple, Optional import jax @@ -20,7 +19,7 @@ from jax.flatten_util import ravel_pytree from blackjax.base import SamplingAlgorithm -from blackjax.types import Array, ArrayTree, PRNGKey +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey __all__ = [ "EnsembleState", @@ -56,13 +55,13 @@ class EnsembleInfo(NamedTuple): acceptance_rate The acceptance rate of the ensemble. - accepted + is_accepted A boolean array of shape `(n_walkers,)` indicating whether each walker's proposal was accepted. """ - acceptance_rate: Array - accepted: Array + acceptance_rate: float + is_accepted: Array def stretch_move( @@ -71,38 +70,50 @@ def stretch_move( complementary_coords: ArrayTree, a: float = 2.0, ) -> tuple[ArrayTree, float]: - """The emcee stretch move. + """Generate a proposal using the affine-invariant stretch move. - A proposal is generated by selecting a random walker from the complementary - ensemble and moving the current walker along the line connecting the two. + The stretch move selects a random walker from the complementary ensemble + and proposes a new position along the line connecting the two walkers, + scaled by a random factor z drawn from g(z) ∝ 1/√z on [1/a, a]. + + Parameters + ---------- + rng_key + A PRNG key for random number generation. + walker_coords + The current walker's coordinates as an array or PyTree. + complementary_coords + The coordinates of the complementary ensemble with shape (n_walkers, ...) + where the leading dimension indexes walkers. + a + The stretch scale parameter. Must be > 1. Default is 2.0. + + Returns + ------- + A tuple (proposal, log_hastings_ratio) where proposal is the proposed + position with the same structure as walker_coords, and log_hastings_ratio + is (ndim - 1) * log(z). """ key_select, key_stretch = jax.random.split(rng_key) - # Ravel coordinates to handle PyTrees walker_flat, unravel_fn = ravel_pytree(walker_coords) - # Get the shape of the complementary ensemble - # complementary_coords should have shape (n_walkers, ...) where ... matches walker_coords - comp_leaves, comp_treedef = jax.tree_util.tree_flatten(complementary_coords) + comp_leaves, _ = jax.tree_util.tree_flatten(complementary_coords) n_walkers_comp = comp_leaves[0].shape[0] - # Select a random walker from the complementary ensemble idx = jax.random.randint(key_select, (), 0, n_walkers_comp) - complementary_walker = jax.tree.map(lambda x: x[idx], complementary_coords) + complementary_walker = jax.tree_util.tree_map( + lambda x: x[idx], complementary_coords + ) - # Ravel the selected complementary walker complementary_walker_flat, _ = ravel_pytree(complementary_walker) - # Generate the stretch factor `Z` from g(z) z = ((a - 1.0) * jax.random.uniform(key_stretch) + 1) ** 2.0 / a - # Generate the proposal (Eq. 10) proposal_flat = complementary_walker_flat + z * ( walker_flat - complementary_walker_flat ) - # The log of the Hastings ratio (Eq. 11) - # Number of dimensions is the length of the flattened walker n_dims = walker_flat.shape[0] log_hastings_ratio = (n_dims - 1.0) * jnp.log(z) @@ -132,16 +143,15 @@ def kernel( ) -> tuple[EnsembleState, EnsembleInfo]: n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape - # Optionally randomize the split if randomize_split: key_shuffle, key_update = jax.random.split(rng_key) indices = jax.random.permutation(key_shuffle, n_walkers) - shuffled_coords = jax.tree.map(lambda x: x[indices], state.coords) + shuffled_coords = jax.tree_util.tree_map(lambda x: x[indices], state.coords) shuffled_log_probs = state.log_probs[indices] shuffled_blobs = ( None if state.blobs is None - else jax.tree.map(lambda x: x[indices], state.blobs) + else jax.tree_util.tree_map(lambda x: x[indices], state.blobs) ) shuffled_state = EnsembleState( shuffled_coords, shuffled_log_probs, shuffled_blobs @@ -151,44 +161,42 @@ def kernel( shuffled_state = state indices = jnp.arange(n_walkers) - # Split into nsplits groups group_size = n_walkers // nsplits groups = [] for i in range(nsplits): start_idx = i * group_size end_idx = (i + 1) * group_size if i < nsplits - 1 else n_walkers - group_coords = jax.tree.map( + group_coords = jax.tree_util.tree_map( lambda x: x[start_idx:end_idx], shuffled_state.coords ) group_log_probs = shuffled_state.log_probs[start_idx:end_idx] group_blobs = ( None if shuffled_state.blobs is None - else jax.tree.map(lambda x: x[start_idx:end_idx], shuffled_state.blobs) + else jax.tree_util.tree_map( + lambda x: x[start_idx:end_idx], shuffled_state.blobs + ) ) groups.append(EnsembleState(group_coords, group_log_probs, group_blobs)) - # Update each group sequentially using all other groups as complementary updated_groups = list(groups) accepted_groups = [] keys = jax.random.split(key_update, nsplits) for i in range(nsplits): - # Build complementary ensemble from all other groups other_indices = [j for j in range(nsplits) if j != i] comp_coords_list = [updated_groups[j].coords for j in other_indices] comp_log_probs_list = [updated_groups[j].log_probs for j in other_indices] comp_blobs_list = [updated_groups[j].blobs for j in other_indices] - # Concatenate complementary groups - complementary_coords = jax.tree.map( + complementary_coords = jax.tree_util.tree_map( lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_coords_list ) complementary_log_probs = jnp.concatenate(comp_log_probs_list, axis=0) if state.blobs is not None: - complementary_blobs = jax.tree.map( + complementary_blobs = jax.tree_util.tree_map( lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_blobs_list ) else: @@ -198,15 +206,13 @@ def kernel( complementary_coords, complementary_log_probs, complementary_blobs ) - # Update this group updated_group, accepted = _update_half( keys[i], groups[i], complementary, logdensity_fn, move_fn ) updated_groups[i] = updated_group accepted_groups.append(accepted) - # Combine all updated groups - shuffled_coords = jax.tree.map( + shuffled_coords = jax.tree_util.tree_map( lambda *arrays: jnp.concatenate(arrays, axis=0), *[g.coords for g in updated_groups], ) @@ -216,21 +222,24 @@ def kernel( shuffled_accepted = jnp.concatenate(accepted_groups, axis=0) if state.blobs is not None: - shuffled_blobs = jax.tree.map( + shuffled_blobs = jax.tree_util.tree_map( lambda *arrays: jnp.concatenate(arrays, axis=0), *[g.blobs for g in updated_groups], ) else: shuffled_blobs = None - # Unshuffle to restore original ordering if randomize_split: inverse_indices = jnp.argsort(indices) - new_coords = jax.tree.map(lambda x: x[inverse_indices], shuffled_coords) + new_coords = jax.tree_util.tree_map( + lambda x: x[inverse_indices], shuffled_coords + ) new_log_probs = shuffled_log_probs[inverse_indices] accepted = shuffled_accepted[inverse_indices] if shuffled_blobs is not None: - new_blobs = jax.tree.map(lambda x: x[inverse_indices], shuffled_blobs) + new_blobs = jax.tree_util.tree_map( + lambda x: x[inverse_indices], shuffled_blobs + ) else: new_blobs = None else: @@ -240,7 +249,7 @@ def kernel( new_blobs = shuffled_blobs new_state = EnsembleState(new_coords, new_log_probs, new_blobs) - acceptance_rate = jnp.mean(accepted.astype(jnp.float32)) + acceptance_rate = jnp.mean(accepted) info = EnsembleInfo(acceptance_rate, accepted) return new_state, info @@ -264,7 +273,6 @@ def _masked_select(mask, new_val, old_val): ------- Array with same shape as new_val/old_val, with values selected per mask """ - # Reshape mask to (n_walkers, 1, 1, ...) to match the rank of new_val expand_dims = (1,) * (new_val.ndim - 1) mask_expanded = mask.reshape((mask.shape[0],) + expand_dims) return jnp.where(mask_expanded, new_val, old_val) @@ -276,16 +284,13 @@ def _update_half( """Helper to update one half of the ensemble.""" n_update, *_ = jax.tree_util.tree_flatten(walkers_to_update.coords)[0][0].shape - # Split key for moves and acceptance to avoid key reuse key_moves, key_accept = jax.random.split(rng_key) keys = jax.random.split(key_moves, n_update) - # Vectorize the move over the walkers to be updated proposals, log_hastings_ratios = jax.vmap( lambda k, w_coords: move_fn(k, w_coords, complementary_walkers.coords) )(keys, walkers_to_update.coords) - # Compute log-probabilities for proposals logdensity_outputs = jax.vmap(logdensity_fn)(proposals) if isinstance(logdensity_outputs, tuple): log_probs_proposal, blobs_proposal = logdensity_outputs @@ -293,15 +298,10 @@ def _update_half( log_probs_proposal = logdensity_outputs blobs_proposal = None - # MH accept/reject step (Eq. 11) log_p_accept = ( log_hastings_ratios + log_probs_proposal - walkers_to_update.log_probs ) - # Handle -inf log probabilities correctly: - # - If current is -inf and proposal is finite: accept (log_p_accept = +inf) - # - If proposal is -inf: reject (log_p_accept = -inf) - # - If both are -inf: reject (log_p_accept = -inf) is_curr_fin = jnp.isfinite(walkers_to_update.log_probs) is_prop_fin = jnp.isfinite(log_probs_proposal) log_p_accept = jnp.where( @@ -317,8 +317,7 @@ def _update_half( u = jax.random.uniform(key_accept, shape=(n_update,)) accepted = jnp.log(u) < log_p_accept - # Build the new state for the half - new_coords = jax.tree.map( + new_coords = jax.tree_util.tree_map( lambda prop, old: _masked_select(accepted, prop, old), proposals, walkers_to_update.coords, @@ -326,7 +325,7 @@ def _update_half( new_log_probs = jnp.where(accepted, log_probs_proposal, walkers_to_update.log_probs) if walkers_to_update.blobs is not None: - new_blobs = jax.tree.map( + new_blobs = jax.tree_util.tree_map( lambda prop, old: _masked_select(accepted, prop, old), blobs_proposal, walkers_to_update.blobs, @@ -339,10 +338,9 @@ def _update_half( def init( - position: ArrayTree, + position: ArrayLikeTree, logdensity_fn: Callable, has_blobs: bool = False, - live_dangerously: bool = False, ) -> EnsembleState: """Initializes the ensemble. @@ -354,26 +352,7 @@ def init( The log-density function to evaluate. has_blobs Whether the log-density function returns additional metadata (blobs). - live_dangerously - If False (default), warns when n_walkers < 2*ndim, which can lead to poor - mixing. Set to True to suppress this warning. """ - # Get number of walkers and dimensions - leaves, _ = jax.tree_util.tree_flatten(position) - n_walkers = leaves[0].shape[0] - flat_sample, _ = ravel_pytree(jax.tree.map(lambda x: x[0], position)) - ndim = flat_sample.shape[0] - - # Warn if n_walkers < 2*ndim (following emcee's recommendation) - if not live_dangerously and n_walkers < 2 * ndim: - warnings.warn( - f"Running ensemble sampler with {n_walkers} walkers for {ndim} dimensions. " - f"For optimal performance and mixing, emcee recommends at least {2 * ndim} walkers. " - f"Set live_dangerously=True to suppress this warning.", - UserWarning, - stacklevel=2, - ) - logdensity_outputs = jax.vmap(logdensity_fn)(position) if isinstance(logdensity_outputs, tuple): log_probs, blobs = logdensity_outputs @@ -388,7 +367,6 @@ def as_top_level_api( move_fn: Callable, has_blobs: bool = False, randomize_split: bool = True, - live_dangerously: bool = False, nsplits: int = 2, ) -> SamplingAlgorithm: """Implements the user-facing API for ensemble samplers. @@ -404,8 +382,6 @@ def as_top_level_api( randomize_split If True, randomly shuffle walker indices before splitting into groups each iteration. This improves mixing and matches emcee's default behavior. - live_dangerously - If False (default), warns when n_walkers < 2*ndim. Set to True to suppress. nsplits Number of groups to split the ensemble into. Default is 2 (red-blue). Each group is updated sequentially using all other groups as complementary. @@ -413,9 +389,9 @@ def as_top_level_api( kernel = build_kernel(move_fn, randomize_split=randomize_split, nsplits=nsplits) def init_fn(position: ArrayTree, rng_key=None): - return init(position, logdensity_fn, has_blobs, live_dangerously) + return init(position, logdensity_fn, has_blobs) - def step_fn(rng_key: PRNGKey, state: EnsembleState): + def step_fn(rng_key: PRNGKey, state) -> tuple[EnsembleState, EnsembleInfo]: return kernel(rng_key, state, logdensity_fn) return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/stretch.py b/blackjax/mcmc/stretch.py index ebecfb0a4..58482ea40 100644 --- a/blackjax/mcmc/stretch.py +++ b/blackjax/mcmc/stretch.py @@ -28,7 +28,6 @@ def as_top_level_api( a: float = 2.0, has_blobs: bool = False, randomize_split: bool = True, - live_dangerously: bool = False, nsplits: int = 2, ) -> SamplingAlgorithm: """A user-facing API for the stretch move algorithm. @@ -44,8 +43,6 @@ def as_top_level_api( randomize_split If True, randomly shuffle walker indices before splitting into groups each iteration. This improves mixing and matches emcee's default behavior. - live_dangerously - If False (default), warns when n_walkers < 2*ndim. Set to True to suppress. nsplits Number of groups to split the ensemble into. Default is 2 (red-blue). Each group is updated sequentially using all other groups as complementary. @@ -60,14 +57,11 @@ def as_top_level_api( move, has_blobs, randomize_split=randomize_split, - live_dangerously=live_dangerously, nsplits=nsplits, ) -def init( - position, logdensity_fn, has_blobs: bool = False, live_dangerously: bool = False -): +def init(position, logdensity_fn, has_blobs: bool = False): """Initialize the stretch move algorithm. Parameters @@ -78,10 +72,8 @@ def init( The log-density function to evaluate. has_blobs Whether the log-density function returns additional metadata (blobs). - live_dangerously - If False (default), warns when n_walkers < 2*ndim. Set to True to suppress. """ - return ensemble_init(position, logdensity_fn, has_blobs, live_dangerously) + return ensemble_init(position, logdensity_fn, has_blobs) def build_kernel( From d5b274545c29b862cfd709a820a0667a9da606eb Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 22:06:10 +0100 Subject: [PATCH 06/14] Merge stretch.py into ensemble.py and rename to stretch.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replaced generic ensemble API with stretch-specific API - as_top_level_api() now takes 'a' parameter instead of 'move_fn' - build_kernel() now takes optional 'a' parameter, defaults to stretch_move - Renamed ensemble.py to stretch.py (single file now) - Deleted redundant stretch.py wrapper - Updated test imports from ensemble to stretch - Removed unused imports from test file - Kept EnsembleState/EnsembleInfo names (describe multi-walker state) Simplifies codebase by eliminating unnecessary abstraction layer. The "ensemble framework" was only used for stretch move. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- blackjax/mcmc/ensemble.py | 397 ------------------------------------ blackjax/mcmc/stretch.py | 382 ++++++++++++++++++++++++++++++---- tests/mcmc/test_ensemble.py | 66 +++--- 3 files changed, 374 insertions(+), 471 deletions(-) delete mode 100644 blackjax/mcmc/ensemble.py diff --git a/blackjax/mcmc/ensemble.py b/blackjax/mcmc/ensemble.py deleted file mode 100644 index 5dfd0f1b7..000000000 --- a/blackjax/mcmc/ensemble.py +++ /dev/null @@ -1,397 +0,0 @@ -# Copyright 2020- The Blackjax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Core functionality for ensemble MCMC algorithms.""" -from typing import Callable, NamedTuple, Optional - -import jax -import jax.numpy as jnp -from jax.flatten_util import ravel_pytree - -from blackjax.base import SamplingAlgorithm -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey - -__all__ = [ - "EnsembleState", - "EnsembleInfo", - "init", - "build_kernel", - "as_top_level_api", - "stretch_move", -] - - -class EnsembleState(NamedTuple): - """State of an ensemble sampler. - - coords - An array or PyTree of arrays of shape `(n_walkers, ...)` that - stores the current position of the walkers. - log_probs - An array of shape `(n_walkers,)` that stores the log-probability of - each walker. - blobs - An optional PyTree that stores metadata returned by the log-probability - function. - """ - - coords: ArrayTree - log_probs: Array - blobs: Optional[ArrayTree] = None - - -class EnsembleInfo(NamedTuple): - """Additional information on the ensemble transition. - - acceptance_rate - The acceptance rate of the ensemble. - is_accepted - A boolean array of shape `(n_walkers,)` indicating whether each walker's - proposal was accepted. - """ - - acceptance_rate: float - is_accepted: Array - - -def stretch_move( - rng_key: PRNGKey, - walker_coords: ArrayTree, - complementary_coords: ArrayTree, - a: float = 2.0, -) -> tuple[ArrayTree, float]: - """Generate a proposal using the affine-invariant stretch move. - - The stretch move selects a random walker from the complementary ensemble - and proposes a new position along the line connecting the two walkers, - scaled by a random factor z drawn from g(z) ∝ 1/√z on [1/a, a]. - - Parameters - ---------- - rng_key - A PRNG key for random number generation. - walker_coords - The current walker's coordinates as an array or PyTree. - complementary_coords - The coordinates of the complementary ensemble with shape (n_walkers, ...) - where the leading dimension indexes walkers. - a - The stretch scale parameter. Must be > 1. Default is 2.0. - - Returns - ------- - A tuple (proposal, log_hastings_ratio) where proposal is the proposed - position with the same structure as walker_coords, and log_hastings_ratio - is (ndim - 1) * log(z). - """ - key_select, key_stretch = jax.random.split(rng_key) - - walker_flat, unravel_fn = ravel_pytree(walker_coords) - - comp_leaves, _ = jax.tree_util.tree_flatten(complementary_coords) - n_walkers_comp = comp_leaves[0].shape[0] - - idx = jax.random.randint(key_select, (), 0, n_walkers_comp) - complementary_walker = jax.tree_util.tree_map( - lambda x: x[idx], complementary_coords - ) - - complementary_walker_flat, _ = ravel_pytree(complementary_walker) - - z = ((a - 1.0) * jax.random.uniform(key_stretch) + 1) ** 2.0 / a - - proposal_flat = complementary_walker_flat + z * ( - walker_flat - complementary_walker_flat - ) - - n_dims = walker_flat.shape[0] - log_hastings_ratio = (n_dims - 1.0) * jnp.log(z) - - return unravel_fn(proposal_flat), log_hastings_ratio - - -def build_kernel( - move_fn: Callable, randomize_split: bool = True, nsplits: int = 2 -) -> Callable: - """Builds a generic ensemble MCMC kernel. - - Parameters - ---------- - move_fn - The move function to use (e.g., stretch_move). - randomize_split - If True, randomly shuffle walker indices before splitting into groups - each iteration. This improves mixing and matches emcee's default behavior. - If False, uses a fixed contiguous split. - nsplits - Number of groups to split the ensemble into. Default is 2 (red-blue). - Each group is updated sequentially using all other groups as complementary. - """ - - def kernel( - rng_key: PRNGKey, state: EnsembleState, logdensity_fn: Callable - ) -> tuple[EnsembleState, EnsembleInfo]: - n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape - - if randomize_split: - key_shuffle, key_update = jax.random.split(rng_key) - indices = jax.random.permutation(key_shuffle, n_walkers) - shuffled_coords = jax.tree_util.tree_map(lambda x: x[indices], state.coords) - shuffled_log_probs = state.log_probs[indices] - shuffled_blobs = ( - None - if state.blobs is None - else jax.tree_util.tree_map(lambda x: x[indices], state.blobs) - ) - shuffled_state = EnsembleState( - shuffled_coords, shuffled_log_probs, shuffled_blobs - ) - else: - key_update = rng_key - shuffled_state = state - indices = jnp.arange(n_walkers) - - group_size = n_walkers // nsplits - groups = [] - for i in range(nsplits): - start_idx = i * group_size - end_idx = (i + 1) * group_size if i < nsplits - 1 else n_walkers - - group_coords = jax.tree_util.tree_map( - lambda x: x[start_idx:end_idx], shuffled_state.coords - ) - group_log_probs = shuffled_state.log_probs[start_idx:end_idx] - group_blobs = ( - None - if shuffled_state.blobs is None - else jax.tree_util.tree_map( - lambda x: x[start_idx:end_idx], shuffled_state.blobs - ) - ) - groups.append(EnsembleState(group_coords, group_log_probs, group_blobs)) - - updated_groups = list(groups) - accepted_groups = [] - - keys = jax.random.split(key_update, nsplits) - for i in range(nsplits): - other_indices = [j for j in range(nsplits) if j != i] - comp_coords_list = [updated_groups[j].coords for j in other_indices] - comp_log_probs_list = [updated_groups[j].log_probs for j in other_indices] - comp_blobs_list = [updated_groups[j].blobs for j in other_indices] - - complementary_coords = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_coords_list - ) - complementary_log_probs = jnp.concatenate(comp_log_probs_list, axis=0) - - if state.blobs is not None: - complementary_blobs = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_blobs_list - ) - else: - complementary_blobs = None - - complementary = EnsembleState( - complementary_coords, complementary_log_probs, complementary_blobs - ) - - updated_group, accepted = _update_half( - keys[i], groups[i], complementary, logdensity_fn, move_fn - ) - updated_groups[i] = updated_group - accepted_groups.append(accepted) - - shuffled_coords = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), - *[g.coords for g in updated_groups], - ) - shuffled_log_probs = jnp.concatenate( - [g.log_probs for g in updated_groups], axis=0 - ) - shuffled_accepted = jnp.concatenate(accepted_groups, axis=0) - - if state.blobs is not None: - shuffled_blobs = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), - *[g.blobs for g in updated_groups], - ) - else: - shuffled_blobs = None - - if randomize_split: - inverse_indices = jnp.argsort(indices) - new_coords = jax.tree_util.tree_map( - lambda x: x[inverse_indices], shuffled_coords - ) - new_log_probs = shuffled_log_probs[inverse_indices] - accepted = shuffled_accepted[inverse_indices] - if shuffled_blobs is not None: - new_blobs = jax.tree_util.tree_map( - lambda x: x[inverse_indices], shuffled_blobs - ) - else: - new_blobs = None - else: - new_coords = shuffled_coords - new_log_probs = shuffled_log_probs - accepted = shuffled_accepted - new_blobs = shuffled_blobs - - new_state = EnsembleState(new_coords, new_log_probs, new_blobs) - acceptance_rate = jnp.mean(accepted) - info = EnsembleInfo(acceptance_rate, accepted) - - return new_state, info - - return kernel - - -def _masked_select(mask, new_val, old_val): - """Helper to broadcast mask to match array rank for jnp.where. - - Parameters - ---------- - mask - Boolean mask with shape (n_walkers,) - new_val - New values to select when mask is True - old_val - Old values to select when mask is False - - Returns - ------- - Array with same shape as new_val/old_val, with values selected per mask - """ - expand_dims = (1,) * (new_val.ndim - 1) - mask_expanded = mask.reshape((mask.shape[0],) + expand_dims) - return jnp.where(mask_expanded, new_val, old_val) - - -def _update_half( - rng_key, walkers_to_update, complementary_walkers, logdensity_fn, move_fn -): - """Helper to update one half of the ensemble.""" - n_update, *_ = jax.tree_util.tree_flatten(walkers_to_update.coords)[0][0].shape - - key_moves, key_accept = jax.random.split(rng_key) - keys = jax.random.split(key_moves, n_update) - - proposals, log_hastings_ratios = jax.vmap( - lambda k, w_coords: move_fn(k, w_coords, complementary_walkers.coords) - )(keys, walkers_to_update.coords) - - logdensity_outputs = jax.vmap(logdensity_fn)(proposals) - if isinstance(logdensity_outputs, tuple): - log_probs_proposal, blobs_proposal = logdensity_outputs - else: - log_probs_proposal = logdensity_outputs - blobs_proposal = None - - log_p_accept = ( - log_hastings_ratios + log_probs_proposal - walkers_to_update.log_probs - ) - - is_curr_fin = jnp.isfinite(walkers_to_update.log_probs) - is_prop_fin = jnp.isfinite(log_probs_proposal) - log_p_accept = jnp.where( - ~is_curr_fin & is_prop_fin, - jnp.inf, - jnp.where( - is_curr_fin & ~is_prop_fin, - -jnp.inf, - jnp.where(~is_curr_fin & ~is_prop_fin, -jnp.inf, log_p_accept), - ), - ) - - u = jax.random.uniform(key_accept, shape=(n_update,)) - accepted = jnp.log(u) < log_p_accept - - new_coords = jax.tree_util.tree_map( - lambda prop, old: _masked_select(accepted, prop, old), - proposals, - walkers_to_update.coords, - ) - new_log_probs = jnp.where(accepted, log_probs_proposal, walkers_to_update.log_probs) - - if walkers_to_update.blobs is not None: - new_blobs = jax.tree_util.tree_map( - lambda prop, old: _masked_select(accepted, prop, old), - blobs_proposal, - walkers_to_update.blobs, - ) - else: - new_blobs = None - - new_walkers = EnsembleState(new_coords, new_log_probs, new_blobs) - return new_walkers, accepted - - -def init( - position: ArrayLikeTree, - logdensity_fn: Callable, - has_blobs: bool = False, -) -> EnsembleState: - """Initializes the ensemble. - - Parameters - ---------- - position - Initial positions for all walkers, with shape (n_walkers, ...). - logdensity_fn - The log-density function to evaluate. - has_blobs - Whether the log-density function returns additional metadata (blobs). - """ - logdensity_outputs = jax.vmap(logdensity_fn)(position) - if isinstance(logdensity_outputs, tuple): - log_probs, blobs = logdensity_outputs - return EnsembleState(position, log_probs, blobs) - else: - log_probs = logdensity_outputs - return EnsembleState(position, log_probs, None) - - -def as_top_level_api( - logdensity_fn: Callable, - move_fn: Callable, - has_blobs: bool = False, - randomize_split: bool = True, - nsplits: int = 2, -) -> SamplingAlgorithm: - """Implements the user-facing API for ensemble samplers. - - Parameters - ---------- - logdensity_fn - The log-density function to sample from. - move_fn - The move function to use (e.g., stretch_move). - has_blobs - Whether the log-density function returns additional metadata (blobs). - randomize_split - If True, randomly shuffle walker indices before splitting into groups - each iteration. This improves mixing and matches emcee's default behavior. - nsplits - Number of groups to split the ensemble into. Default is 2 (red-blue). - Each group is updated sequentially using all other groups as complementary. - """ - kernel = build_kernel(move_fn, randomize_split=randomize_split, nsplits=nsplits) - - def init_fn(position: ArrayTree, rng_key=None): - return init(position, logdensity_fn, has_blobs) - - def step_fn(rng_key: PRNGKey, state) -> tuple[EnsembleState, EnsembleInfo]: - return kernel(rng_key, state, logdensity_fn) - - return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/stretch.py b/blackjax/mcmc/stretch.py index 58482ea40..7afda54cb 100644 --- a/blackjax/mcmc/stretch.py +++ b/blackjax/mcmc/stretch.py @@ -12,56 +12,338 @@ # See the License for the specific language governing permissions and # limitations under the License. """Public API for the Stretch Move ensemble sampler.""" -from typing import Callable +from typing import Callable, NamedTuple, Optional + +import jax +import jax.numpy as jnp +from jax.flatten_util import ravel_pytree from blackjax.base import SamplingAlgorithm -from blackjax.mcmc.ensemble import as_top_level_api as ensemble_api -from blackjax.mcmc.ensemble import build_kernel as ensemble_build_kernel -from blackjax.mcmc.ensemble import init as ensemble_init -from blackjax.mcmc.ensemble import stretch_move +from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey -__all__ = ["as_top_level_api", "init", "build_kernel"] +__all__ = [ + "EnsembleState", + "EnsembleInfo", + "init", + "build_kernel", + "as_top_level_api", + "stretch_move", +] -def as_top_level_api( - logdensity_fn: Callable, +class EnsembleState(NamedTuple): + """State of an ensemble sampler. + + coords + An array or PyTree of arrays of shape `(n_walkers, ...)` that + stores the current position of the walkers. + log_probs + An array of shape `(n_walkers,)` that stores the log-probability of + each walker. + blobs + An optional PyTree that stores metadata returned by the log-probability + function. + """ + + coords: ArrayTree + log_probs: Array + blobs: Optional[ArrayTree] = None + + +class EnsembleInfo(NamedTuple): + """Additional information on the ensemble transition. + + acceptance_rate + The acceptance rate of the ensemble. + is_accepted + A boolean array of shape `(n_walkers,)` indicating whether each walker's + proposal was accepted. + """ + + acceptance_rate: float + is_accepted: Array + + +def stretch_move( + rng_key: PRNGKey, + walker_coords: ArrayTree, + complementary_coords: ArrayTree, a: float = 2.0, - has_blobs: bool = False, - randomize_split: bool = True, - nsplits: int = 2, -) -> SamplingAlgorithm: - """A user-facing API for the stretch move algorithm. +) -> tuple[ArrayTree, float]: + """Generate a proposal using the affine-invariant stretch move. + + The stretch move selects a random walker from the complementary ensemble + and proposes a new position along the line connecting the two walkers, + scaled by a random factor z drawn from g(z) ∝ 1/√z on [1/a, a]. Parameters ---------- - logdensity_fn - A function that returns the log density of the model at a given position. + rng_key + A PRNG key for random number generation. + walker_coords + The current walker's coordinates as an array or PyTree. + complementary_coords + The coordinates of the complementary ensemble with shape (n_walkers, ...) + where the leading dimension indexes walkers. + a + The stretch scale parameter. Must be > 1. Default is 2.0. + + Returns + ------- + A tuple (proposal, log_hastings_ratio) where proposal is the proposed + position with the same structure as walker_coords, and log_hastings_ratio + is (ndim - 1) * log(z). + """ + key_select, key_stretch = jax.random.split(rng_key) + + walker_flat, unravel_fn = ravel_pytree(walker_coords) + + comp_leaves, _ = jax.tree_util.tree_flatten(complementary_coords) + n_walkers_comp = comp_leaves[0].shape[0] + + idx = jax.random.randint(key_select, (), 0, n_walkers_comp) + complementary_walker = jax.tree_util.tree_map( + lambda x: x[idx], complementary_coords + ) + + complementary_walker_flat, _ = ravel_pytree(complementary_walker) + + z = ((a - 1.0) * jax.random.uniform(key_stretch) + 1) ** 2.0 / a + + proposal_flat = complementary_walker_flat + z * ( + walker_flat - complementary_walker_flat + ) + + n_dims = walker_flat.shape[0] + log_hastings_ratio = (n_dims - 1.0) * jnp.log(z) + + return unravel_fn(proposal_flat), log_hastings_ratio + + +def build_kernel( + move_fn=None, a: float = 2.0, randomize_split: bool = True, nsplits: int = 2 +) -> Callable: + """Build the stretch move kernel. + + Parameters + ---------- + move_fn + Optional custom move function. If None, uses stretch_move with parameter a. a The stretch parameter. Must be > 1. Default is 2.0. - has_blobs - Whether the logdensity function returns additional information (blobs). randomize_split If True, randomly shuffle walker indices before splitting into groups each iteration. This improves mixing and matches emcee's default behavior. nsplits Number of groups to split the ensemble into. Default is 2 (red-blue). - Each group is updated sequentially using all other groups as complementary. + """ + if move_fn is None: + move_fn = lambda key, w, c: stretch_move(key, w, c, a) + + def kernel( + rng_key: PRNGKey, state: EnsembleState, logdensity_fn: Callable + ) -> tuple[EnsembleState, EnsembleInfo]: + n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape + + if randomize_split: + key_shuffle, key_update = jax.random.split(rng_key) + indices = jax.random.permutation(key_shuffle, n_walkers) + shuffled_coords = jax.tree_util.tree_map(lambda x: x[indices], state.coords) + shuffled_log_probs = state.log_probs[indices] + shuffled_blobs = ( + None + if state.blobs is None + else jax.tree_util.tree_map(lambda x: x[indices], state.blobs) + ) + shuffled_state = EnsembleState( + shuffled_coords, shuffled_log_probs, shuffled_blobs + ) + else: + key_update = rng_key + shuffled_state = state + indices = jnp.arange(n_walkers) + + group_size = n_walkers // nsplits + groups = [] + for i in range(nsplits): + start_idx = i * group_size + end_idx = (i + 1) * group_size if i < nsplits - 1 else n_walkers + + group_coords = jax.tree_util.tree_map( + lambda x: x[start_idx:end_idx], shuffled_state.coords + ) + group_log_probs = shuffled_state.log_probs[start_idx:end_idx] + group_blobs = ( + None + if shuffled_state.blobs is None + else jax.tree_util.tree_map( + lambda x: x[start_idx:end_idx], shuffled_state.blobs + ) + ) + groups.append(EnsembleState(group_coords, group_log_probs, group_blobs)) + + updated_groups = list(groups) + accepted_groups = [] + + keys = jax.random.split(key_update, nsplits) + for i in range(nsplits): + other_indices = [j for j in range(nsplits) if j != i] + comp_coords_list = [updated_groups[j].coords for j in other_indices] + comp_log_probs_list = [updated_groups[j].log_probs for j in other_indices] + comp_blobs_list = [updated_groups[j].blobs for j in other_indices] + + complementary_coords = jax.tree_util.tree_map( + lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_coords_list + ) + complementary_log_probs = jnp.concatenate(comp_log_probs_list, axis=0) + + if state.blobs is not None: + complementary_blobs = jax.tree_util.tree_map( + lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_blobs_list + ) + else: + complementary_blobs = None + + complementary = EnsembleState( + complementary_coords, complementary_log_probs, complementary_blobs + ) + + updated_group, accepted = _update_half( + keys[i], groups[i], complementary, logdensity_fn, move_fn + ) + updated_groups[i] = updated_group + accepted_groups.append(accepted) + + shuffled_coords = jax.tree_util.tree_map( + lambda *arrays: jnp.concatenate(arrays, axis=0), + *[g.coords for g in updated_groups], + ) + shuffled_log_probs = jnp.concatenate( + [g.log_probs for g in updated_groups], axis=0 + ) + shuffled_accepted = jnp.concatenate(accepted_groups, axis=0) + + if state.blobs is not None: + shuffled_blobs = jax.tree_util.tree_map( + lambda *arrays: jnp.concatenate(arrays, axis=0), + *[g.blobs for g in updated_groups], + ) + else: + shuffled_blobs = None + + if randomize_split: + inverse_indices = jnp.argsort(indices) + new_coords = jax.tree_util.tree_map( + lambda x: x[inverse_indices], shuffled_coords + ) + new_log_probs = shuffled_log_probs[inverse_indices] + accepted = shuffled_accepted[inverse_indices] + if shuffled_blobs is not None: + new_blobs = jax.tree_util.tree_map( + lambda x: x[inverse_indices], shuffled_blobs + ) + else: + new_blobs = None + else: + new_coords = shuffled_coords + new_log_probs = shuffled_log_probs + accepted = shuffled_accepted + new_blobs = shuffled_blobs + + new_state = EnsembleState(new_coords, new_log_probs, new_blobs) + acceptance_rate = jnp.mean(accepted) + info = EnsembleInfo(acceptance_rate, accepted) + + return new_state, info + + return kernel + + +def _masked_select(mask, new_val, old_val): + """Helper to broadcast mask to match array rank for jnp.where. + + Parameters + ---------- + mask + Boolean mask with shape (n_walkers,) + new_val + New values to select when mask is True + old_val + Old values to select when mask is False Returns ------- - A `SamplingAlgorithm` that can be used to sample from the target distribution. + Array with same shape as new_val/old_val, with values selected per mask """ - move = lambda key, w, c: stretch_move(key, w, c, a) - return ensemble_api( - logdensity_fn, - move, - has_blobs, - randomize_split=randomize_split, - nsplits=nsplits, + expand_dims = (1,) * (new_val.ndim - 1) + mask_expanded = mask.reshape((mask.shape[0],) + expand_dims) + return jnp.where(mask_expanded, new_val, old_val) + + +def _update_half( + rng_key, walkers_to_update, complementary_walkers, logdensity_fn, move_fn +): + """Helper to update one half of the ensemble.""" + n_update, *_ = jax.tree_util.tree_flatten(walkers_to_update.coords)[0][0].shape + + key_moves, key_accept = jax.random.split(rng_key) + keys = jax.random.split(key_moves, n_update) + + proposals, log_hastings_ratios = jax.vmap( + lambda k, w_coords: move_fn(k, w_coords, complementary_walkers.coords) + )(keys, walkers_to_update.coords) + + logdensity_outputs = jax.vmap(logdensity_fn)(proposals) + if isinstance(logdensity_outputs, tuple): + log_probs_proposal, blobs_proposal = logdensity_outputs + else: + log_probs_proposal = logdensity_outputs + blobs_proposal = None + + log_p_accept = ( + log_hastings_ratios + log_probs_proposal - walkers_to_update.log_probs + ) + + is_curr_fin = jnp.isfinite(walkers_to_update.log_probs) + is_prop_fin = jnp.isfinite(log_probs_proposal) + log_p_accept = jnp.where( + ~is_curr_fin & is_prop_fin, + jnp.inf, + jnp.where( + is_curr_fin & ~is_prop_fin, + -jnp.inf, + jnp.where(~is_curr_fin & ~is_prop_fin, -jnp.inf, log_p_accept), + ), ) + u = jax.random.uniform(key_accept, shape=(n_update,)) + accepted = jnp.log(u) < log_p_accept -def init(position, logdensity_fn, has_blobs: bool = False): + new_coords = jax.tree_util.tree_map( + lambda prop, old: _masked_select(accepted, prop, old), + proposals, + walkers_to_update.coords, + ) + new_log_probs = jnp.where(accepted, log_probs_proposal, walkers_to_update.log_probs) + + if walkers_to_update.blobs is not None: + new_blobs = jax.tree_util.tree_map( + lambda prop, old: _masked_select(accepted, prop, old), + blobs_proposal, + walkers_to_update.blobs, + ) + else: + new_blobs = None + + new_walkers = EnsembleState(new_coords, new_log_probs, new_blobs) + return new_walkers, accepted + + +def init( + position: ArrayLikeTree, + logdensity_fn: Callable, + has_blobs: bool = False, +) -> EnsembleState: """Initialize the stretch move algorithm. Parameters @@ -73,28 +355,50 @@ def init(position, logdensity_fn, has_blobs: bool = False): has_blobs Whether the log-density function returns additional metadata (blobs). """ - return ensemble_init(position, logdensity_fn, has_blobs) + logdensity_outputs = jax.vmap(logdensity_fn)(position) + if isinstance(logdensity_outputs, tuple): + log_probs, blobs = logdensity_outputs + return EnsembleState(position, log_probs, blobs) + else: + log_probs = logdensity_outputs + return EnsembleState(position, log_probs, None) -def build_kernel( - move_fn=None, a: float = 2.0, randomize_split: bool = True, nsplits: int = 2 -): - """Build the stretch move kernel. +def as_top_level_api( + logdensity_fn: Callable, + a: float = 2.0, + has_blobs: bool = False, + randomize_split: bool = True, + nsplits: int = 2, +) -> SamplingAlgorithm: + """A user-facing API for the stretch move algorithm. Parameters ---------- - move_fn - Optional custom move function. If None, uses stretch_move with parameter a. + logdensity_fn + A function that returns the log density of the model at a given position. a The stretch parameter. Must be > 1. Default is 2.0. + has_blobs + Whether the logdensity function returns additional information (blobs). randomize_split If True, randomly shuffle walker indices before splitting into groups each iteration. This improves mixing and matches emcee's default behavior. nsplits Number of groups to split the ensemble into. Default is 2 (red-blue). + Each group is updated sequentially using all other groups as complementary. + + Returns + ------- + A `SamplingAlgorithm` that can be used to sample from the target distribution. """ - if move_fn is None: - move_fn = lambda key, w, c: stretch_move(key, w, c, a) - return ensemble_build_kernel( - move_fn, randomize_split=randomize_split, nsplits=nsplits - ) + move_fn = lambda key, w, c: stretch_move(key, w, c, a) + kernel = build_kernel(move_fn, randomize_split=randomize_split, nsplits=nsplits) + + def init_fn(position: ArrayTree, rng_key=None): + return init(position, logdensity_fn, has_blobs) + + def step_fn(rng_key: PRNGKey, state) -> tuple[EnsembleState, EnsembleInfo]: + return kernel(rng_key, state, logdensity_fn) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/tests/mcmc/test_ensemble.py b/tests/mcmc/test_ensemble.py index 17464b180..c24abe3b1 100644 --- a/tests/mcmc/test_ensemble.py +++ b/tests/mcmc/test_ensemble.py @@ -1,17 +1,13 @@ """Test the ensemble MCMC kernels.""" -import functools - import chex import jax import jax.numpy as jnp import jax.scipy.stats as stats -import numpy as np -from absl.testing import absltest, parameterized +from absl.testing import absltest import blackjax -from blackjax.mcmc.ensemble import EnsembleState, stretch_move -from blackjax.util import run_inference_algorithm +from blackjax.mcmc.stretch import EnsembleState, stretch_move class EnsembleTest(chex.TestCase): @@ -20,19 +16,19 @@ class EnsembleTest(chex.TestCase): def test_stretch_move(self): """Test that stretch_move produces valid proposals.""" rng_key = jax.random.PRNGKey(0) - + # Simple 2D case walker_coords = jnp.array([1.0, 2.0]) complementary_coords = jnp.array([[0.0, 0.0], [2.0, 4.0], [3.0, 1.0]]) - + proposal, log_hastings_ratio = stretch_move( rng_key, walker_coords, complementary_coords, a=2.0 ) - + # Check shapes self.assertEqual(proposal.shape, walker_coords.shape) self.assertEqual(log_hastings_ratio.shape, ()) - + # Check that proposal is finite self.assertTrue(jnp.isfinite(proposal).all()) self.assertTrue(jnp.isfinite(log_hastings_ratio)) @@ -40,18 +36,18 @@ def test_stretch_move(self): def test_stretch_move_pytree(self): """Test that stretch_move works with PyTree structures.""" rng_key = jax.random.PRNGKey(0) - + # PyTree case walker_coords = {"a": jnp.array([1.0, 2.0]), "b": jnp.array(3.0)} complementary_coords = { "a": jnp.array([[0.0, 0.0], [2.0, 4.0], [3.0, 1.0]]), - "b": jnp.array([1.0, 2.0, 3.0]) + "b": jnp.array([1.0, 2.0, 3.0]), } - + proposal, log_hastings_ratio = stretch_move( rng_key, walker_coords, complementary_coords, a=2.0 ) - + # Check structure self.assertEqual(set(proposal.keys()), {"a", "b"}) self.assertEqual(proposal["a"].shape, walker_coords["a"].shape) @@ -60,82 +56,82 @@ def test_stretch_move_pytree(self): def test_stretch_algorithm_2d_gaussian(self): """Test the stretch algorithm on a 2D Gaussian distribution.""" - + # Define a 2D Gaussian target mu = jnp.array([1.0, 2.0]) cov = jnp.array([[1.0, 0.5], [0.5, 2.0]]) - + def logdensity_fn(x): return stats.multivariate_normal.logpdf(x, mu, cov) - + # Initialize ensemble of 20 walkers rng_key = jax.random.PRNGKey(42) init_key, sample_key = jax.random.split(rng_key) - + n_walkers = 20 initial_position = jax.random.normal(init_key, (n_walkers, 2)) - + # Create algorithm algorithm = blackjax.stretch(logdensity_fn, a=2.0) initial_state = algorithm.init(initial_position) - + # Run a few steps def run_step(state, key): new_state, info = algorithm.step(key, state) return new_state, (new_state, info) - + keys = jax.random.split(sample_key, 100) final_state, (states, infos) = jax.lax.scan(run_step, initial_state, keys) - + # Check that we get valid states self.assertIsInstance(final_state, EnsembleState) self.assertEqual(final_state.coords.shape, (n_walkers, 2)) self.assertEqual(final_state.log_probs.shape, (n_walkers,)) - + # Check that acceptance rate is reasonable mean_acceptance = jnp.mean(infos.acceptance_rate) self.assertGreater(mean_acceptance, 0.1) # Should accept some proposals - self.assertLess(mean_acceptance, 0.9) # Should reject some proposals + self.assertLess(mean_acceptance, 0.9) # Should reject some proposals def test_stretch_algorithm_convergence(self): """Test that the stretch algorithm converges to the correct distribution.""" - + # Simple 1D Gaussian mu = 2.0 sigma = 1.5 - + def logdensity_fn(x): return stats.norm.logpdf(x.squeeze(), mu, sigma) - + rng_key = jax.random.PRNGKey(123) init_key, sample_key = jax.random.split(rng_key) - + n_walkers = 50 initial_position = jax.random.normal(init_key, (n_walkers, 1)) - + # Run algorithm algorithm = blackjax.stretch(logdensity_fn, a=2.0) initial_state = algorithm.init(initial_position) - + def run_step(state, key): new_state, info = algorithm.step(key, state) return new_state, new_state.coords - + keys = jax.random.split(sample_key, 1000) final_state, samples = jax.lax.scan(run_step, initial_state, keys) - + # Take samples from the second half (burn-in) samples = samples[500:] # Shape: (500, n_walkers, 1) samples = samples.reshape(-1, 1) # Flatten to (500 * n_walkers, 1) - + # Check convergence sample_mean = jnp.mean(samples) sample_std = jnp.std(samples) - + # Allow for some tolerance due to finite sampling self.assertAlmostEqual(sample_mean.item(), mu, places=1) self.assertAlmostEqual(sample_std.item(), sigma, places=1) if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From d54b803f8923f02afce4eaebbb4711c385e372a6 Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 22:09:26 +0100 Subject: [PATCH 07/14] Rename stretch.py back to ensemble.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - File now named after sampling approach (ensemble), not move (stretch) - Matches class names (EnsembleState, EnsembleInfo) - Public API remains blackjax.stretch() via alias in mcmc/__init__.py - Updated test imports to use blackjax.mcmc.ensemble Rationale: "Ensemble" describes the fundamental multi-walker approach, while "stretch" is just one possible proposal move. This naming better reflects the algorithm's core concept and allows for future moves. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- blackjax/mcmc/__init__.py | 4 +++- blackjax/mcmc/{stretch.py => ensemble.py} | 0 tests/mcmc/test_ensemble.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) rename blackjax/mcmc/{stretch.py => ensemble.py} (100%) diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 337968552..c106c27fe 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -3,6 +3,7 @@ adjusted_mclmc_dynamic, barker, elliptical_slice, + ensemble, ghmc, hmc, mala, @@ -12,9 +13,10 @@ periodic_orbital, random_walk, rmhmc, - stretch, ) +stretch = ensemble + __all__ = [ "barker", "elliptical_slice", diff --git a/blackjax/mcmc/stretch.py b/blackjax/mcmc/ensemble.py similarity index 100% rename from blackjax/mcmc/stretch.py rename to blackjax/mcmc/ensemble.py diff --git a/tests/mcmc/test_ensemble.py b/tests/mcmc/test_ensemble.py index c24abe3b1..4736e6a35 100644 --- a/tests/mcmc/test_ensemble.py +++ b/tests/mcmc/test_ensemble.py @@ -7,7 +7,7 @@ from absl.testing import absltest import blackjax -from blackjax.mcmc.stretch import EnsembleState, stretch_move +from blackjax.mcmc.ensemble import EnsembleState, stretch_move class EnsembleTest(chex.TestCase): From 4503fede8a2c4755fef522caf5decbeed816fd1b Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 22:16:09 +0100 Subject: [PATCH 08/14] Remove stretch alias and expose ensemble as public API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Public API is now blackjax.ensemble() instead of blackjax.stretch() - Removed stretch alias from blackjax/mcmc/__init__.py - Updated blackjax/__init__.py to import and expose ensemble - Updated tests to use blackjax.ensemble() Rationale: Consistent naming throughout - the module is ensemble.py, the classes are EnsembleState/Info, and now the public API is blackjax.ensemble(). "Ensemble" describes the multi-walker sampling approach, while "stretch" is just the proposal mechanism. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- blackjax/__init__.py | 4 ++-- blackjax/mcmc/__init__.py | 4 +--- tests/mcmc/test_ensemble.py | 4 ++-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index fb4450b24..b31b115a8 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -17,6 +17,7 @@ from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc from .mcmc import elliptical_slice as _elliptical_slice +from .mcmc import ensemble as _ensemble from .mcmc import ghmc as _ghmc from .mcmc import hmc as _hmc from .mcmc import mala as _mala @@ -25,7 +26,6 @@ from .mcmc import nuts as _nuts from .mcmc import periodic_orbital, random_walk from .mcmc import rmhmc as _rmhmc -from .mcmc import stretch as _stretch from .mcmc.random_walk import additive_step_random_walk as _additive_step_random_walk from .mcmc.random_walk import ( irmh_as_top_level_api, @@ -120,7 +120,7 @@ def generate_top_level_api_from(module): elliptical_slice = generate_top_level_api_from(_elliptical_slice) ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) -stretch = generate_top_level_api_from(_stretch) +ensemble = generate_top_level_api_from(_ensemble) hmc_family = [hmc, nuts] diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index c106c27fe..72e818a84 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -15,11 +15,10 @@ rmhmc, ) -stretch = ensemble - __all__ = [ "barker", "elliptical_slice", + "ensemble", "ghmc", "hmc", "rmhmc", @@ -31,5 +30,4 @@ "mclmc", "adjusted_mclmc_dynamic", "adjusted_mclmc", - "stretch", ] diff --git a/tests/mcmc/test_ensemble.py b/tests/mcmc/test_ensemble.py index 4736e6a35..1d6a118a2 100644 --- a/tests/mcmc/test_ensemble.py +++ b/tests/mcmc/test_ensemble.py @@ -72,7 +72,7 @@ def logdensity_fn(x): initial_position = jax.random.normal(init_key, (n_walkers, 2)) # Create algorithm - algorithm = blackjax.stretch(logdensity_fn, a=2.0) + algorithm = blackjax.ensemble(logdensity_fn, a=2.0) initial_state = algorithm.init(initial_position) # Run a few steps @@ -110,7 +110,7 @@ def logdensity_fn(x): initial_position = jax.random.normal(init_key, (n_walkers, 1)) # Run algorithm - algorithm = blackjax.stretch(logdensity_fn, a=2.0) + algorithm = blackjax.ensemble(logdensity_fn, a=2.0) initial_state = algorithm.init(initial_position) def run_step(state, key): From 86b10f734de84d93d754a3e9ec578b809491eaf6 Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 22:50:27 +0100 Subject: [PATCH 09/14] Add Ensemble Slice Sampling (ESS) implementation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements the zeus Ensemble Slice Sampling algorithm in BlackJAX's ensemble module, following the architecture established in the refactoring. Key features: - Slice sampling with stepping-out and shrinking procedures - DifferentialMove and RandomMove direction generators - Robbins-Monro adaptive tuning of scale parameter mu - No rejections (acceptance_rate = 1.0) - Full JAX compatibility with jit/vmap support - PyTree support for arbitrary position structures Module structure: - blackjax/ensemble/base.py: Shared EnsembleState/Info - blackjax/ensemble/stretch.py: Stretch move (Metropolis-based) - blackjax/ensemble/slice.py: Slice sampling (new) Tests reorganized: - tests/ensemble/test_stretch.py (4 tests) - tests/ensemble/test_slice.py (8 tests) Public API: blackjax.ensemble_slice(logdensity_fn, move="differential", ...) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- blackjax/__init__.py | 4 +- blackjax/ensemble/__init__.py | 18 + blackjax/ensemble/base.py | 52 ++ blackjax/ensemble/slice.py | 732 ++++++++++++++++++ .../{mcmc/ensemble.py => ensemble/stretch.py} | 42 +- blackjax/mcmc/__init__.py | 2 - tests/ensemble/test_slice.py | 239 ++++++ .../test_stretch.py} | 3 +- 8 files changed, 1050 insertions(+), 42 deletions(-) create mode 100644 blackjax/ensemble/__init__.py create mode 100644 blackjax/ensemble/base.py create mode 100644 blackjax/ensemble/slice.py rename blackjax/{mcmc/ensemble.py => ensemble/stretch.py} (92%) create mode 100644 tests/ensemble/test_slice.py rename tests/{mcmc/test_ensemble.py => ensemble/test_stretch.py} (97%) diff --git a/blackjax/__init__.py b/blackjax/__init__.py index b31b115a8..b9043df81 100644 --- a/blackjax/__init__.py +++ b/blackjax/__init__.py @@ -12,12 +12,13 @@ from .base import SamplingAlgorithm, VIAlgorithm from .diagnostics import effective_sample_size as ess from .diagnostics import potential_scale_reduction as rhat +from .ensemble import slice as _ensemble_slice +from .ensemble import stretch as _ensemble from .mcmc import adjusted_mclmc as _adjusted_mclmc from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic from .mcmc import barker from .mcmc import dynamic_hmc as _dynamic_hmc from .mcmc import elliptical_slice as _elliptical_slice -from .mcmc import ensemble as _ensemble from .mcmc import ghmc as _ghmc from .mcmc import hmc as _hmc from .mcmc import mala as _mala @@ -121,6 +122,7 @@ def generate_top_level_api_from(module): ghmc = generate_top_level_api_from(_ghmc) barker_proposal = generate_top_level_api_from(barker) ensemble = generate_top_level_api_from(_ensemble) +ensemble_slice = generate_top_level_api_from(_ensemble_slice) hmc_family = [hmc, nuts] diff --git a/blackjax/ensemble/__init__.py b/blackjax/ensemble/__init__.py new file mode 100644 index 000000000..89d9d9360 --- /dev/null +++ b/blackjax/ensemble/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ensemble sampling methods for BlackJAX.""" +from blackjax.ensemble import slice, stretch +from blackjax.ensemble.base import EnsembleInfo, EnsembleState + +__all__ = ["EnsembleState", "EnsembleInfo", "stretch", "slice"] diff --git a/blackjax/ensemble/base.py b/blackjax/ensemble/base.py new file mode 100644 index 000000000..f1156cd53 --- /dev/null +++ b/blackjax/ensemble/base.py @@ -0,0 +1,52 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base classes and utilities for ensemble sampling methods.""" +from typing import NamedTuple, Optional + +from blackjax.types import Array, ArrayTree + +__all__ = ["EnsembleState", "EnsembleInfo"] + + +class EnsembleState(NamedTuple): + """State of an ensemble sampler. + + coords + An array or PyTree of arrays of shape `(n_walkers, ...)` that + stores the current position of the walkers. + log_probs + An array of shape `(n_walkers,)` that stores the log-probability of + each walker. + blobs + An optional PyTree that stores metadata returned by the log-probability + function. + """ + + coords: ArrayTree + log_probs: Array + blobs: Optional[ArrayTree] = None + + +class EnsembleInfo(NamedTuple): + """Additional information on the ensemble transition. + + acceptance_rate + The acceptance rate of the ensemble. + is_accepted + A boolean array of shape `(n_walkers,)` indicating whether each walker's + proposal was accepted. + """ + + acceptance_rate: float + is_accepted: Array diff --git a/blackjax/ensemble/slice.py b/blackjax/ensemble/slice.py new file mode 100644 index 000000000..b4b6f7a3c --- /dev/null +++ b/blackjax/ensemble/slice.py @@ -0,0 +1,732 @@ +# Copyright 2020- The Blackjax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Ensemble Slice Sampling (ESS) implementation.""" +from typing import Callable, NamedTuple, Optional + +import jax +import jax.numpy as jnp + +from blackjax.base import SamplingAlgorithm +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey + +__all__ = [ + "SliceEnsembleState", + "SliceEnsembleInfo", + "init", + "build_kernel", + "as_top_level_api", +] + + +class SliceEnsembleState(NamedTuple): + """State of the ensemble slice sampler. + + coords + An array or PyTree of arrays of shape `(n_walkers, ...)` that + stores the current position of the walkers. + log_probs + An array of shape `(n_walkers,)` that stores the log-probability of + each walker. + blobs + An optional PyTree that stores metadata returned by the log-probability + function. + mu + The current scale parameter for the slice sampling directions. + tuning_active + Whether adaptive tuning of mu is currently active. + patience_count + Counter for determining when to stop tuning. + """ + + coords: ArrayTree + log_probs: jnp.ndarray + blobs: Optional[ArrayTree] = None + mu: float = 1.0 + tuning_active: bool = True + patience_count: int = 0 + + +class SliceEnsembleInfo(NamedTuple): + """Additional information on the ensemble slice sampling transition. + + acceptance_rate + Always 1.0 for slice sampling (no rejections). + is_accepted + A boolean array of shape `(n_walkers,)` - always True for slice sampling. + expansions + Total number of slice expansions performed. + contractions + Total number of slice contractions performed. + nevals + Total number of log-density evaluations performed. + mu + The current value of the scale parameter mu. + """ + + acceptance_rate: float + is_accepted: jnp.ndarray + expansions: int + contractions: int + nevals: int + mu: float + + +def differential_direction( + rng_key: PRNGKey, + complementary_coords: ArrayTree, + n_update: int, + mu: float, +) -> tuple[ArrayTree, bool]: + """Generate direction vectors using the differential move. + + Directions are defined by the difference between two randomly selected + walkers from the complementary ensemble, scaled by 2*mu. + + Parameters + ---------- + rng_key + A PRNG key for random number generation. + complementary_coords + Coordinates of the complementary ensemble with shape (n_walkers, ...). + n_update + Number of walkers to update (number of directions needed). + mu + The scale parameter. + + Returns + ------- + A tuple (directions, tune_once) where directions is a PyTree matching + the structure of complementary_coords with leading dimension n_update, + and tune_once is True (indicating mu should be tuned). + """ + # Get the number of walkers in complementary ensemble + comp_leaves, _ = jax.tree_util.tree_flatten(complementary_coords) + n_comp = comp_leaves[0].shape[0] + + # Sample two different indices for each update walker + key1, key2 = jax.random.split(rng_key) + idx1 = jax.random.randint(key1, (n_update,), 0, n_comp) + j = jax.random.randint(key2, (n_update,), 0, n_comp - 1) + idx2 = j + (j >= idx1) # Ensure idx2 != idx1 + + # Compute directions as difference of pairs + walker1 = jax.tree_util.tree_map(lambda x: x[idx1], complementary_coords) + walker2 = jax.tree_util.tree_map(lambda x: x[idx2], complementary_coords) + + directions = jax.tree_util.tree_map( + lambda w1, w2: 2.0 * mu * (w1 - w2), walker1, walker2 + ) + + return directions, True + + +def random_direction( + rng_key: PRNGKey, + template_coords: ArrayTree, + n_update: int, + mu: float, +) -> tuple[ArrayTree, bool]: + """Generate random isotropic direction vectors. + + Directions are sampled from a standard normal distribution and scaled by 2*mu. + This corresponds to standard multivariate slice sampling without using + ensemble information. + + Parameters + ---------- + rng_key + A PRNG key for random number generation. + template_coords + Template coordinates to match structure and shape (n_update, ...). + n_update + Number of walkers to update (number of directions needed). + mu + The scale parameter. + + Returns + ------- + A tuple (directions, tune_once) where directions is a PyTree matching + the structure of template_coords, and tune_once is True. + """ + + def sample_leaf(shape): + return jax.random.normal(rng_key, shape) + + # Generate random directions with same structure as template + directions = jax.tree_util.tree_map( + lambda x: 2.0 * mu * sample_leaf(x.shape), template_coords + ) + + return directions, True + + +def slice_along_direction( + rng_key: PRNGKey, + x0: ArrayTree, + logp0: float, + direction: ArrayTree, + logprob_fn: Callable, + maxsteps: int = 10000, + maxiter: int = 10000, +) -> tuple[ArrayTree, float, int, int, int]: + """Perform slice sampling along a given direction. + + Implements the stepping-out and shrinking procedures for 1D slice sampling + along the specified direction vector. + + Parameters + ---------- + rng_key + A PRNG key for random number generation. + x0 + Current position (PyTree). + logp0 + Log-probability at current position. + direction + Direction vector (PyTree with same structure as x0). + logprob_fn + Function that computes log-probability given a position. + maxsteps + Maximum number of steps for stepping-out procedure. + maxiter + Maximum total iterations to prevent infinite loops. + + Returns + ------- + A tuple (x1, logp1, nexp, ncon, neval) where: + x1: New position (PyTree) + logp1: Log-probability at new position + nexp: Number of expansions performed + ncon: Number of contractions performed + neval: Number of log-probability evaluations + """ + key_z0, key_lr, key_j, key_shrink = jax.random.split(rng_key, 4) + + # Draw slice height: Z0 = logp0 - Exponential(1) + z0 = logp0 - jax.random.exponential(key_z0) + + # Initialize interval [L, R] around 0 + l_init = -jax.random.uniform(key_lr) + r_init = l_init + 1.0 + + # Random allocation of expansion steps + j = jax.random.randint(key_j, (), 0, maxsteps) + k = maxsteps - 1 - j + + # Helper function to evaluate log-prob at x0 + t*direction + def eval_at_t(t): + xt = jax.tree_util.tree_map(lambda x, d: x + t * d, x0, direction) + return logprob_fn(xt) + + # Stepping-out: expand left + def left_expand_cond(carry): + l, j_left, nexp, neval, iter_count = carry + logp_l = eval_at_t(l) + return (j_left > 0) & (iter_count < maxiter) & (logp_l > z0) + + def left_expand_body(carry): + l, j_left, nexp, neval, iter_count = carry + return l - 1.0, j_left - 1, nexp + 1, neval + 1, iter_count + 1 + + l_final, _, nexp_left, neval_left, _ = jax.lax.while_loop( + left_expand_cond, left_expand_body, (l_init, j, 0, 0, 0) + ) + + # Stepping-out: expand right + def right_expand_cond(carry): + r, k_right, nexp, neval, iter_count = carry + logp_r = eval_at_t(r) + return (k_right > 0) & (iter_count < maxiter) & (logp_r > z0) + + def right_expand_body(carry): + r, k_right, nexp, neval, iter_count = carry + return r + 1.0, k_right - 1, nexp + 1, neval + 1, iter_count + 1 + + r_final, _, nexp_right, neval_right, _ = jax.lax.while_loop( + right_expand_cond, right_expand_body, (r_init, k, 0, 0, 0) + ) + + nexp_total = nexp_left + nexp_right + neval_after_expand = neval_left + neval_right + + # Shrinking: sample uniformly from [L, R] until inside slice + def shrink_cond(carry): + _, _, _, _, _, iter_count, accepted, _, _ = carry + return (~accepted) & (iter_count < maxiter) + + def shrink_body(carry): + key, l, r, neval, ncon, iter_count, accepted, t_acc, logp_acc = carry + key, key_t = jax.random.split(key) + t = jax.random.uniform(key_t, minval=l, maxval=r) + logp_t = eval_at_t(t) + neval_new = neval + 1 + + # Check if inside slice + inside_slice = logp_t >= z0 + accepted_new = accepted | inside_slice + + # Update interval or accept + l_new = jnp.where(inside_slice, l, jnp.where(t < 0, t, l)) + r_new = jnp.where(inside_slice, r, jnp.where(t >= 0, t, r)) + ncon_new = jnp.where(inside_slice, ncon, ncon + 1) + t_acc_new = jnp.where(inside_slice & ~accepted, t, t_acc) + logp_acc_new = jnp.where(inside_slice & ~accepted, logp_t, logp_acc) + + return ( + key, + l_new, + r_new, + neval_new, + ncon_new, + iter_count + 1, + accepted_new, + t_acc_new, + logp_acc_new, + ) + + _, _, _, neval_shrink, ncon_total, _, _, t_final, logp_final = jax.lax.while_loop( + shrink_cond, + shrink_body, + (key_shrink, l_final, r_final, 0, 0, 0, False, 0.0, logp0), + ) + + # Compute final position + x1 = jax.tree_util.tree_map(lambda x, d: x + t_final * d, x0, direction) + + neval_total = neval_after_expand + neval_shrink + + return x1, logp_final, nexp_total, ncon_total, neval_total + + +def init( + position: ArrayLikeTree, + logdensity_fn: Callable, + has_blobs: bool = False, + mu: float = 1.0, +) -> SliceEnsembleState: + """Initialize the ensemble slice sampling algorithm. + + Parameters + ---------- + position + Initial positions for all walkers, with shape (n_walkers, ...). + logdensity_fn + The log-density function to evaluate. + has_blobs + Whether the log-density function returns additional metadata (blobs). + mu + Initial value of the scale parameter. + + Returns + ------- + Initial SliceEnsembleState. + """ + logdensity_outputs = jax.vmap(logdensity_fn)(position) + if isinstance(logdensity_outputs, tuple): + log_probs, blobs = logdensity_outputs + return SliceEnsembleState(position, log_probs, blobs, mu, True, 0) + else: + log_probs = logdensity_outputs + return SliceEnsembleState(position, log_probs, None, mu, True, 0) + + +def build_kernel( + move: str = "differential", + move_fn: Optional[Callable] = None, + randomize_split: bool = True, + nsplits: int = 2, + maxsteps: int = 10000, + maxiter: int = 10000, + tune: bool = True, + patience: int = 5, + tolerance: float = 0.05, +) -> Callable: + """Build the ensemble slice sampling kernel. + + Parameters + ---------- + move + Type of move to use: "differential" or "random". Ignored if move_fn provided. + move_fn + Optional custom move function. If None, uses the specified move type. + randomize_split + If True, randomly shuffle walker indices before splitting into groups. + nsplits + Number of groups to split the ensemble into. Default is 2. + maxsteps + Maximum steps for slice stepping-out procedure. + maxiter + Maximum iterations to prevent infinite loops. + tune + Whether to enable adaptive tuning of mu. + patience + Number of steps within tolerance before stopping tuning. + tolerance + Tolerance for expansion/contraction ratio to stop tuning. + + Returns + ------- + A kernel function that performs one step of ensemble slice sampling. + """ + # Select move function + if move_fn is None: + if move == "differential": + move_fn = differential_direction + elif move == "random": + move_fn = random_direction + else: + raise ValueError(f"Unknown move type: {move}") + + # At this point move_fn is guaranteed to be Callable + selected_move_fn: Callable = move_fn + + def kernel( + rng_key: PRNGKey, state: SliceEnsembleState, logdensity_fn: Callable + ) -> tuple[SliceEnsembleState, SliceEnsembleInfo]: + n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape + + # Shuffle walkers if requested + if randomize_split: + key_shuffle, key_update = jax.random.split(rng_key) + indices = jax.random.permutation(key_shuffle, n_walkers) + shuffled_coords = jax.tree_util.tree_map(lambda x: x[indices], state.coords) + shuffled_log_probs = state.log_probs[indices] + shuffled_blobs = ( + None + if state.blobs is None + else jax.tree_util.tree_map(lambda x: x[indices], state.blobs) + ) + shuffled_state = SliceEnsembleState( + shuffled_coords, + shuffled_log_probs, + shuffled_blobs, + state.mu, + state.tuning_active, + state.patience_count, + ) + else: + key_update = rng_key + shuffled_state = state + indices = jnp.arange(n_walkers) + + # Split into groups + group_size = n_walkers // nsplits + groups = [] + for i in range(nsplits): + start_idx = i * group_size + end_idx = (i + 1) * group_size if i < nsplits - 1 else n_walkers + + group_coords = jax.tree_util.tree_map( + lambda x: x[start_idx:end_idx], shuffled_state.coords + ) + group_log_probs = shuffled_state.log_probs[start_idx:end_idx] + group_blobs = ( + None + if shuffled_state.blobs is None + else jax.tree_util.tree_map( + lambda x: x[start_idx:end_idx], shuffled_state.blobs + ) + ) + groups.append( + SliceEnsembleState( + group_coords, + group_log_probs, + group_blobs, + state.mu, + state.tuning_active, + state.patience_count, + ) + ) + + # Update each group sequentially + updated_groups = list(groups) + total_nexp = 0 + total_ncon = 0 + total_neval = 0 + + keys = jax.random.split(key_update, nsplits) + for i in range(nsplits): + # Build complementary ensemble from other groups + other_indices = [j for j in range(nsplits) if j != i] + comp_coords_list = [updated_groups[j].coords for j in other_indices] + comp_log_probs_list = [updated_groups[j].log_probs for j in other_indices] + comp_blobs_list = [updated_groups[j].blobs for j in other_indices] + + complementary_coords = jax.tree_util.tree_map( + lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_coords_list + ) + complementary_log_probs = jnp.concatenate(comp_log_probs_list, axis=0) + + if state.blobs is not None: + complementary_blobs = jax.tree_util.tree_map( + lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_blobs_list + ) + else: + complementary_blobs = None + + complementary = SliceEnsembleState( + complementary_coords, + complementary_log_probs, + complementary_blobs, + state.mu, + state.tuning_active, + state.patience_count, + ) + + # Update this group + updated_group, nexp, ncon, neval = _update_half_slice( + keys[i], + groups[i], + complementary, + logdensity_fn, + selected_move_fn, + maxsteps, + maxiter, + ) + updated_groups[i] = updated_group + total_nexp += nexp + total_ncon += ncon + total_neval += neval + + # Concatenate updated groups + shuffled_coords = jax.tree_util.tree_map( + lambda *arrays: jnp.concatenate(arrays, axis=0), + *[g.coords for g in updated_groups], + ) + shuffled_log_probs = jnp.concatenate( + [g.log_probs for g in updated_groups], axis=0 + ) + + if state.blobs is not None: + shuffled_blobs = jax.tree_util.tree_map( + lambda *arrays: jnp.concatenate(arrays, axis=0), + *[g.blobs for g in updated_groups], + ) + else: + shuffled_blobs = None + + # Unshuffle if needed + if randomize_split: + inverse_indices = jnp.argsort(indices) + new_coords = jax.tree_util.tree_map( + lambda x: x[inverse_indices], shuffled_coords + ) + new_log_probs = shuffled_log_probs[inverse_indices] + if shuffled_blobs is not None: + new_blobs = jax.tree_util.tree_map( + lambda x: x[inverse_indices], shuffled_blobs + ) + else: + new_blobs = None + else: + new_coords = shuffled_coords + new_log_probs = shuffled_log_probs + new_blobs = shuffled_blobs + + # Adaptive tuning of mu + should_tune = tune & state.tuning_active + + nexp_eff = jnp.maximum(total_nexp, 1) + mu_tuned = state.mu * 2.0 * nexp_eff / (nexp_eff + total_ncon) + + # Check convergence of tuning + exp_ratio = total_nexp / jnp.maximum(total_nexp + total_ncon, 1) + within_tolerance = jnp.abs(exp_ratio - 0.5) < tolerance + + patience_count_updated = jnp.where( + within_tolerance, state.patience_count + 1, 0 + ) + tuning_active_updated = patience_count_updated < patience + + # Apply tuning updates conditionally + mu_new = jnp.where(should_tune, mu_tuned, state.mu) + patience_count_new = jnp.where( + should_tune, patience_count_updated, state.patience_count + ) + tuning_active_new = jnp.where( + should_tune, tuning_active_updated, state.tuning_active + ) + + new_state = SliceEnsembleState( + new_coords, + new_log_probs, + new_blobs, + mu_new, + tuning_active_new, + patience_count_new, + ) + + # Build info (acceptance always 1.0 for slice sampling) + info = SliceEnsembleInfo( + acceptance_rate=1.0, + is_accepted=jnp.ones(n_walkers, dtype=bool), + expansions=total_nexp, + contractions=total_ncon, + nevals=total_neval, + mu=mu_new, + ) + + return new_state, info + + return kernel + + +def _update_half_slice( + rng_key: PRNGKey, + walkers_to_update: SliceEnsembleState, + complementary_walkers: SliceEnsembleState, + logdensity_fn: Callable, + direction_fn: Callable, + maxsteps: int, + maxiter: int, +) -> tuple[SliceEnsembleState, int, int, int]: + """Update a group of walkers using ensemble slice sampling. + + Parameters + ---------- + rng_key + PRNG key for random number generation. + walkers_to_update + Group of walkers to update. + complementary_walkers + Complementary ensemble used for generating directions. + logdensity_fn + Log-density function. + direction_fn + Function to generate direction vectors. + maxsteps + Maximum steps for stepping-out. + maxiter + Maximum iterations. + + Returns + ------- + Tuple of (updated_group_state, total_expansions, total_contractions, total_evals). + """ + n_update, *_ = jax.tree_util.tree_flatten(walkers_to_update.coords)[0][0].shape + + # Generate directions + key_dir, key_slice = jax.random.split(rng_key) + directions, _ = direction_fn( + key_dir, complementary_walkers.coords, n_update, walkers_to_update.mu + ) + + # Define logprob-only function + def logprob_only(x): + out = logdensity_fn(x) + return out[0] if isinstance(out, tuple) else out + + # Perform slice sampling for each walker + keys = jax.random.split(key_slice, n_update) + + def slice_one_walker(key, x0, logp0, direction): + return slice_along_direction( + key, x0, logp0, direction, logprob_only, maxsteps, maxiter + ) + + results = jax.vmap(slice_one_walker)( + keys, walkers_to_update.coords, walkers_to_update.log_probs, directions + ) + + new_coords, new_log_probs, nexp_array, ncon_array, neval_array = results + + # Sum statistics + total_nexp = jnp.sum(nexp_array) + total_ncon = jnp.sum(ncon_array) + total_neval = jnp.sum(neval_array) + + # Handle blobs if needed + if walkers_to_update.blobs is not None: + # Re-evaluate at new positions to get blobs + logdensity_outputs = jax.vmap(logdensity_fn)(new_coords) + _, new_blobs = logdensity_outputs + else: + new_blobs = None + + updated_state = SliceEnsembleState( + new_coords, + new_log_probs, + new_blobs, + walkers_to_update.mu, + walkers_to_update.tuning_active, + walkers_to_update.patience_count, + ) + + return updated_state, total_nexp, total_ncon, total_neval + + +def as_top_level_api( + logdensity_fn: Callable, + move: str = "differential", + mu: float = 1.0, + has_blobs: bool = False, + randomize_split: bool = True, + nsplits: int = 2, + maxsteps: int = 10000, + maxiter: int = 10000, + tune: bool = True, + patience: int = 5, + tolerance: float = 0.05, +) -> SamplingAlgorithm: + """A user-facing API for the ensemble slice sampling algorithm. + + Parameters + ---------- + logdensity_fn + A function that returns the log density of the model at a given position. + move + Type of move: "differential" or "random". + mu + Initial value of the scale parameter. + has_blobs + Whether the logdensity function returns additional information (blobs). + randomize_split + If True, randomly shuffle walker indices before splitting into groups. + nsplits + Number of groups to split the ensemble into. Default is 2. + maxsteps + Maximum steps for slice stepping-out procedure. + maxiter + Maximum iterations to prevent infinite loops. + tune + Whether to enable adaptive tuning of mu. + patience + Number of steps within tolerance before stopping tuning. + tolerance + Tolerance for expansion/contraction ratio to stop tuning. + + Returns + ------- + A `SamplingAlgorithm` that can be used to sample from the target distribution. + """ + kernel = build_kernel( + move=move, + randomize_split=randomize_split, + nsplits=nsplits, + maxsteps=maxsteps, + maxiter=maxiter, + tune=tune, + patience=patience, + tolerance=tolerance, + ) + + def init_fn(position: ArrayTree, rng_key=None): + return init(position, logdensity_fn, has_blobs, mu) + + def step_fn( + rng_key: PRNGKey, state + ) -> tuple[SliceEnsembleState, SliceEnsembleInfo]: + return kernel(rng_key, state, logdensity_fn) + + return SamplingAlgorithm(init_fn, step_fn) diff --git a/blackjax/mcmc/ensemble.py b/blackjax/ensemble/stretch.py similarity index 92% rename from blackjax/mcmc/ensemble.py rename to blackjax/ensemble/stretch.py index 7afda54cb..062910cec 100644 --- a/blackjax/mcmc/ensemble.py +++ b/blackjax/ensemble/stretch.py @@ -11,19 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Public API for the Stretch Move ensemble sampler.""" -from typing import Callable, NamedTuple, Optional +"""Stretch move ensemble sampler (affine-invariant MCMC).""" +from typing import Callable import jax import jax.numpy as jnp from jax.flatten_util import ravel_pytree from blackjax.base import SamplingAlgorithm -from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey +from blackjax.ensemble.base import EnsembleInfo, EnsembleState +from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = [ - "EnsembleState", - "EnsembleInfo", "init", "build_kernel", "as_top_level_api", @@ -31,39 +30,6 @@ ] -class EnsembleState(NamedTuple): - """State of an ensemble sampler. - - coords - An array or PyTree of arrays of shape `(n_walkers, ...)` that - stores the current position of the walkers. - log_probs - An array of shape `(n_walkers,)` that stores the log-probability of - each walker. - blobs - An optional PyTree that stores metadata returned by the log-probability - function. - """ - - coords: ArrayTree - log_probs: Array - blobs: Optional[ArrayTree] = None - - -class EnsembleInfo(NamedTuple): - """Additional information on the ensemble transition. - - acceptance_rate - The acceptance rate of the ensemble. - is_accepted - A boolean array of shape `(n_walkers,)` indicating whether each walker's - proposal was accepted. - """ - - acceptance_rate: float - is_accepted: Array - - def stretch_move( rng_key: PRNGKey, walker_coords: ArrayTree, diff --git a/blackjax/mcmc/__init__.py b/blackjax/mcmc/__init__.py index 72e818a84..8acb28274 100644 --- a/blackjax/mcmc/__init__.py +++ b/blackjax/mcmc/__init__.py @@ -3,7 +3,6 @@ adjusted_mclmc_dynamic, barker, elliptical_slice, - ensemble, ghmc, hmc, mala, @@ -18,7 +17,6 @@ __all__ = [ "barker", "elliptical_slice", - "ensemble", "ghmc", "hmc", "rmhmc", diff --git a/tests/ensemble/test_slice.py b/tests/ensemble/test_slice.py new file mode 100644 index 000000000..1739a71ab --- /dev/null +++ b/tests/ensemble/test_slice.py @@ -0,0 +1,239 @@ +"""Test the ensemble slice sampling kernel.""" + +import chex +import jax +import jax.numpy as jnp +import jax.scipy.stats as stats +from absl.testing import absltest + +from blackjax.ensemble.slice import ( + SliceEnsembleInfo, + SliceEnsembleState, + as_top_level_api, + differential_direction, + init, + random_direction, + slice_along_direction, +) + + +class EnsembleSliceTest(chex.TestCase): + """Test the ensemble slice sampling algorithm.""" + + def test_differential_direction(self): + """Test that differential_direction produces valid directions.""" + rng_key = jax.random.PRNGKey(0) + + # Complementary ensemble + complementary_coords = jnp.array([[0.0, 0.0], [2.0, 4.0], [3.0, 1.0]]) + n_update = 2 + mu = 1.0 + + directions, tune_once = differential_direction( + rng_key, complementary_coords, n_update, mu + ) + + # Check shape and properties + self.assertEqual(directions.shape, (n_update, 2)) + self.assertTrue(tune_once) + self.assertTrue(jnp.isfinite(directions).all()) + + def test_random_direction(self): + """Test that random_direction produces valid directions.""" + rng_key = jax.random.PRNGKey(0) + + # Template coordinates + template_coords = jnp.array([[1.0, 2.0], [3.0, 4.0]]) + n_update = 2 + mu = 1.0 + + directions, tune_once = random_direction(rng_key, template_coords, n_update, mu) + + # Check shape and properties + self.assertEqual(directions.shape, template_coords.shape) + self.assertTrue(tune_once) + self.assertTrue(jnp.isfinite(directions).all()) + + def test_slice_along_direction_1d(self): + """Test slice sampling along a direction in 1D.""" + rng_key = jax.random.PRNGKey(42) + + # Simple 1D Gaussian + def logprob_fn(x): + return stats.norm.logpdf(x, 0.0, 1.0) + + x0 = 0.0 + logp0 = logprob_fn(x0) + direction = 1.0 + + x1, logp1, nexp, ncon, neval = slice_along_direction( + rng_key, x0, logp0, direction, logprob_fn, maxsteps=100, maxiter=1000 + ) + + # Check that we got valid results + self.assertTrue(jnp.isfinite(x1)) + self.assertTrue(jnp.isfinite(logp1)) + self.assertGreater(neval, 0) + + def test_init(self): + """Test initialization of ensemble slice sampler.""" + + def logdensity_fn(x): + return stats.norm.logpdf(x, 0.0, 1.0).sum() + + rng_key = jax.random.PRNGKey(0) + n_walkers = 10 + initial_position = jax.random.normal(rng_key, (n_walkers, 2)) + + state = init(initial_position, logdensity_fn, has_blobs=False, mu=1.0) + + # Check state properties + self.assertIsInstance(state, SliceEnsembleState) + self.assertEqual(state.coords.shape, (n_walkers, 2)) + self.assertEqual(state.log_probs.shape, (n_walkers,)) + self.assertIsNone(state.blobs) + self.assertEqual(state.mu, 1.0) + self.assertTrue(state.tuning_active) + self.assertEqual(state.patience_count, 0) + + def test_ensemble_slice_1d_gaussian(self): + """Test ensemble slice sampling on a 1D Gaussian distribution.""" + + # Define 1D Gaussian target + mu_true = 2.0 + sigma_true = 1.5 + + def logdensity_fn(x): + return stats.norm.logpdf(x.squeeze(), mu_true, sigma_true) + + rng_key = jax.random.PRNGKey(123) + init_key, sample_key = jax.random.split(rng_key) + + # Initialize with 20 walkers + n_walkers = 20 + initial_position = jax.random.normal(init_key, (n_walkers, 1)) + + # Create algorithm + algorithm = as_top_level_api( + logdensity_fn, move="differential", mu=1.0, maxsteps=100, maxiter=1000 + ) + initial_state = algorithm.init(initial_position) + + # Run a few steps + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, (new_state, info) + + keys = jax.random.split(sample_key, 100) + final_state, (states, infos) = jax.lax.scan(run_step, initial_state, keys) + + # Check that we get valid states + self.assertIsInstance(final_state, SliceEnsembleState) + self.assertEqual(final_state.coords.shape, (n_walkers, 1)) + self.assertEqual(final_state.log_probs.shape, (n_walkers,)) + + # Check info + self.assertIsInstance(infos, SliceEnsembleInfo) + # Slice sampling always accepts + self.assertTrue(jnp.all(infos.acceptance_rate == 1.0)) + self.assertTrue(jnp.all(infos.is_accepted)) + + # Check that evaluations are happening + self.assertTrue(jnp.all(infos.nevals > 0)) + + def test_ensemble_slice_2d_gaussian(self): + """Test ensemble slice sampling on a 2D Gaussian distribution.""" + + # Define 2D Gaussian target + mu = jnp.array([1.0, 2.0]) + cov = jnp.array([[1.0, 0.5], [0.5, 2.0]]) + + def logdensity_fn(x): + return stats.multivariate_normal.logpdf(x, mu, cov) + + rng_key = jax.random.PRNGKey(42) + init_key, sample_key = jax.random.split(rng_key) + + # Initialize ensemble + n_walkers = 20 + initial_position = jax.random.normal(init_key, (n_walkers, 2)) + + # Create algorithm + algorithm = as_top_level_api( + logdensity_fn, move="differential", mu=1.0, maxsteps=100, maxiter=1000 + ) + initial_state = algorithm.init(initial_position) + + # Run steps + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, new_state.coords + + keys = jax.random.split(sample_key, 200) + final_state, samples = jax.lax.scan(run_step, initial_state, keys) + + # Take second half as samples (burn-in) + samples = samples[100:] # Shape: (100, n_walkers, 2) + samples = samples.reshape(-1, 2) # Flatten to (100 * n_walkers, 2) + + # Check convergence (loose tolerance for quick test) + sample_mean = jnp.mean(samples, axis=0) + self.assertAlmostEqual(sample_mean[0].item(), mu[0], places=0) + self.assertAlmostEqual(sample_mean[1].item(), mu[1], places=0) + + def test_jit_compilation(self): + """Test that the algorithm can be JIT compiled.""" + + def logdensity_fn(x): + return stats.norm.logpdf(x.squeeze(), 0.0, 1.0) + + rng_key = jax.random.PRNGKey(0) + n_walkers = 10 + initial_position = jax.random.normal(rng_key, (n_walkers, 1)) + + algorithm = as_top_level_api(logdensity_fn, move="differential") + initial_state = algorithm.init(initial_position) + + # JIT compile step function + jitted_step = jax.jit(algorithm.step) + + # Run one step + key = jax.random.PRNGKey(1) + new_state, info = jitted_step(key, initial_state) + + # Check results are valid + self.assertIsInstance(new_state, SliceEnsembleState) + self.assertIsInstance(info, SliceEnsembleInfo) + + def test_random_move(self): + """Test ensemble slice sampling with random move.""" + + def logdensity_fn(x): + return stats.norm.logpdf(x, 0.0, 1.0).sum() + + rng_key = jax.random.PRNGKey(99) + n_walkers = 10 + initial_position = jax.random.normal(rng_key, (n_walkers, 2)) + + # Use random move instead of differential + algorithm = as_top_level_api( + logdensity_fn, move="random", mu=1.0, maxsteps=100, maxiter=1000 + ) + initial_state = algorithm.init(initial_position) + + # Run a few steps + keys = jax.random.split(rng_key, 10) + + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, info + + final_state, infos = jax.lax.scan(run_step, initial_state, keys) + + # Check valid results + self.assertIsInstance(final_state, SliceEnsembleState) + self.assertTrue(jnp.all(infos.acceptance_rate == 1.0)) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/mcmc/test_ensemble.py b/tests/ensemble/test_stretch.py similarity index 97% rename from tests/mcmc/test_ensemble.py rename to tests/ensemble/test_stretch.py index 1d6a118a2..4e8166ca2 100644 --- a/tests/mcmc/test_ensemble.py +++ b/tests/ensemble/test_stretch.py @@ -7,7 +7,8 @@ from absl.testing import absltest import blackjax -from blackjax.mcmc.ensemble import EnsembleState, stretch_move +from blackjax.ensemble.base import EnsembleState +from blackjax.ensemble.stretch import stretch_move class EnsembleTest(chex.TestCase): From 925f4683c726abba6345ce41f97598a6154540de Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 23:17:19 +0100 Subject: [PATCH 10/14] Refine ensemble slice implementation and follow house style MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Implement all medium-priority improvements from OpenAI review: - Expose move_fn parameter in as_top_level_api for custom moves - Optimize logprob evaluation (carry in while-loop state) - Add gaussian_direction with full PyTree support - Simplify docstrings to match BlackJAX house style - Remove all inline comments and defensive checks - Fix SliceEnsembleInfo docstring accuracy - Add tests for gaussian direction (arrays and PyTrees) - All 11 tests passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- blackjax/ensemble/slice.py | 270 ++++++++++++++++++++++------------- tests/ensemble/test_slice.py | 71 ++++++++- 2 files changed, 236 insertions(+), 105 deletions(-) diff --git a/blackjax/ensemble/slice.py b/blackjax/ensemble/slice.py index b4b6f7a3c..3d4bb2607 100644 --- a/blackjax/ensemble/slice.py +++ b/blackjax/ensemble/slice.py @@ -1,16 +1,6 @@ -# Copyright 2020- The Blackjax Authors. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """Ensemble Slice Sampling (ESS) implementation.""" from typing import Callable, NamedTuple, Optional @@ -26,6 +16,10 @@ "init", "build_kernel", "as_top_level_api", + "differential_direction", + "random_direction", + "gaussian_direction", + "slice_along_direction", ] @@ -61,9 +55,9 @@ class SliceEnsembleInfo(NamedTuple): """Additional information on the ensemble slice sampling transition. acceptance_rate - Always 1.0 for slice sampling (no rejections). + Fraction of walkers that found valid slice points. is_accepted - A boolean array of shape `(n_walkers,)` - always True for slice sampling. + Boolean array of shape `(n_walkers,)` indicating successful slice updates. expansions Total number of slice expansions performed. contractions @@ -71,7 +65,7 @@ class SliceEnsembleInfo(NamedTuple): nevals Total number of log-density evaluations performed. mu - The current value of the scale parameter mu. + The current value of the scale parameter. """ acceptance_rate: float @@ -110,17 +104,14 @@ def differential_direction( the structure of complementary_coords with leading dimension n_update, and tune_once is True (indicating mu should be tuned). """ - # Get the number of walkers in complementary ensemble comp_leaves, _ = jax.tree_util.tree_flatten(complementary_coords) n_comp = comp_leaves[0].shape[0] - # Sample two different indices for each update walker key1, key2 = jax.random.split(rng_key) idx1 = jax.random.randint(key1, (n_update,), 0, n_comp) j = jax.random.randint(key2, (n_update,), 0, n_comp - 1) idx2 = j + (j >= idx1) # Ensure idx2 != idx1 - # Compute directions as difference of pairs walker1 = jax.tree_util.tree_map(lambda x: x[idx1], complementary_coords) walker2 = jax.tree_util.tree_map(lambda x: x[idx2], complementary_coords) @@ -148,7 +139,7 @@ def random_direction( rng_key A PRNG key for random number generation. template_coords - Template coordinates to match structure and shape (n_update, ...). + Template coordinates to match structure. Leading dimension will be n_update. n_update Number of walkers to update (number of directions needed). mu @@ -157,16 +148,90 @@ def random_direction( Returns ------- A tuple (directions, tune_once) where directions is a PyTree matching - the structure of template_coords, and tune_once is True. + the structure of template_coords with leading dimension n_update, + and tune_once is True. """ + leaves, treedef = jax.tree_util.tree_flatten(template_coords) + n_leaves = len(leaves) - def sample_leaf(shape): - return jax.random.normal(rng_key, shape) + keys = jax.random.split(rng_key, n_leaves) - # Generate random directions with same structure as template - directions = jax.tree_util.tree_map( - lambda x: 2.0 * mu * sample_leaf(x.shape), template_coords - ) + def sample_leaf(key, template_leaf): + template_shape = template_leaf.shape + if len(template_shape) > 0: + new_shape = (n_update,) + template_shape[1:] + else: + new_shape = (n_update,) + return jax.random.normal(key, new_shape) + + direction_leaves = [ + 2.0 * mu * sample_leaf(key, leaf) for key, leaf in zip(keys, leaves) + ] + + directions = jax.tree_util.tree_unflatten(treedef, direction_leaves) + + return directions, True + + +def gaussian_direction( + rng_key: PRNGKey, + complementary_coords: ArrayTree, + n_update: int, + mu: float, +) -> tuple[ArrayTree, bool]: + """Generate direction vectors using the Gaussian move. + + Directions are sampled from a multivariate normal distribution with covariance + estimated from the complementary ensemble. This move adapts to the local + geometry of the target distribution. + + Parameters + ---------- + rng_key + A PRNG key for random number generation. + complementary_coords + Coordinates of the complementary ensemble. + n_update + Number of walkers to update. + mu + The scale parameter. + + Returns + ------- + A tuple (directions, tune_once) where directions is a PyTree matching + the structure of complementary_coords with leading dimension n_update, + and tune_once is True. + """ + leaves, treedef = jax.tree_util.tree_flatten(complementary_coords) + n_leaves = len(leaves) + + keys = jax.random.split(rng_key, n_leaves) + + def sample_gaussian_leaf(key, leaf): + n_comp = leaf.shape[0] + leaf_flat = leaf.reshape(n_comp, -1) + d = leaf_flat.shape[1] + + mean = jnp.mean(leaf_flat, axis=0) + centered = leaf_flat - mean + + cov = jnp.dot(centered.T, centered) / (n_comp - 1) + jitter = 1e-6 * jnp.eye(d) + cov_reg = cov + jitter + + directions_flat = jax.random.multivariate_normal( + key, jnp.zeros(d), cov_reg, (n_update,) + ) + + orig_shape = leaf.shape + new_shape = (n_update,) + orig_shape[1:] + return (2.0 * mu * directions_flat).reshape(new_shape) + + direction_leaves = [ + sample_gaussian_leaf(key, leaf) for key, leaf in zip(keys, leaves) + ] + + directions = jax.tree_util.tree_unflatten(treedef, direction_leaves) return directions, True @@ -179,7 +244,7 @@ def slice_along_direction( logprob_fn: Callable, maxsteps: int = 10000, maxiter: int = 10000, -) -> tuple[ArrayTree, float, int, int, int]: +) -> tuple[ArrayTree, float, bool, int, int, int]: """Perform slice sampling along a given direction. Implements the stepping-out and shrinking procedures for 1D slice sampling @@ -204,63 +269,63 @@ def slice_along_direction( Returns ------- - A tuple (x1, logp1, nexp, ncon, neval) where: - x1: New position (PyTree) - logp1: Log-probability at new position + A tuple (x1, logp1, accepted, nexp, ncon, neval) where: + x1: New position (PyTree), or x0 if not accepted + logp1: Log-probability at new position, or logp0 if not accepted + accepted: Boolean indicating if a valid slice point was found nexp: Number of expansions performed ncon: Number of contractions performed neval: Number of log-probability evaluations """ key_z0, key_lr, key_j, key_shrink = jax.random.split(rng_key, 4) - # Draw slice height: Z0 = logp0 - Exponential(1) z0 = logp0 - jax.random.exponential(key_z0) - # Initialize interval [L, R] around 0 l_init = -jax.random.uniform(key_lr) r_init = l_init + 1.0 - # Random allocation of expansion steps j = jax.random.randint(key_j, (), 0, maxsteps) k = maxsteps - 1 - j - # Helper function to evaluate log-prob at x0 + t*direction def eval_at_t(t): xt = jax.tree_util.tree_map(lambda x, d: x + t * d, x0, direction) return logprob_fn(xt) - # Stepping-out: expand left + logp_l_init = eval_at_t(l_init) + def left_expand_cond(carry): - l, j_left, nexp, neval, iter_count = carry - logp_l = eval_at_t(l) + l, logp_l, j_left, nexp, neval, iter_count = carry return (j_left > 0) & (iter_count < maxiter) & (logp_l > z0) def left_expand_body(carry): - l, j_left, nexp, neval, iter_count = carry - return l - 1.0, j_left - 1, nexp + 1, neval + 1, iter_count + 1 + l, logp_l, j_left, nexp, neval, iter_count = carry + l_new = l - 1.0 + logp_new = eval_at_t(l_new) + return l_new, logp_new, j_left - 1, nexp + 1, neval + 1, iter_count + 1 - l_final, _, nexp_left, neval_left, _ = jax.lax.while_loop( - left_expand_cond, left_expand_body, (l_init, j, 0, 0, 0) + l_final, _, _, nexp_left, neval_left, _ = jax.lax.while_loop( + left_expand_cond, left_expand_body, (l_init, logp_l_init, j, 0, 0, 0) ) - # Stepping-out: expand right + logp_r_init = eval_at_t(r_init) + def right_expand_cond(carry): - r, k_right, nexp, neval, iter_count = carry - logp_r = eval_at_t(r) + r, logp_r, k_right, nexp, neval, iter_count = carry return (k_right > 0) & (iter_count < maxiter) & (logp_r > z0) def right_expand_body(carry): - r, k_right, nexp, neval, iter_count = carry - return r + 1.0, k_right - 1, nexp + 1, neval + 1, iter_count + 1 + r, logp_r, k_right, nexp, neval, iter_count = carry + r_new = r + 1.0 + logp_new = eval_at_t(r_new) + return r_new, logp_new, k_right - 1, nexp + 1, neval + 1, iter_count + 1 - r_final, _, nexp_right, neval_right, _ = jax.lax.while_loop( - right_expand_cond, right_expand_body, (r_init, k, 0, 0, 0) + r_final, _, _, nexp_right, neval_right, _ = jax.lax.while_loop( + right_expand_cond, right_expand_body, (r_init, logp_r_init, k, 0, 0, 0) ) nexp_total = nexp_left + nexp_right - neval_after_expand = neval_left + neval_right + neval_after_expand = neval_left + neval_right + 2 - # Shrinking: sample uniformly from [L, R] until inside slice def shrink_cond(carry): _, _, _, _, _, iter_count, accepted, _, _ = carry return (~accepted) & (iter_count < maxiter) @@ -272,11 +337,9 @@ def shrink_body(carry): logp_t = eval_at_t(t) neval_new = neval + 1 - # Check if inside slice inside_slice = logp_t >= z0 accepted_new = accepted | inside_slice - # Update interval or accept l_new = jnp.where(inside_slice, l, jnp.where(t < 0, t, l)) r_new = jnp.where(inside_slice, r, jnp.where(t >= 0, t, r)) ncon_new = jnp.where(inside_slice, ncon, ncon + 1) @@ -295,18 +358,27 @@ def shrink_body(carry): logp_acc_new, ) - _, _, _, neval_shrink, ncon_total, _, _, t_final, logp_final = jax.lax.while_loop( + ( + _, + _, + _, + neval_shrink, + ncon_total, + _, + accepted, + t_final, + logp_final, + ) = jax.lax.while_loop( shrink_cond, shrink_body, (key_shrink, l_final, r_final, 0, 0, 0, False, 0.0, logp0), ) - # Compute final position x1 = jax.tree_util.tree_map(lambda x, d: x + t_final * d, x0, direction) neval_total = neval_after_expand + neval_shrink - return x1, logp_final, nexp_total, ncon_total, neval_total + return x1, logp_final, accepted, nexp_total, ncon_total, neval_total def init( @@ -343,7 +415,7 @@ def init( def build_kernel( move: str = "differential", - move_fn: Optional[Callable] = None, + move_fn=None, randomize_split: bool = True, nsplits: int = 2, maxsteps: int = 10000, @@ -357,17 +429,17 @@ def build_kernel( Parameters ---------- move - Type of move to use: "differential" or "random". Ignored if move_fn provided. + Type of move: "differential", "random", or "gaussian". Ignored if move_fn provided. move_fn Optional custom move function. If None, uses the specified move type. randomize_split If True, randomly shuffle walker indices before splitting into groups. nsplits - Number of groups to split the ensemble into. Default is 2. + Number of groups to split the ensemble into. maxsteps Maximum steps for slice stepping-out procedure. maxiter - Maximum iterations to prevent infinite loops. + Maximum iterations for shrinking procedure. tune Whether to enable adaptive tuning of mu. patience @@ -379,24 +451,19 @@ def build_kernel( ------- A kernel function that performs one step of ensemble slice sampling. """ - # Select move function if move_fn is None: if move == "differential": move_fn = differential_direction elif move == "random": move_fn = random_direction - else: - raise ValueError(f"Unknown move type: {move}") - - # At this point move_fn is guaranteed to be Callable - selected_move_fn: Callable = move_fn + elif move == "gaussian": + move_fn = gaussian_direction def kernel( rng_key: PRNGKey, state: SliceEnsembleState, logdensity_fn: Callable ) -> tuple[SliceEnsembleState, SliceEnsembleInfo]: n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape - # Shuffle walkers if requested if randomize_split: key_shuffle, key_update = jax.random.split(rng_key) indices = jax.random.permutation(key_shuffle, n_walkers) @@ -420,7 +487,6 @@ def kernel( shuffled_state = state indices = jnp.arange(n_walkers) - # Split into groups group_size = n_walkers // nsplits groups = [] for i in range(nsplits): @@ -449,15 +515,14 @@ def kernel( ) ) - # Update each group sequentially updated_groups = list(groups) - total_nexp = 0 - total_ncon = 0 - total_neval = 0 + accepted_groups = [] + total_nexp = jnp.array(0, dtype=jnp.int32) + total_ncon = jnp.array(0, dtype=jnp.int32) + total_neval = jnp.array(0, dtype=jnp.int32) keys = jax.random.split(key_update, nsplits) for i in range(nsplits): - # Build complementary ensemble from other groups other_indices = [j for j in range(nsplits) if j != i] comp_coords_list = [updated_groups[j].coords for j in other_indices] comp_log_probs_list = [updated_groups[j].log_probs for j in other_indices] @@ -484,22 +549,21 @@ def kernel( state.patience_count, ) - # Update this group - updated_group, nexp, ncon, neval = _update_half_slice( + updated_group, accepted, nexp, ncon, neval = _update_half_slice( keys[i], groups[i], complementary, logdensity_fn, - selected_move_fn, + move_fn, maxsteps, maxiter, ) updated_groups[i] = updated_group - total_nexp += nexp - total_ncon += ncon - total_neval += neval + accepted_groups.append(accepted) + total_nexp = total_nexp + nexp + total_ncon = total_ncon + ncon + total_neval = total_neval + neval - # Concatenate updated groups shuffled_coords = jax.tree_util.tree_map( lambda *arrays: jnp.concatenate(arrays, axis=0), *[g.coords for g in updated_groups], @@ -507,6 +571,7 @@ def kernel( shuffled_log_probs = jnp.concatenate( [g.log_probs for g in updated_groups], axis=0 ) + shuffled_accepted = jnp.concatenate(accepted_groups, axis=0) if state.blobs is not None: shuffled_blobs = jax.tree_util.tree_map( @@ -516,13 +581,13 @@ def kernel( else: shuffled_blobs = None - # Unshuffle if needed if randomize_split: inverse_indices = jnp.argsort(indices) new_coords = jax.tree_util.tree_map( lambda x: x[inverse_indices], shuffled_coords ) new_log_probs = shuffled_log_probs[inverse_indices] + accepted = shuffled_accepted[inverse_indices] if shuffled_blobs is not None: new_blobs = jax.tree_util.tree_map( lambda x: x[inverse_indices], shuffled_blobs @@ -532,15 +597,14 @@ def kernel( else: new_coords = shuffled_coords new_log_probs = shuffled_log_probs + accepted = shuffled_accepted new_blobs = shuffled_blobs - # Adaptive tuning of mu should_tune = tune & state.tuning_active nexp_eff = jnp.maximum(total_nexp, 1) mu_tuned = state.mu * 2.0 * nexp_eff / (nexp_eff + total_ncon) - # Check convergence of tuning exp_ratio = total_nexp / jnp.maximum(total_nexp + total_ncon, 1) within_tolerance = jnp.abs(exp_ratio - 0.5) < tolerance @@ -549,7 +613,6 @@ def kernel( ) tuning_active_updated = patience_count_updated < patience - # Apply tuning updates conditionally mu_new = jnp.where(should_tune, mu_tuned, state.mu) patience_count_new = jnp.where( should_tune, patience_count_updated, state.patience_count @@ -567,10 +630,10 @@ def kernel( patience_count_new, ) - # Build info (acceptance always 1.0 for slice sampling) + acceptance_rate = jnp.mean(accepted.astype(jnp.float32)) info = SliceEnsembleInfo( - acceptance_rate=1.0, - is_accepted=jnp.ones(n_walkers, dtype=bool), + acceptance_rate=acceptance_rate, + is_accepted=accepted, expansions=total_nexp, contractions=total_ncon, nevals=total_neval, @@ -590,7 +653,7 @@ def _update_half_slice( direction_fn: Callable, maxsteps: int, maxiter: int, -) -> tuple[SliceEnsembleState, int, int, int]: +) -> tuple[SliceEnsembleState, jnp.ndarray, int, int, int]: """Update a group of walkers using ensemble slice sampling. Parameters @@ -612,22 +675,19 @@ def _update_half_slice( Returns ------- - Tuple of (updated_group_state, total_expansions, total_contractions, total_evals). + Tuple of (updated_group_state, accepted_array, total_expansions, total_contractions, total_evals). """ n_update, *_ = jax.tree_util.tree_flatten(walkers_to_update.coords)[0][0].shape - # Generate directions key_dir, key_slice = jax.random.split(rng_key) directions, _ = direction_fn( key_dir, complementary_walkers.coords, n_update, walkers_to_update.mu ) - # Define logprob-only function def logprob_only(x): out = logdensity_fn(x) return out[0] if isinstance(out, tuple) else out - # Perform slice sampling for each walker keys = jax.random.split(key_slice, n_update) def slice_one_walker(key, x0, logp0, direction): @@ -639,16 +699,20 @@ def slice_one_walker(key, x0, logp0, direction): keys, walkers_to_update.coords, walkers_to_update.log_probs, directions ) - new_coords, new_log_probs, nexp_array, ncon_array, neval_array = results + ( + new_coords, + new_log_probs, + accepted_array, + nexp_array, + ncon_array, + neval_array, + ) = results - # Sum statistics total_nexp = jnp.sum(nexp_array) total_ncon = jnp.sum(ncon_array) total_neval = jnp.sum(neval_array) - # Handle blobs if needed if walkers_to_update.blobs is not None: - # Re-evaluate at new positions to get blobs logdensity_outputs = jax.vmap(logdensity_fn)(new_coords) _, new_blobs = logdensity_outputs else: @@ -663,12 +727,13 @@ def slice_one_walker(key, x0, logp0, direction): walkers_to_update.patience_count, ) - return updated_state, total_nexp, total_ncon, total_neval + return updated_state, accepted_array, total_nexp, total_ncon, total_neval def as_top_level_api( logdensity_fn: Callable, move: str = "differential", + move_fn=None, mu: float = 1.0, has_blobs: bool = False, randomize_split: bool = True, @@ -679,26 +744,28 @@ def as_top_level_api( patience: int = 5, tolerance: float = 0.05, ) -> SamplingAlgorithm: - """A user-facing API for the ensemble slice sampling algorithm. + """Ensemble slice sampling algorithm. Parameters ---------- logdensity_fn - A function that returns the log density of the model at a given position. + Function that returns the log density at a given position. move - Type of move: "differential" or "random". + Type of move: "differential", "random", or "gaussian". Ignored if move_fn provided. + move_fn + Optional custom move function. If None, uses the specified move type. mu Initial value of the scale parameter. has_blobs - Whether the logdensity function returns additional information (blobs). + Whether the logdensity function returns additional information. randomize_split If True, randomly shuffle walker indices before splitting into groups. nsplits - Number of groups to split the ensemble into. Default is 2. + Number of groups to split the ensemble into. maxsteps Maximum steps for slice stepping-out procedure. maxiter - Maximum iterations to prevent infinite loops. + Maximum iterations for shrinking procedure. tune Whether to enable adaptive tuning of mu. patience @@ -708,10 +775,11 @@ def as_top_level_api( Returns ------- - A `SamplingAlgorithm` that can be used to sample from the target distribution. + A `SamplingAlgorithm`. """ kernel = build_kernel( move=move, + move_fn=move_fn, randomize_split=randomize_split, nsplits=nsplits, maxsteps=maxsteps, diff --git a/tests/ensemble/test_slice.py b/tests/ensemble/test_slice.py index 1739a71ab..3addc9517 100644 --- a/tests/ensemble/test_slice.py +++ b/tests/ensemble/test_slice.py @@ -11,6 +11,7 @@ SliceEnsembleState, as_top_level_api, differential_direction, + gaussian_direction, init, random_direction, slice_along_direction, @@ -66,13 +67,14 @@ def logprob_fn(x): logp0 = logprob_fn(x0) direction = 1.0 - x1, logp1, nexp, ncon, neval = slice_along_direction( + x1, logp1, accepted, nexp, ncon, neval = slice_along_direction( rng_key, x0, logp0, direction, logprob_fn, maxsteps=100, maxiter=1000 ) # Check that we got valid results self.assertTrue(jnp.isfinite(x1)) self.assertTrue(jnp.isfinite(logp1)) + self.assertTrue(accepted) # Should have found a valid point self.assertGreater(neval, 0) def test_init(self): @@ -205,6 +207,44 @@ def logdensity_fn(x): self.assertIsInstance(new_state, SliceEnsembleState) self.assertIsInstance(info, SliceEnsembleInfo) + def test_gaussian_direction_array(self): + """Test that gaussian_direction produces valid directions for arrays.""" + rng_key = jax.random.PRNGKey(42) + + complementary_coords = jnp.array([[0.0, 0.0], [2.0, 4.0], [3.0, 1.0]]) + n_update = 2 + mu = 1.0 + + directions, tune_once = gaussian_direction( + rng_key, complementary_coords, n_update, mu + ) + + self.assertEqual(directions.shape, (n_update, 2)) + self.assertTrue(tune_once) + self.assertTrue(jnp.isfinite(directions).all()) + + def test_gaussian_direction_pytree(self): + """Test that gaussian_direction works with PyTree coordinates.""" + rng_key = jax.random.PRNGKey(43) + + complementary_coords = { + "x": jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]), + "y": jnp.array([[0.5], [1.5], [2.5]]), + } + n_update = 2 + mu = 1.0 + + directions, tune_once = gaussian_direction( + rng_key, complementary_coords, n_update, mu + ) + + self.assertIsInstance(directions, dict) + self.assertEqual(directions["x"].shape, (n_update, 2)) + self.assertEqual(directions["y"].shape, (n_update, 1)) + self.assertTrue(tune_once) + self.assertTrue(jnp.isfinite(directions["x"]).all()) + self.assertTrue(jnp.isfinite(directions["y"]).all()) + def test_random_move(self): """Test ensemble slice sampling with random move.""" @@ -215,13 +255,11 @@ def logdensity_fn(x): n_walkers = 10 initial_position = jax.random.normal(rng_key, (n_walkers, 2)) - # Use random move instead of differential algorithm = as_top_level_api( logdensity_fn, move="random", mu=1.0, maxsteps=100, maxiter=1000 ) initial_state = algorithm.init(initial_position) - # Run a few steps keys = jax.random.split(rng_key, 10) def run_step(state, key): @@ -230,10 +268,35 @@ def run_step(state, key): final_state, infos = jax.lax.scan(run_step, initial_state, keys) - # Check valid results self.assertIsInstance(final_state, SliceEnsembleState) self.assertTrue(jnp.all(infos.acceptance_rate == 1.0)) + def test_gaussian_move(self): + """Test ensemble slice sampling with gaussian move.""" + + def logdensity_fn(x): + return stats.norm.logpdf(x, 0.0, 1.0).sum() + + rng_key = jax.random.PRNGKey(100) + n_walkers = 10 + initial_position = jax.random.normal(rng_key, (n_walkers, 2)) + + algorithm = as_top_level_api( + logdensity_fn, move="gaussian", mu=1.0, maxsteps=100, maxiter=1000 + ) + initial_state = algorithm.init(initial_position) + + keys = jax.random.split(rng_key, 10) + + def run_step(state, key): + new_state, info = algorithm.step(key, state) + return new_state, info + + final_state, infos = jax.lax.scan(run_step, initial_state, keys) + + self.assertIsInstance(final_state, SliceEnsembleState) + self.assertTrue(jnp.all(infos.acceptance_rate >= 0.0)) + if __name__ == "__main__": absltest.main() From 7921cea7e697ea956583e399e59963102813e63e Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 23:37:41 +0100 Subject: [PATCH 11/14] Extract shared ensemble utilities to reduce code duplication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move common ensemble manipulation patterns into shared utilities in base.py: - Shuffle/unshuffle operations for randomized splitting - Split/concatenate operations for group management - Complementary ensemble building - Masked selection and vmapped logdensity helpers Refactor both slice and stretch kernels to use these utilities, reducing duplication by ~120 lines total. Add Apache license header to slice.py and unify acceptance_rate dtype computation. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- blackjax/ensemble/base.py | 98 ++++++++++++++++++++++- blackjax/ensemble/slice.py | 147 +++++++++++++---------------------- blackjax/ensemble/stretch.py | 109 +++++++------------------- 3 files changed, 178 insertions(+), 176 deletions(-) diff --git a/blackjax/ensemble/base.py b/blackjax/ensemble/base.py index f1156cd53..227683ce5 100644 --- a/blackjax/ensemble/base.py +++ b/blackjax/ensemble/base.py @@ -14,9 +14,24 @@ """Base classes and utilities for ensemble sampling methods.""" from typing import NamedTuple, Optional +import jax +import jax.numpy as jnp + from blackjax.types import Array, ArrayTree -__all__ = ["EnsembleState", "EnsembleInfo"] +__all__ = [ + "EnsembleState", + "EnsembleInfo", + "get_nwalkers", + "tree_take", + "shuffle_triple", + "unshuffle_triple", + "split_triple", + "concat_triple_groups", + "complementary_triple", + "masked_select", + "vmapped_logdensity", +] class EnsembleState(NamedTuple): @@ -50,3 +65,84 @@ class EnsembleInfo(NamedTuple): acceptance_rate: float is_accepted: Array + + +def get_nwalkers(coords: ArrayTree) -> int: + """Get the number of walkers from ensemble coordinates.""" + return jax.tree_util.tree_flatten(coords)[0][0].shape[0] + + +def tree_take(tree: ArrayTree, idx: jnp.ndarray) -> ArrayTree: + """Index into a PyTree along the leading dimension.""" + return jax.tree_util.tree_map(lambda a: a[idx], tree) + + +def shuffle_triple(key, coords, log_probs, blobs): + """Shuffle ensemble coordinates, log_probs, and blobs.""" + n = get_nwalkers(coords) + idx = jax.random.permutation(key, n) + coords_s = tree_take(coords, idx) + log_probs_s = log_probs[idx] + blobs_s = None if blobs is None else tree_take(blobs, idx) + return coords_s, log_probs_s, blobs_s, idx + + +def unshuffle_triple(coords, log_probs, blobs, indices): + """Reverse a shuffle operation on ensemble coordinates, log_probs, and blobs.""" + inv = jnp.argsort(indices) + coords_u = tree_take(coords, inv) + log_probs_u = log_probs[inv] + blobs_u = None if blobs is None else tree_take(blobs, inv) + return coords_u, log_probs_u, blobs_u + + +def split_triple(coords, log_probs, blobs, nsplits): + """Split ensemble into nsplits contiguous groups.""" + n = get_nwalkers(coords) + group_size = n // nsplits + groups = [] + for i in range(nsplits): + s = i * group_size + e = (i + 1) * group_size if i < nsplits - 1 else n + coords_i = jax.tree_util.tree_map(lambda a: a[s:e], coords) + log_probs_i = log_probs[s:e] + blobs_i = ( + None if blobs is None else jax.tree_util.tree_map(lambda a: a[s:e], blobs) + ) + groups.append((coords_i, log_probs_i, blobs_i)) + return groups + + +def concat_triple_groups(group_triples): + """Concatenate groups of (coords, log_probs, blobs) triples.""" + coords_list, logp_list, blobs_list = zip(*group_triples) + coords = jax.tree_util.tree_map( + lambda *xs: jnp.concatenate(xs, axis=0), *coords_list + ) + logp = jnp.concatenate(logp_list, axis=0) + blobs = ( + None + if all(b is None for b in blobs_list) + else jax.tree_util.tree_map( + lambda *xs: jnp.concatenate(xs, axis=0), *blobs_list + ) + ) + return coords, logp, blobs + + +def complementary_triple(groups, i): + """Build complementary ensemble from all groups except group i.""" + return concat_triple_groups([g for j, g in enumerate(groups) if j != i]) + + +def masked_select(mask, new_val, old_val): + """Select between new and old values based on mask.""" + expand_dims = (1,) * (new_val.ndim - 1) + mask_expanded = mask.reshape((mask.shape[0],) + expand_dims) + return jnp.where(mask_expanded, new_val, old_val) + + +def vmapped_logdensity(logdensity_fn, coords): + """Evaluate logdensity function on ensemble coordinates with vmap.""" + outs = jax.vmap(logdensity_fn)(coords) + return outs if isinstance(outs, tuple) else (outs, None) diff --git a/blackjax/ensemble/slice.py b/blackjax/ensemble/slice.py index 3d4bb2607..9069a1a3d 100644 --- a/blackjax/ensemble/slice.py +++ b/blackjax/ensemble/slice.py @@ -1,6 +1,16 @@ +# Copyright 2020- The Blackjax Authors. # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # +# http://www.apache.org/licenses/LICENSE-2.0 # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Ensemble Slice Sampling (ESS) implementation.""" from typing import Callable, NamedTuple, Optional @@ -8,6 +18,14 @@ import jax.numpy as jnp from blackjax.base import SamplingAlgorithm +from blackjax.ensemble.base import ( + complementary_triple, + concat_triple_groups, + get_nwalkers, + shuffle_triple, + split_triple, + unshuffle_triple, +) from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = [ @@ -462,58 +480,37 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: SliceEnsembleState, logdensity_fn: Callable ) -> tuple[SliceEnsembleState, SliceEnsembleInfo]: - n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape - if randomize_split: key_shuffle, key_update = jax.random.split(rng_key) - indices = jax.random.permutation(key_shuffle, n_walkers) - shuffled_coords = jax.tree_util.tree_map(lambda x: x[indices], state.coords) - shuffled_log_probs = state.log_probs[indices] - shuffled_blobs = ( - None - if state.blobs is None - else jax.tree_util.tree_map(lambda x: x[indices], state.blobs) - ) - shuffled_state = SliceEnsembleState( - shuffled_coords, - shuffled_log_probs, - shuffled_blobs, - state.mu, - state.tuning_active, - state.patience_count, + coords_s, logp_s, blobs_s, indices = shuffle_triple( + key_shuffle, state.coords, state.log_probs, state.blobs ) else: key_update = rng_key - shuffled_state = state - indices = jnp.arange(n_walkers) - - group_size = n_walkers // nsplits - groups = [] - for i in range(nsplits): - start_idx = i * group_size - end_idx = (i + 1) * group_size if i < nsplits - 1 else n_walkers + coords_s, logp_s, blobs_s = state.coords, state.log_probs, state.blobs + indices = jnp.arange(get_nwalkers(state.coords)) + + shuffled_state = SliceEnsembleState( + coords_s, + logp_s, + blobs_s, + state.mu, + state.tuning_active, + state.patience_count, + ) - group_coords = jax.tree_util.tree_map( - lambda x: x[start_idx:end_idx], shuffled_state.coords - ) - group_log_probs = shuffled_state.log_probs[start_idx:end_idx] - group_blobs = ( - None - if shuffled_state.blobs is None - else jax.tree_util.tree_map( - lambda x: x[start_idx:end_idx], shuffled_state.blobs - ) - ) - groups.append( - SliceEnsembleState( - group_coords, - group_log_probs, - group_blobs, - state.mu, - state.tuning_active, - state.patience_count, - ) + group_triples = split_triple( + shuffled_state.coords, + shuffled_state.log_probs, + shuffled_state.blobs, + nsplits, + ) + groups = [ + SliceEnsembleState( + t[0], t[1], t[2], state.mu, state.tuning_active, state.patience_count ) + for t in group_triples + ] updated_groups = list(groups) accepted_groups = [] @@ -523,27 +520,13 @@ def kernel( keys = jax.random.split(key_update, nsplits) for i in range(nsplits): - other_indices = [j for j in range(nsplits) if j != i] - comp_coords_list = [updated_groups[j].coords for j in other_indices] - comp_log_probs_list = [updated_groups[j].log_probs for j in other_indices] - comp_blobs_list = [updated_groups[j].blobs for j in other_indices] - - complementary_coords = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_coords_list + comp_triple = complementary_triple( + [(g.coords, g.log_probs, g.blobs) for g in updated_groups], i ) - complementary_log_probs = jnp.concatenate(comp_log_probs_list, axis=0) - - if state.blobs is not None: - complementary_blobs = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_blobs_list - ) - else: - complementary_blobs = None - complementary = SliceEnsembleState( - complementary_coords, - complementary_log_probs, - complementary_blobs, + comp_triple[0], + comp_triple[1], + comp_triple[2], state.mu, state.tuning_active, state.patience_count, @@ -564,41 +547,19 @@ def kernel( total_ncon = total_ncon + ncon total_neval = total_neval + neval - shuffled_coords = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), - *[g.coords for g in updated_groups], - ) - shuffled_log_probs = jnp.concatenate( - [g.log_probs for g in updated_groups], axis=0 + coords_cat, logp_cat, blobs_cat = concat_triple_groups( + [(g.coords, g.log_probs, g.blobs) for g in updated_groups] ) shuffled_accepted = jnp.concatenate(accepted_groups, axis=0) - if state.blobs is not None: - shuffled_blobs = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), - *[g.blobs for g in updated_groups], - ) - else: - shuffled_blobs = None - if randomize_split: - inverse_indices = jnp.argsort(indices) - new_coords = jax.tree_util.tree_map( - lambda x: x[inverse_indices], shuffled_coords + new_coords, new_log_probs, new_blobs = unshuffle_triple( + coords_cat, logp_cat, blobs_cat, indices ) - new_log_probs = shuffled_log_probs[inverse_indices] - accepted = shuffled_accepted[inverse_indices] - if shuffled_blobs is not None: - new_blobs = jax.tree_util.tree_map( - lambda x: x[inverse_indices], shuffled_blobs - ) - else: - new_blobs = None + accepted = shuffled_accepted[jnp.argsort(indices)] else: - new_coords = shuffled_coords - new_log_probs = shuffled_log_probs + new_coords, new_log_probs, new_blobs = coords_cat, logp_cat, blobs_cat accepted = shuffled_accepted - new_blobs = shuffled_blobs should_tune = tune & state.tuning_active @@ -630,7 +591,7 @@ def kernel( patience_count_new, ) - acceptance_rate = jnp.mean(accepted.astype(jnp.float32)) + acceptance_rate = jnp.mean(accepted) info = SliceEnsembleInfo( acceptance_rate=acceptance_rate, is_accepted=accepted, diff --git a/blackjax/ensemble/stretch.py b/blackjax/ensemble/stretch.py index 062910cec..303b6242b 100644 --- a/blackjax/ensemble/stretch.py +++ b/blackjax/ensemble/stretch.py @@ -19,7 +19,16 @@ from jax.flatten_util import ravel_pytree from blackjax.base import SamplingAlgorithm -from blackjax.ensemble.base import EnsembleInfo, EnsembleState +from blackjax.ensemble.base import ( + EnsembleInfo, + EnsembleState, + complementary_triple, + concat_triple_groups, + get_nwalkers, + shuffle_triple, + split_triple, + unshuffle_triple, +) from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey __all__ = [ @@ -109,70 +118,28 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: EnsembleState, logdensity_fn: Callable ) -> tuple[EnsembleState, EnsembleInfo]: - n_walkers, *_ = jax.tree_util.tree_flatten(state.coords)[0][0].shape - if randomize_split: key_shuffle, key_update = jax.random.split(rng_key) - indices = jax.random.permutation(key_shuffle, n_walkers) - shuffled_coords = jax.tree_util.tree_map(lambda x: x[indices], state.coords) - shuffled_log_probs = state.log_probs[indices] - shuffled_blobs = ( - None - if state.blobs is None - else jax.tree_util.tree_map(lambda x: x[indices], state.blobs) - ) - shuffled_state = EnsembleState( - shuffled_coords, shuffled_log_probs, shuffled_blobs + coords_s, logp_s, blobs_s, indices = shuffle_triple( + key_shuffle, state.coords, state.log_probs, state.blobs ) else: key_update = rng_key - shuffled_state = state - indices = jnp.arange(n_walkers) + coords_s, logp_s, blobs_s = state.coords, state.log_probs, state.blobs + indices = jnp.arange(get_nwalkers(state.coords)) - group_size = n_walkers // nsplits - groups = [] - for i in range(nsplits): - start_idx = i * group_size - end_idx = (i + 1) * group_size if i < nsplits - 1 else n_walkers - - group_coords = jax.tree_util.tree_map( - lambda x: x[start_idx:end_idx], shuffled_state.coords - ) - group_log_probs = shuffled_state.log_probs[start_idx:end_idx] - group_blobs = ( - None - if shuffled_state.blobs is None - else jax.tree_util.tree_map( - lambda x: x[start_idx:end_idx], shuffled_state.blobs - ) - ) - groups.append(EnsembleState(group_coords, group_log_probs, group_blobs)) + group_triples = split_triple(coords_s, logp_s, blobs_s, nsplits) + groups = [EnsembleState(*t) for t in group_triples] updated_groups = list(groups) accepted_groups = [] keys = jax.random.split(key_update, nsplits) for i in range(nsplits): - other_indices = [j for j in range(nsplits) if j != i] - comp_coords_list = [updated_groups[j].coords for j in other_indices] - comp_log_probs_list = [updated_groups[j].log_probs for j in other_indices] - comp_blobs_list = [updated_groups[j].blobs for j in other_indices] - - complementary_coords = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_coords_list - ) - complementary_log_probs = jnp.concatenate(comp_log_probs_list, axis=0) - - if state.blobs is not None: - complementary_blobs = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), *comp_blobs_list - ) - else: - complementary_blobs = None - - complementary = EnsembleState( - complementary_coords, complementary_log_probs, complementary_blobs + comp_triple = complementary_triple( + [(g.coords, g.log_probs, g.blobs) for g in updated_groups], i ) + complementary = EnsembleState(*comp_triple) updated_group, accepted = _update_half( keys[i], groups[i], complementary, logdensity_fn, move_fn @@ -180,41 +147,19 @@ def kernel( updated_groups[i] = updated_group accepted_groups.append(accepted) - shuffled_coords = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), - *[g.coords for g in updated_groups], - ) - shuffled_log_probs = jnp.concatenate( - [g.log_probs for g in updated_groups], axis=0 + coords_cat, logp_cat, blobs_cat = concat_triple_groups( + [(g.coords, g.log_probs, g.blobs) for g in updated_groups] ) - shuffled_accepted = jnp.concatenate(accepted_groups, axis=0) - - if state.blobs is not None: - shuffled_blobs = jax.tree_util.tree_map( - lambda *arrays: jnp.concatenate(arrays, axis=0), - *[g.blobs for g in updated_groups], - ) - else: - shuffled_blobs = None + accepted_cat = jnp.concatenate(accepted_groups, axis=0) if randomize_split: - inverse_indices = jnp.argsort(indices) - new_coords = jax.tree_util.tree_map( - lambda x: x[inverse_indices], shuffled_coords + new_coords, new_log_probs, new_blobs = unshuffle_triple( + coords_cat, logp_cat, blobs_cat, indices ) - new_log_probs = shuffled_log_probs[inverse_indices] - accepted = shuffled_accepted[inverse_indices] - if shuffled_blobs is not None: - new_blobs = jax.tree_util.tree_map( - lambda x: x[inverse_indices], shuffled_blobs - ) - else: - new_blobs = None + accepted = accepted_cat[jnp.argsort(indices)] else: - new_coords = shuffled_coords - new_log_probs = shuffled_log_probs - accepted = shuffled_accepted - new_blobs = shuffled_blobs + new_coords, new_log_probs, new_blobs = coords_cat, logp_cat, blobs_cat + accepted = accepted_cat new_state = EnsembleState(new_coords, new_log_probs, new_blobs) acceptance_rate = jnp.mean(accepted) From fc977e32184a3d8c0f609fe56d2c44d4e5fd7be7 Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 23:46:28 +0100 Subject: [PATCH 12/14] Further reduce ensemble code duplication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add additional utilities to base.py and use throughout: - Use vmapped_logdensity in all init functions and blob re-evaluation - Replace n_update calculation with get_nwalkers calls - Add prepare_split helper to eliminate randomize-split boilerplate - Add unshuffle_1d helper for 1D array unshuffling - Add build_states_from_triples helper for state construction Remove stretch._masked_select in favor of base.masked_select. Reduces duplication by ~40 additional lines while maintaining full test coverage. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- blackjax/ensemble/base.py | 36 ++++++++++++++++++++ blackjax/ensemble/slice.py | 62 +++++++++++---------------------- blackjax/ensemble/stretch.py | 66 +++++++++++------------------------- 3 files changed, 76 insertions(+), 88 deletions(-) diff --git a/blackjax/ensemble/base.py b/blackjax/ensemble/base.py index 227683ce5..d3cbe3fff 100644 --- a/blackjax/ensemble/base.py +++ b/blackjax/ensemble/base.py @@ -31,6 +31,9 @@ "complementary_triple", "masked_select", "vmapped_logdensity", + "prepare_split", + "unshuffle_1d", + "build_states_from_triples", ] @@ -146,3 +149,36 @@ def vmapped_logdensity(logdensity_fn, coords): """Evaluate logdensity function on ensemble coordinates with vmap.""" outs = jax.vmap(logdensity_fn)(coords) return outs if isinstance(outs, tuple) else (outs, None) + + +def prepare_split(rng_key, coords, log_probs, blobs, randomize_split, nsplits): + """Prepare ensemble for splitting into groups. + + Handles optional randomization, splitting, and returns components needed + for group-wise updates and subsequent unshuffling. + """ + if randomize_split: + key_shuffle, key_update = jax.random.split(rng_key) + coords_s, logp_s, blobs_s, indices = shuffle_triple( + key_shuffle, coords, log_probs, blobs + ) + else: + key_update = rng_key + coords_s, logp_s, blobs_s = coords, log_probs, blobs + indices = jnp.arange(get_nwalkers(coords)) + group_triples = split_triple(coords_s, logp_s, blobs_s, nsplits) + return key_update, group_triples, indices + + +def unshuffle_1d(arr, indices): + """Reverse shuffle operation on a 1D per-walker array.""" + return arr[jnp.argsort(indices)] + + +def build_states_from_triples(group_triples, state_ctor, extra_fields=()): + """Build state objects from triples with optional extra fields. + + Handles both base EnsembleState and algorithm-specific states like + SliceEnsembleState that have additional fields. + """ + return [state_ctor(t[0], t[1], t[2], *extra_fields) for t in group_triples] diff --git a/blackjax/ensemble/slice.py b/blackjax/ensemble/slice.py index 9069a1a3d..fb803159e 100644 --- a/blackjax/ensemble/slice.py +++ b/blackjax/ensemble/slice.py @@ -19,12 +19,14 @@ from blackjax.base import SamplingAlgorithm from blackjax.ensemble.base import ( + build_states_from_triples, complementary_triple, concat_triple_groups, get_nwalkers, - shuffle_triple, - split_triple, + prepare_split, + unshuffle_1d, unshuffle_triple, + vmapped_logdensity, ) from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey @@ -422,13 +424,8 @@ def init( ------- Initial SliceEnsembleState. """ - logdensity_outputs = jax.vmap(logdensity_fn)(position) - if isinstance(logdensity_outputs, tuple): - log_probs, blobs = logdensity_outputs - return SliceEnsembleState(position, log_probs, blobs, mu, True, 0) - else: - log_probs = logdensity_outputs - return SliceEnsembleState(position, log_probs, None, mu, True, 0) + log_probs, blobs = vmapped_logdensity(logdensity_fn, position) + return SliceEnsembleState(position, log_probs, blobs, mu, True, 0) def build_kernel( @@ -480,37 +477,19 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: SliceEnsembleState, logdensity_fn: Callable ) -> tuple[SliceEnsembleState, SliceEnsembleInfo]: - if randomize_split: - key_shuffle, key_update = jax.random.split(rng_key) - coords_s, logp_s, blobs_s, indices = shuffle_triple( - key_shuffle, state.coords, state.log_probs, state.blobs - ) - else: - key_update = rng_key - coords_s, logp_s, blobs_s = state.coords, state.log_probs, state.blobs - indices = jnp.arange(get_nwalkers(state.coords)) - - shuffled_state = SliceEnsembleState( - coords_s, - logp_s, - blobs_s, - state.mu, - state.tuning_active, - state.patience_count, - ) - - group_triples = split_triple( - shuffled_state.coords, - shuffled_state.log_probs, - shuffled_state.blobs, + key_update, group_triples, indices = prepare_split( + rng_key, + state.coords, + state.log_probs, + state.blobs, + randomize_split, nsplits, ) - groups = [ - SliceEnsembleState( - t[0], t[1], t[2], state.mu, state.tuning_active, state.patience_count - ) - for t in group_triples - ] + groups = build_states_from_triples( + group_triples, + SliceEnsembleState, + (state.mu, state.tuning_active, state.patience_count), + ) updated_groups = list(groups) accepted_groups = [] @@ -556,7 +535,7 @@ def kernel( new_coords, new_log_probs, new_blobs = unshuffle_triple( coords_cat, logp_cat, blobs_cat, indices ) - accepted = shuffled_accepted[jnp.argsort(indices)] + accepted = unshuffle_1d(shuffled_accepted, indices) else: new_coords, new_log_probs, new_blobs = coords_cat, logp_cat, blobs_cat accepted = shuffled_accepted @@ -638,7 +617,7 @@ def _update_half_slice( ------- Tuple of (updated_group_state, accepted_array, total_expansions, total_contractions, total_evals). """ - n_update, *_ = jax.tree_util.tree_flatten(walkers_to_update.coords)[0][0].shape + n_update = get_nwalkers(walkers_to_update.coords) key_dir, key_slice = jax.random.split(rng_key) directions, _ = direction_fn( @@ -674,8 +653,7 @@ def slice_one_walker(key, x0, logp0, direction): total_neval = jnp.sum(neval_array) if walkers_to_update.blobs is not None: - logdensity_outputs = jax.vmap(logdensity_fn)(new_coords) - _, new_blobs = logdensity_outputs + _, new_blobs = vmapped_logdensity(logdensity_fn, new_coords) else: new_blobs = None diff --git a/blackjax/ensemble/stretch.py b/blackjax/ensemble/stretch.py index 303b6242b..97c92a280 100644 --- a/blackjax/ensemble/stretch.py +++ b/blackjax/ensemble/stretch.py @@ -22,12 +22,15 @@ from blackjax.ensemble.base import ( EnsembleInfo, EnsembleState, + build_states_from_triples, complementary_triple, concat_triple_groups, get_nwalkers, - shuffle_triple, - split_triple, + masked_select, + prepare_split, + unshuffle_1d, unshuffle_triple, + vmapped_logdensity, ) from blackjax.types import ArrayLikeTree, ArrayTree, PRNGKey @@ -118,18 +121,15 @@ def build_kernel( def kernel( rng_key: PRNGKey, state: EnsembleState, logdensity_fn: Callable ) -> tuple[EnsembleState, EnsembleInfo]: - if randomize_split: - key_shuffle, key_update = jax.random.split(rng_key) - coords_s, logp_s, blobs_s, indices = shuffle_triple( - key_shuffle, state.coords, state.log_probs, state.blobs - ) - else: - key_update = rng_key - coords_s, logp_s, blobs_s = state.coords, state.log_probs, state.blobs - indices = jnp.arange(get_nwalkers(state.coords)) - - group_triples = split_triple(coords_s, logp_s, blobs_s, nsplits) - groups = [EnsembleState(*t) for t in group_triples] + key_update, group_triples, indices = prepare_split( + rng_key, + state.coords, + state.log_probs, + state.blobs, + randomize_split, + nsplits, + ) + groups = build_states_from_triples(group_triples, EnsembleState) updated_groups = list(groups) accepted_groups = [] @@ -156,7 +156,7 @@ def kernel( new_coords, new_log_probs, new_blobs = unshuffle_triple( coords_cat, logp_cat, blobs_cat, indices ) - accepted = accepted_cat[jnp.argsort(indices)] + accepted = unshuffle_1d(accepted_cat, indices) else: new_coords, new_log_probs, new_blobs = coords_cat, logp_cat, blobs_cat accepted = accepted_cat @@ -170,32 +170,11 @@ def kernel( return kernel -def _masked_select(mask, new_val, old_val): - """Helper to broadcast mask to match array rank for jnp.where. - - Parameters - ---------- - mask - Boolean mask with shape (n_walkers,) - new_val - New values to select when mask is True - old_val - Old values to select when mask is False - - Returns - ------- - Array with same shape as new_val/old_val, with values selected per mask - """ - expand_dims = (1,) * (new_val.ndim - 1) - mask_expanded = mask.reshape((mask.shape[0],) + expand_dims) - return jnp.where(mask_expanded, new_val, old_val) - - def _update_half( rng_key, walkers_to_update, complementary_walkers, logdensity_fn, move_fn ): """Helper to update one half of the ensemble.""" - n_update, *_ = jax.tree_util.tree_flatten(walkers_to_update.coords)[0][0].shape + n_update = get_nwalkers(walkers_to_update.coords) key_moves, key_accept = jax.random.split(rng_key) keys = jax.random.split(key_moves, n_update) @@ -231,7 +210,7 @@ def _update_half( accepted = jnp.log(u) < log_p_accept new_coords = jax.tree_util.tree_map( - lambda prop, old: _masked_select(accepted, prop, old), + lambda prop, old: masked_select(accepted, prop, old), proposals, walkers_to_update.coords, ) @@ -239,7 +218,7 @@ def _update_half( if walkers_to_update.blobs is not None: new_blobs = jax.tree_util.tree_map( - lambda prop, old: _masked_select(accepted, prop, old), + lambda prop, old: masked_select(accepted, prop, old), blobs_proposal, walkers_to_update.blobs, ) @@ -266,13 +245,8 @@ def init( has_blobs Whether the log-density function returns additional metadata (blobs). """ - logdensity_outputs = jax.vmap(logdensity_fn)(position) - if isinstance(logdensity_outputs, tuple): - log_probs, blobs = logdensity_outputs - return EnsembleState(position, log_probs, blobs) - else: - log_probs = logdensity_outputs - return EnsembleState(position, log_probs, None) + log_probs, blobs = vmapped_logdensity(logdensity_fn, position) + return EnsembleState(position, log_probs, blobs) def as_top_level_api( From 4db587243de58ba348384590b83e03e480acd11c Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 23:48:17 +0100 Subject: [PATCH 13/14] Remove unnecessary unshuffle_1d helper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The one-liner arr[jnp.argsort(indices)] is already clear and doesn't benefit from wrapping in a helper function. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- blackjax/ensemble/base.py | 6 ------ blackjax/ensemble/slice.py | 3 +-- blackjax/ensemble/stretch.py | 3 +-- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/blackjax/ensemble/base.py b/blackjax/ensemble/base.py index d3cbe3fff..c897c5f6e 100644 --- a/blackjax/ensemble/base.py +++ b/blackjax/ensemble/base.py @@ -32,7 +32,6 @@ "masked_select", "vmapped_logdensity", "prepare_split", - "unshuffle_1d", "build_states_from_triples", ] @@ -170,11 +169,6 @@ def prepare_split(rng_key, coords, log_probs, blobs, randomize_split, nsplits): return key_update, group_triples, indices -def unshuffle_1d(arr, indices): - """Reverse shuffle operation on a 1D per-walker array.""" - return arr[jnp.argsort(indices)] - - def build_states_from_triples(group_triples, state_ctor, extra_fields=()): """Build state objects from triples with optional extra fields. diff --git a/blackjax/ensemble/slice.py b/blackjax/ensemble/slice.py index fb803159e..bb4afa6ee 100644 --- a/blackjax/ensemble/slice.py +++ b/blackjax/ensemble/slice.py @@ -24,7 +24,6 @@ concat_triple_groups, get_nwalkers, prepare_split, - unshuffle_1d, unshuffle_triple, vmapped_logdensity, ) @@ -535,7 +534,7 @@ def kernel( new_coords, new_log_probs, new_blobs = unshuffle_triple( coords_cat, logp_cat, blobs_cat, indices ) - accepted = unshuffle_1d(shuffled_accepted, indices) + accepted = shuffled_accepted[jnp.argsort(indices)] else: new_coords, new_log_probs, new_blobs = coords_cat, logp_cat, blobs_cat accepted = shuffled_accepted diff --git a/blackjax/ensemble/stretch.py b/blackjax/ensemble/stretch.py index 97c92a280..f503960d7 100644 --- a/blackjax/ensemble/stretch.py +++ b/blackjax/ensemble/stretch.py @@ -28,7 +28,6 @@ get_nwalkers, masked_select, prepare_split, - unshuffle_1d, unshuffle_triple, vmapped_logdensity, ) @@ -156,7 +155,7 @@ def kernel( new_coords, new_log_probs, new_blobs = unshuffle_triple( coords_cat, logp_cat, blobs_cat, indices ) - accepted = unshuffle_1d(accepted_cat, indices) + accepted = accepted_cat[jnp.argsort(indices)] else: new_coords, new_log_probs, new_blobs = coords_cat, logp_cat, blobs_cat accepted = accepted_cat From 8c247e86b8232039006b3695226d2b635a7b2ca0 Mon Sep 17 00:00:00 2001 From: Will Handley Date: Wed, 15 Oct 2025 23:52:40 +0100 Subject: [PATCH 14/14] Use vmapped_logdensity in stretch._update_half MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace manual tuple vs non-tuple handling with the shared helper for consistency with the rest of the codebase. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- blackjax/ensemble/stretch.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/blackjax/ensemble/stretch.py b/blackjax/ensemble/stretch.py index f503960d7..cbf7dde9b 100644 --- a/blackjax/ensemble/stretch.py +++ b/blackjax/ensemble/stretch.py @@ -182,12 +182,7 @@ def _update_half( lambda k, w_coords: move_fn(k, w_coords, complementary_walkers.coords) )(keys, walkers_to_update.coords) - logdensity_outputs = jax.vmap(logdensity_fn)(proposals) - if isinstance(logdensity_outputs, tuple): - log_probs_proposal, blobs_proposal = logdensity_outputs - else: - log_probs_proposal = logdensity_outputs - blobs_proposal = None + log_probs_proposal, blobs_proposal = vmapped_logdensity(logdensity_fn, proposals) log_p_accept = ( log_hastings_ratios + log_probs_proposal - walkers_to_update.log_probs