Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
c1b3d8f
Implement emcee stretch move ensemble sampler
williamjameshandley Jul 8, 2025
526e5f5
Remove PLAN.md and REVIEW.md
williamjameshandley Oct 15, 2025
923f6ce
Fix ensemble sampler correctness issues and add emcee compatibility f…
williamjameshandley Oct 15, 2025
1032cf9
Fix remaining PyTree bugs and implement nsplits > 2 support
williamjameshandley Oct 15, 2025
7b57761
Refactor ensemble sampler for BlackJAX consistency
williamjameshandley Oct 15, 2025
d5b2745
Merge stretch.py into ensemble.py and rename to stretch.py
williamjameshandley Oct 15, 2025
d54b803
Rename stretch.py back to ensemble.py
williamjameshandley Oct 15, 2025
4503fed
Remove stretch alias and expose ensemble as public API
williamjameshandley Oct 15, 2025
86b10f7
Add Ensemble Slice Sampling (ESS) implementation
williamjameshandley Oct 15, 2025
925f468
Refine ensemble slice implementation and follow house style
williamjameshandley Oct 15, 2025
7921cea
Extract shared ensemble utilities to reduce code duplication
williamjameshandley Oct 15, 2025
fc977e3
Further reduce ensemble code duplication
williamjameshandley Oct 15, 2025
4db5872
Remove unnecessary unshuffle_1d helper
williamjameshandley Oct 15, 2025
8c247e8
Use vmapped_logdensity in stretch._update_half
williamjameshandley Oct 15, 2025
d299817
Merge branch 'main' into ensemble
junpenglao Oct 16, 2025
6f8054d
Merge branch 'main' into ensemble
junpenglao Oct 29, 2025
f0b9e09
Merge branch 'main' into ensemble
junpenglao Nov 15, 2025
18b0dd9
Merge branch 'main' into ensemble
junpenglao Nov 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .base import SamplingAlgorithm, VIAlgorithm
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
from .ensemble import slice as _ensemble_slice
from .ensemble import stretch as _ensemble
from .mcmc import adjusted_mclmc as _adjusted_mclmc
from .mcmc import adjusted_mclmc_dynamic as _adjusted_mclmc_dynamic
from .mcmc import barker
Expand Down Expand Up @@ -120,6 +122,8 @@ def generate_top_level_api_from(module):
elliptical_slice = generate_top_level_api_from(_elliptical_slice)
ghmc = generate_top_level_api_from(_ghmc)
barker_proposal = generate_top_level_api_from(barker)
ensemble = generate_top_level_api_from(_ensemble)
ensemble_slice = generate_top_level_api_from(_ensemble_slice)

hmc_family = [hmc, nuts]

Expand Down
18 changes: 18 additions & 0 deletions blackjax/ensemble/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ensemble sampling methods for BlackJAX."""
from blackjax.ensemble import slice, stretch
from blackjax.ensemble.base import EnsembleInfo, EnsembleState

__all__ = ["EnsembleState", "EnsembleInfo", "stretch", "slice"]
178 changes: 178 additions & 0 deletions blackjax/ensemble/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright 2020- The Blackjax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base classes and utilities for ensemble sampling methods."""
from typing import NamedTuple, Optional

import jax
import jax.numpy as jnp

from blackjax.types import Array, ArrayTree

__all__ = [
"EnsembleState",
"EnsembleInfo",
"get_nwalkers",
"tree_take",
"shuffle_triple",
"unshuffle_triple",
"split_triple",
"concat_triple_groups",
"complementary_triple",
"masked_select",
"vmapped_logdensity",
"prepare_split",
"build_states_from_triples",
]


class EnsembleState(NamedTuple):
"""State of an ensemble sampler.

coords
An array or PyTree of arrays of shape `(n_walkers, ...)` that
stores the current position of the walkers.
log_probs
An array of shape `(n_walkers,)` that stores the log-probability of
each walker.
blobs
An optional PyTree that stores metadata returned by the log-probability
function.
"""

coords: ArrayTree
log_probs: Array
blobs: Optional[ArrayTree] = None


class EnsembleInfo(NamedTuple):
"""Additional information on the ensemble transition.

acceptance_rate
The acceptance rate of the ensemble.
is_accepted
A boolean array of shape `(n_walkers,)` indicating whether each walker's
proposal was accepted.
"""

acceptance_rate: float
is_accepted: Array


def get_nwalkers(coords: ArrayTree) -> int:
"""Get the number of walkers from ensemble coordinates."""
return jax.tree_util.tree_flatten(coords)[0][0].shape[0]


def tree_take(tree: ArrayTree, idx: jnp.ndarray) -> ArrayTree:
"""Index into a PyTree along the leading dimension."""
return jax.tree_util.tree_map(lambda a: a[idx], tree)


def shuffle_triple(key, coords, log_probs, blobs):
"""Shuffle ensemble coordinates, log_probs, and blobs."""
n = get_nwalkers(coords)
idx = jax.random.permutation(key, n)
coords_s = tree_take(coords, idx)
log_probs_s = log_probs[idx]
blobs_s = None if blobs is None else tree_take(blobs, idx)
return coords_s, log_probs_s, blobs_s, idx


def unshuffle_triple(coords, log_probs, blobs, indices):
"""Reverse a shuffle operation on ensemble coordinates, log_probs, and blobs."""
inv = jnp.argsort(indices)
coords_u = tree_take(coords, inv)
log_probs_u = log_probs[inv]
blobs_u = None if blobs is None else tree_take(blobs, inv)
return coords_u, log_probs_u, blobs_u


def split_triple(coords, log_probs, blobs, nsplits):
"""Split ensemble into nsplits contiguous groups."""
n = get_nwalkers(coords)
group_size = n // nsplits
groups = []
for i in range(nsplits):
s = i * group_size
e = (i + 1) * group_size if i < nsplits - 1 else n
coords_i = jax.tree_util.tree_map(lambda a: a[s:e], coords)
log_probs_i = log_probs[s:e]
blobs_i = (
None if blobs is None else jax.tree_util.tree_map(lambda a: a[s:e], blobs)
)
groups.append((coords_i, log_probs_i, blobs_i))
return groups


def concat_triple_groups(group_triples):
"""Concatenate groups of (coords, log_probs, blobs) triples."""
coords_list, logp_list, blobs_list = zip(*group_triples)
coords = jax.tree_util.tree_map(
lambda *xs: jnp.concatenate(xs, axis=0), *coords_list
)
logp = jnp.concatenate(logp_list, axis=0)
blobs = (
None
if all(b is None for b in blobs_list)
else jax.tree_util.tree_map(
lambda *xs: jnp.concatenate(xs, axis=0), *blobs_list
)
)
Comment on lines +125 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: unnest and rewrite it as a regular if...else... block

return coords, logp, blobs


def complementary_triple(groups, i):
"""Build complementary ensemble from all groups except group i."""
return concat_triple_groups([g for j, g in enumerate(groups) if j != i])


def masked_select(mask, new_val, old_val):
"""Select between new and old values based on mask."""
expand_dims = (1,) * (new_val.ndim - 1)
mask_expanded = mask.reshape((mask.shape[0],) + expand_dims)
return jnp.where(mask_expanded, new_val, old_val)


def vmapped_logdensity(logdensity_fn, coords):
"""Evaluate logdensity function on ensemble coordinates with vmap."""
outs = jax.vmap(logdensity_fn)(coords)
return outs if isinstance(outs, tuple) else (outs, None)


def prepare_split(rng_key, coords, log_probs, blobs, randomize_split, nsplits):
"""Prepare ensemble for splitting into groups.

Handles optional randomization, splitting, and returns components needed
for group-wise updates and subsequent unshuffling.
"""
if randomize_split:
key_shuffle, key_update = jax.random.split(rng_key)
coords_s, logp_s, blobs_s, indices = shuffle_triple(
key_shuffle, coords, log_probs, blobs
)
else:
key_update = rng_key
coords_s, logp_s, blobs_s = coords, log_probs, blobs
indices = jnp.arange(get_nwalkers(coords))
group_triples = split_triple(coords_s, logp_s, blobs_s, nsplits)
return key_update, group_triples, indices


def build_states_from_triples(group_triples, state_ctor, extra_fields=()):
"""Build state objects from triples with optional extra fields.

Handles both base EnsembleState and algorithm-specific states like
SliceEnsembleState that have additional fields.
"""
return [state_ctor(t[0], t[1], t[2], *extra_fields) for t in group_triples]
Loading