-
Notifications
You must be signed in to change notification settings - Fork 123
Add ensemble sampling methods to BlackJAX #797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
williamjameshandley
wants to merge
18
commits into
blackjax-devs:main
Choose a base branch
from
handley-lab:ensemble
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
c1b3d8f
Implement emcee stretch move ensemble sampler
williamjameshandley 526e5f5
Remove PLAN.md and REVIEW.md
williamjameshandley 923f6ce
Fix ensemble sampler correctness issues and add emcee compatibility f…
williamjameshandley 1032cf9
Fix remaining PyTree bugs and implement nsplits > 2 support
williamjameshandley 7b57761
Refactor ensemble sampler for BlackJAX consistency
williamjameshandley d5b2745
Merge stretch.py into ensemble.py and rename to stretch.py
williamjameshandley d54b803
Rename stretch.py back to ensemble.py
williamjameshandley 4503fed
Remove stretch alias and expose ensemble as public API
williamjameshandley 86b10f7
Add Ensemble Slice Sampling (ESS) implementation
williamjameshandley 925f468
Refine ensemble slice implementation and follow house style
williamjameshandley 7921cea
Extract shared ensemble utilities to reduce code duplication
williamjameshandley fc977e3
Further reduce ensemble code duplication
williamjameshandley 4db5872
Remove unnecessary unshuffle_1d helper
williamjameshandley 8c247e8
Use vmapped_logdensity in stretch._update_half
williamjameshandley d299817
Merge branch 'main' into ensemble
junpenglao 6f8054d
Merge branch 'main' into ensemble
junpenglao f0b9e09
Merge branch 'main' into ensemble
junpenglao 18b0dd9
Merge branch 'main' into ensemble
junpenglao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,18 @@ | ||
| # Copyright 2020- The Blackjax Authors. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Ensemble sampling methods for BlackJAX.""" | ||
| from blackjax.ensemble import slice, stretch | ||
| from blackjax.ensemble.base import EnsembleInfo, EnsembleState | ||
|
|
||
| __all__ = ["EnsembleState", "EnsembleInfo", "stretch", "slice"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,178 @@ | ||
| # Copyright 2020- The Blackjax Authors. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Base classes and utilities for ensemble sampling methods.""" | ||
| from typing import NamedTuple, Optional | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| from blackjax.types import Array, ArrayTree | ||
|
|
||
| __all__ = [ | ||
| "EnsembleState", | ||
| "EnsembleInfo", | ||
| "get_nwalkers", | ||
| "tree_take", | ||
| "shuffle_triple", | ||
| "unshuffle_triple", | ||
| "split_triple", | ||
| "concat_triple_groups", | ||
| "complementary_triple", | ||
| "masked_select", | ||
| "vmapped_logdensity", | ||
| "prepare_split", | ||
| "build_states_from_triples", | ||
| ] | ||
|
|
||
|
|
||
| class EnsembleState(NamedTuple): | ||
| """State of an ensemble sampler. | ||
|
|
||
| coords | ||
| An array or PyTree of arrays of shape `(n_walkers, ...)` that | ||
| stores the current position of the walkers. | ||
| log_probs | ||
| An array of shape `(n_walkers,)` that stores the log-probability of | ||
| each walker. | ||
| blobs | ||
| An optional PyTree that stores metadata returned by the log-probability | ||
| function. | ||
| """ | ||
|
|
||
| coords: ArrayTree | ||
| log_probs: Array | ||
| blobs: Optional[ArrayTree] = None | ||
|
|
||
|
|
||
| class EnsembleInfo(NamedTuple): | ||
| """Additional information on the ensemble transition. | ||
|
|
||
| acceptance_rate | ||
| The acceptance rate of the ensemble. | ||
| is_accepted | ||
| A boolean array of shape `(n_walkers,)` indicating whether each walker's | ||
| proposal was accepted. | ||
| """ | ||
|
|
||
| acceptance_rate: float | ||
| is_accepted: Array | ||
|
|
||
|
|
||
| def get_nwalkers(coords: ArrayTree) -> int: | ||
| """Get the number of walkers from ensemble coordinates.""" | ||
| return jax.tree_util.tree_flatten(coords)[0][0].shape[0] | ||
|
|
||
|
|
||
| def tree_take(tree: ArrayTree, idx: jnp.ndarray) -> ArrayTree: | ||
| """Index into a PyTree along the leading dimension.""" | ||
| return jax.tree_util.tree_map(lambda a: a[idx], tree) | ||
|
|
||
|
|
||
| def shuffle_triple(key, coords, log_probs, blobs): | ||
| """Shuffle ensemble coordinates, log_probs, and blobs.""" | ||
| n = get_nwalkers(coords) | ||
| idx = jax.random.permutation(key, n) | ||
| coords_s = tree_take(coords, idx) | ||
| log_probs_s = log_probs[idx] | ||
| blobs_s = None if blobs is None else tree_take(blobs, idx) | ||
| return coords_s, log_probs_s, blobs_s, idx | ||
|
|
||
|
|
||
| def unshuffle_triple(coords, log_probs, blobs, indices): | ||
| """Reverse a shuffle operation on ensemble coordinates, log_probs, and blobs.""" | ||
| inv = jnp.argsort(indices) | ||
| coords_u = tree_take(coords, inv) | ||
| log_probs_u = log_probs[inv] | ||
| blobs_u = None if blobs is None else tree_take(blobs, inv) | ||
| return coords_u, log_probs_u, blobs_u | ||
|
|
||
|
|
||
| def split_triple(coords, log_probs, blobs, nsplits): | ||
| """Split ensemble into nsplits contiguous groups.""" | ||
| n = get_nwalkers(coords) | ||
| group_size = n // nsplits | ||
| groups = [] | ||
| for i in range(nsplits): | ||
| s = i * group_size | ||
| e = (i + 1) * group_size if i < nsplits - 1 else n | ||
| coords_i = jax.tree_util.tree_map(lambda a: a[s:e], coords) | ||
| log_probs_i = log_probs[s:e] | ||
| blobs_i = ( | ||
| None if blobs is None else jax.tree_util.tree_map(lambda a: a[s:e], blobs) | ||
| ) | ||
| groups.append((coords_i, log_probs_i, blobs_i)) | ||
| return groups | ||
|
|
||
|
|
||
| def concat_triple_groups(group_triples): | ||
| """Concatenate groups of (coords, log_probs, blobs) triples.""" | ||
| coords_list, logp_list, blobs_list = zip(*group_triples) | ||
| coords = jax.tree_util.tree_map( | ||
| lambda *xs: jnp.concatenate(xs, axis=0), *coords_list | ||
| ) | ||
| logp = jnp.concatenate(logp_list, axis=0) | ||
| blobs = ( | ||
| None | ||
| if all(b is None for b in blobs_list) | ||
| else jax.tree_util.tree_map( | ||
| lambda *xs: jnp.concatenate(xs, axis=0), *blobs_list | ||
| ) | ||
| ) | ||
| return coords, logp, blobs | ||
|
|
||
|
|
||
| def complementary_triple(groups, i): | ||
| """Build complementary ensemble from all groups except group i.""" | ||
| return concat_triple_groups([g for j, g in enumerate(groups) if j != i]) | ||
|
|
||
|
|
||
| def masked_select(mask, new_val, old_val): | ||
| """Select between new and old values based on mask.""" | ||
| expand_dims = (1,) * (new_val.ndim - 1) | ||
| mask_expanded = mask.reshape((mask.shape[0],) + expand_dims) | ||
| return jnp.where(mask_expanded, new_val, old_val) | ||
|
|
||
|
|
||
| def vmapped_logdensity(logdensity_fn, coords): | ||
| """Evaluate logdensity function on ensemble coordinates with vmap.""" | ||
| outs = jax.vmap(logdensity_fn)(coords) | ||
| return outs if isinstance(outs, tuple) else (outs, None) | ||
|
|
||
|
|
||
| def prepare_split(rng_key, coords, log_probs, blobs, randomize_split, nsplits): | ||
| """Prepare ensemble for splitting into groups. | ||
|
|
||
| Handles optional randomization, splitting, and returns components needed | ||
| for group-wise updates and subsequent unshuffling. | ||
| """ | ||
| if randomize_split: | ||
| key_shuffle, key_update = jax.random.split(rng_key) | ||
| coords_s, logp_s, blobs_s, indices = shuffle_triple( | ||
| key_shuffle, coords, log_probs, blobs | ||
| ) | ||
| else: | ||
| key_update = rng_key | ||
| coords_s, logp_s, blobs_s = coords, log_probs, blobs | ||
| indices = jnp.arange(get_nwalkers(coords)) | ||
| group_triples = split_triple(coords_s, logp_s, blobs_s, nsplits) | ||
| return key_update, group_triples, indices | ||
|
|
||
|
|
||
| def build_states_from_triples(group_triples, state_ctor, extra_fields=()): | ||
| """Build state objects from triples with optional extra fields. | ||
|
|
||
| Handles both base EnsembleState and algorithm-specific states like | ||
| SliceEnsembleState that have additional fields. | ||
| """ | ||
| return [state_ctor(t[0], t[1], t[2], *extra_fields) for t in group_triples] | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: unnest and rewrite it as a regular if...else... block