Skip to content

Commit f0a3741

Browse files
committed
Use diverging draws for mass matrix adapt
1 parent 52e303a commit f0a3741

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "nuts-rs"
3-
version = "0.3.0"
3+
version = "0.4.0"
44
authors = ["Adrian Seyboldt <adrian.seyboldt@gmail.com>"]
55
edition = "2021"
66
license = "MIT"

src/mass_matrix.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,14 @@ impl DrawGradCollector {
143143
impl Collector for DrawGradCollector {
144144
type State = State;
145145

146-
fn register_draw(&mut self, state: &Self::State, _info: &crate::nuts::SampleInfo) {
146+
fn register_draw(&mut self, state: &Self::State, info: &crate::nuts::SampleInfo) {
147147
self.draw.copy_from_slice(&state.q);
148148
self.grad.copy_from_slice(&state.grad);
149149
let idx = state.index_in_trajectory();
150-
self.is_good = _info.divergence_info.is_none() & (idx != 0);
150+
if info.divergence_info.is_some() {
151+
self.is_good = idx.abs() > 4;
152+
} else {
153+
self.is_good = idx != 0;
154+
}
151155
}
152156
}

0 commit comments

Comments
 (0)