Skip to content

Conversation

@williamjameshandley
Copy link

This PR adds two affine-invariant ensemble MCMC samplers to BlackJAX: the stretch move sampler and ensemble slice sampling.

Overview

Ensemble samplers use multiple interacting walkers to explore the posterior distribution without requiring gradient information. They are popular in astronomy due to their self-tuning capability. Despite their inability to use gradient information, there is still a benefit to JAX if the probability density is appropriately vectorized over multiple calls.

Stretch Move Sampler

The stretch move is an affine-invariant ensemble sampler introduced by Goodman & Weare (2010) and popularized by the emcee package (Foreman-Mackey et al. 2013).

Key features:

  • Affine-invariant: performance independent of linear transformations
  • Uses complementary ensemble to propose moves along random directions
  • Single tuning parameter a (stretch scale)
  • Supports arbitrary nsplits for group splitting (default: 2)
  • Optional randomized splitting for improved mixing

Ensemble Slice Sampling

Ensemble slice sampling (ESS) extends the slice sampling framework to ensemble methods, as implemented in the Zeus package (Karamanis et al. 2021).

Key features:

  • Combines slice sampling with ensemble proposals
  • Three direction modes: differential, random, gaussian
  • Adaptive tuning of scale parameter mu via Robbins-Monro
  • Built-in convergence diagnostics through expansion/contraction ratios
  • Supports arbitrary nsplits for group splitting (default: 2)

API Design

Both samplers follow BlackJAX conventions and expose consistent APIs:

import blackjax

# Stretch move sampler
stretch = blackjax.ensemble.stretch.as_top_level_api(
    logdensity_fn,
    a=2.0,
    randomize_split=True,
    nsplits=2
)

# Ensemble slice sampling
slice_sampler = blackjax.ensemble.slice.as_top_level_api(
    logdensity_fn,
    move="differential",  # or "random", "gaussian"
    mu=1.0,
    tune=True,
    randomize_split=True,
    nsplits=2
)

Both samplers:

  • Support PyTree coordinates for structured parameters
  • Return EnsembleState with coords, log_probs, and optional blobs
  • Provide acceptance diagnostics via EnsembleInfo
  • Are fully compatible with JAX transformations (jit, vmap)

Implementation

The ensemble module is organized into three files:

  • base.py: Shared utilities for ensemble manipulation (shuffle, split, concatenate, complement operations)
  • stretch.py: Stretch move implementation (~250 lines)
  • slice.py: Ensemble slice sampling implementation (~700 lines including direction functions and slice procedure)

Extensive refactoring consolidated common patterns into base.py, eliminating ~165 lines of duplication while maintaining clean, readable implementations.

Testing

Comprehensive test coverage with 15 tests across both samplers:

Stretch sampler (4 tests):

  • Basic stretch move mechanics
  • 2D Gaussian sampling and convergence
  • PyTree coordinate support

Slice sampler (11 tests):

  • All three direction functions (differential, random, gaussian)
  • 1D and 2D Gaussian sampling
  • PyTree support for coordinates and directions
  • JIT compilation
  • Initialization with and without blobs

All tests pass with 100% coverage of core functionality.

References

  • Goodman & Weare (2010). "Ensemble samplers with affine invariance." Communications in Applied Mathematics and Computational Science, 5(1), 65-80. DOI: 10.2140/camcos.2010.5.65
  • Foreman-Mackey et al. (2013). "emcee: The MCMC Hammer." PASP, 125(925), 306. arXiv: 1202.3665
  • Karamanis et al. (2021). "zeus: A Python implementation of Ensemble Slice Sampling for efficient Bayesian parameter inference." arXiv: 2105.03468

Code references:

  • emcee: Stretch move reference implementation
  • Zeus: Ensemble slice sampling reference implementation

williamjameshandley and others added 14 commits July 8, 2025 21:21
This commit adds a JAX-native implementation of the emcee stretch move
ensemble sampler to BlackJax, following the affine-invariant ensemble
sampling algorithm described in Goodman & Weare (2010).

Key features:
- Full JAX/JIT compatibility with vectorized operations
- Stateless functional design consistent with BlackJax architecture
- Red-blue ensemble update strategy for efficient parallel sampling
- Support for PyTree parameter structures
- Comprehensive test suite with convergence validation

New modules:
- blackjax/mcmc/ensemble.py: Core ensemble sampling infrastructure
- blackjax/mcmc/stretch.py: User-facing API for stretch move algorithm
- tests/mcmc/test_ensemble.py: Test suite for ensemble algorithms
- PLAN.md: Detailed implementation plan and technical analysis
- REVIEW.md: Comprehensive code review by senior reviewer

The implementation enables efficient ensemble MCMC sampling for
high-dimensional parameter spaces with minimal hand-tuning.

Code review summary: High-quality implementation with no methodological
errors. Minor suggestions for API improvements and robustness enhancements.
One critical recommendation: add validation test against reference emcee.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
…eatures

Correctness fixes:
- Fix PRNG key reuse bug in _update_half function
- Fix -∞ log probability handling to accept moves from -∞ to finite density

Enhancements for emcee compatibility:
- Add randomized red/blue split (randomize_split parameter, default True)
- Add warning when nwalkers < 2*ndim (suppressible with live_dangerously)

All tests pass and fixes verified with dedicated test cases.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Critical bug fixes:
- Fix PyTree handling with None blobs (avoid tree.map on None)
- Fix mask broadcasting for 1D PyTree leaves (add _masked_select helper)

Enhancements:
- Implement nsplits > 2 support (generalize red-blue to n-way split)
- Each group updated sequentially using all other groups as complementary
- Default nsplits=2 maintains emcee compatibility

All tests pass including:
- PyTree coords with None blobs
- 1D PyTree leaves
- nsplits=2 and nsplits=4
- PyTree with blobs

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
- Remove all validation and warnings for consistency with BlackJAX philosophy
  - Removed ValueError for nsplits and a parameter bounds
  - Removed nwalkers < 2*ndim warning
  - Removed live_dangerously parameter
- Replace jax.tree.map with jax.tree_util.tree_map throughout (13 instances)
- Remove all inline comments (~20 instances) - rely on docstrings and clear names
- Expand stretch_move docstring with full Parameters and Returns sections
- Improve variable naming (use _ for unused tree_flatten return value)
- Fix acceptance_rate type: use jnp.mean() instead of float() to avoid concretization errors
- Fix mypy type checking for step_fn signature

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
- Replaced generic ensemble API with stretch-specific API
  - as_top_level_api() now takes 'a' parameter instead of 'move_fn'
  - build_kernel() now takes optional 'a' parameter, defaults to stretch_move
- Renamed ensemble.py to stretch.py (single file now)
- Deleted redundant stretch.py wrapper
- Updated test imports from ensemble to stretch
- Removed unused imports from test file
- Kept EnsembleState/EnsembleInfo names (describe multi-walker state)

Simplifies codebase by eliminating unnecessary abstraction layer.
The "ensemble framework" was only used for stretch move.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
- File now named after sampling approach (ensemble), not move (stretch)
- Matches class names (EnsembleState, EnsembleInfo)
- Public API remains blackjax.stretch() via alias in mcmc/__init__.py
- Updated test imports to use blackjax.mcmc.ensemble

Rationale: "Ensemble" describes the fundamental multi-walker approach,
while "stretch" is just one possible proposal move. This naming better
reflects the algorithm's core concept and allows for future moves.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
- Public API is now blackjax.ensemble() instead of blackjax.stretch()
- Removed stretch alias from blackjax/mcmc/__init__.py
- Updated blackjax/__init__.py to import and expose ensemble
- Updated tests to use blackjax.ensemble()

Rationale: Consistent naming throughout - the module is ensemble.py,
the classes are EnsembleState/Info, and now the public API is
blackjax.ensemble(). "Ensemble" describes the multi-walker sampling
approach, while "stretch" is just the proposal mechanism.

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
Implements the zeus Ensemble Slice Sampling algorithm in BlackJAX's
ensemble module, following the architecture established in the refactoring.

Key features:
- Slice sampling with stepping-out and shrinking procedures
- DifferentialMove and RandomMove direction generators
- Robbins-Monro adaptive tuning of scale parameter mu
- No rejections (acceptance_rate = 1.0)
- Full JAX compatibility with jit/vmap support
- PyTree support for arbitrary position structures

Module structure:
- blackjax/ensemble/base.py: Shared EnsembleState/Info
- blackjax/ensemble/stretch.py: Stretch move (Metropolis-based)
- blackjax/ensemble/slice.py: Slice sampling (new)

Tests reorganized:
- tests/ensemble/test_stretch.py (4 tests)
- tests/ensemble/test_slice.py (8 tests)

Public API: blackjax.ensemble_slice(logdensity_fn, move="differential", ...)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
- Implement all medium-priority improvements from OpenAI review:
  - Expose move_fn parameter in as_top_level_api for custom moves
  - Optimize logprob evaluation (carry in while-loop state)
  - Add gaussian_direction with full PyTree support
  - Simplify docstrings to match BlackJAX house style
- Remove all inline comments and defensive checks
- Fix SliceEnsembleInfo docstring accuracy
- Add tests for gaussian direction (arrays and PyTrees)
- All 11 tests passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Move common ensemble manipulation patterns into shared utilities in base.py:
- Shuffle/unshuffle operations for randomized splitting
- Split/concatenate operations for group management
- Complementary ensemble building
- Masked selection and vmapped logdensity helpers

Refactor both slice and stretch kernels to use these utilities, reducing
duplication by ~120 lines total. Add Apache license header to slice.py
and unify acceptance_rate dtype computation.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Add additional utilities to base.py and use throughout:
- Use vmapped_logdensity in all init functions and blob re-evaluation
- Replace n_update calculation with get_nwalkers calls
- Add prepare_split helper to eliminate randomize-split boilerplate
- Add unshuffle_1d helper for 1D array unshuffling
- Add build_states_from_triples helper for state construction

Remove stretch._masked_select in favor of base.masked_select.

Reduces duplication by ~40 additional lines while maintaining full
test coverage.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
The one-liner arr[jnp.argsort(indices)] is already clear and doesn't
benefit from wrapping in a helper function.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Replace manual tuple vs non-tuple handling with the shared helper
for consistency with the rest of the codebase.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
@junpenglao
Copy link
Member

Wow this looks great! Thanks for working on it!

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +125 to +131
blobs = (
None
if all(b is None for b in blobs_list)
else jax.tree_util.tree_map(
lambda *xs: jnp.concatenate(xs, axis=0), *blobs_list
)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: unnest and rewrite it as a regular if...else... block

@codecov
Copy link

codecov bot commented Oct 29, 2025

Codecov Report

❌ Patch coverage is 96.83908% with 11 lines in your changes missing coverage. Please review.
✅ Project coverage is 98.44%. Comparing base (56df032) to head (6f8054d).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
blackjax/ensemble/slice.py 97.95% 4 Missing ⚠️
blackjax/ensemble/stretch.py 94.87% 4 Missing ⚠️
blackjax/ensemble/base.py 95.52% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #797      +/-   ##
==========================================
- Coverage   98.45%   98.44%   -0.02%     
==========================================
  Files          66       72       +6     
  Lines        3242     3719     +477     
==========================================
+ Hits         3192     3661     +469     
- Misses         50       58       +8     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants