Skip to content

Commit 34d23e1

Browse files
committed
Work on early tuning schedule
1 parent 47b01f4 commit 34d23e1

File tree

2 files changed

+65
-44
lines changed

2 files changed

+65
-44
lines changed

src/adapt_strategy.rs

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,16 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
181181
potential: &mut Self::Potential,
182182
state: &<Self::Potential as Hamiltonian>::State,
183183
) {
184-
self.exp_variance_draw.set_variance(iter::repeat(1f64));
184+
self.exp_variance_draw.set_variance(iter::repeat(0f64));
185185
self.exp_variance_draw.set_mean(state.q.iter().copied());
186186
self.exp_variance_grad
187187
.set_variance(state.grad.iter().map(|&val| {
188188
let diag = if !self.settings.grad_init {
189189
1f64
190190
} else {
191191
let out = val * val;
192-
let out = out.clamp(LOWER_LIMIT * LOWER_LIMIT, UPPER_LIMIT * UPPER_LIMIT);
193-
if (out == 0f64) | (!out.is_finite()) {
192+
let out = out.clamp(LOWER_LIMIT, UPPER_LIMIT);
193+
if !out.is_finite() {
194194
1f64
195195
} else {
196196
out
@@ -205,12 +205,13 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
205205
self.exp_variance_draw.current(),
206206
self.exp_variance_grad.current(),
207207
)
208-
.map(|(draw, grad)| {
209-
let val = (draw / grad).sqrt().clamp(LOWER_LIMIT, UPPER_LIMIT);
208+
.map(|(_draw, grad)| {
209+
//let val = (1f64 / grad).clamp(LOWER_LIMIT, UPPER_LIMIT);
210+
let val = (1f64 / grad).sqrt().clamp(LOWER_LIMIT, UPPER_LIMIT);
210211
if val.is_finite() {
211-
val
212+
Some(val)
212213
} else {
213-
1f64
214+
Some(1f64)
214215
}
215216
}),
216217
);
@@ -227,25 +228,44 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
227228
return;
228229
}
229230

231+
let is_early = (draw as f64) < self.settings.early_ratio * (self.num_tune as f64);
232+
233+
230234
let count = self.exp_variance_draw_bg.count();
231235

232-
let early_switch = (count == self.settings.early_window_switch_freq)
233-
& (draw < self.settings.window_switch_freq);
236+
let switch_freq = if is_early {
237+
self.settings.early_window_switch_freq
238+
} else {
239+
self.settings.window_switch_freq
240+
};
241+
242+
let variance_decay = if is_early {
243+
self.settings.early_variance_decay
244+
} else {
245+
self.settings.variance_decay
246+
};
234247

235-
if early_switch | ((draw % self.settings.window_switch_freq == 0) & (count > 5)) {
248+
let switch = count >= switch_freq;
249+
250+
if switch {
251+
assert!(count == switch_freq);
236252
self.exp_variance_draw = std::mem::replace(
237253
&mut self.exp_variance_draw_bg,
238-
ExpWeightedVariance::new(self.dim, self.settings.variance_decay, true),
254+
ExpWeightedVariance::new(self.dim, variance_decay, true),
239255
);
240256
self.exp_variance_grad = std::mem::replace(
241257
&mut self.exp_variance_grad_bg,
242-
ExpWeightedVariance::new(self.dim, self.settings.variance_decay, true),
258+
ExpWeightedVariance::new(self.dim, variance_decay, true),
243259
);
244260

245261
self.exp_variance_draw_bg
246262
.set_mean(collector.draw.iter().copied());
263+
self.exp_variance_draw_bg
264+
.set_variance(iter::repeat(0f64));
247265
self.exp_variance_grad_bg
248-
.set_mean(collector.grad.iter().copied());
266+
.set_mean(iter::repeat(0f64));
267+
self.exp_variance_grad_bg
268+
.set_variance(collector.grad.iter().map(|&x| x * x));
249269
} else if collector.is_good {
250270
self.exp_variance_draw
251271
.add_sample(collector.draw.iter().copied());
@@ -257,23 +277,23 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
257277
.add_sample(collector.grad.iter().copied());
258278
}
259279

280+
//if (is_early & (self.exp_variance_draw.count() > 2)) | (!is_early & switch) {
260281
if self.exp_variance_draw.count() > 2 {
261282
assert!(self.exp_variance_draw.count() == self.exp_variance_grad.count());
262-
if (self.settings.grad_init) | (draw > self.settings.window_switch_freq) {
263-
potential.mass_matrix.update_diag(
264-
izip!(
265-
self.exp_variance_draw.current(),
266-
self.exp_variance_grad.current(),
267-
)
268-
.map(|(draw, grad)| {
269-
let mut val = (draw / grad).sqrt().clamp(LOWER_LIMIT, UPPER_LIMIT);
270-
if !val.is_finite() {
271-
val = 1f64;
272-
}
273-
val
274-
}),
275-
);
276-
}
283+
potential.mass_matrix.update_diag(
284+
izip!(
285+
self.exp_variance_draw.current(),
286+
self.exp_variance_grad.current(),
287+
)
288+
.map(|(draw, grad)| {
289+
let val = (draw / grad).sqrt().clamp(LOWER_LIMIT, UPPER_LIMIT);
290+
if !val.is_finite() {
291+
None
292+
} else {
293+
Some(val)
294+
}
295+
}),
296+
);
277297
}
278298
}
279299

src/mass_matrix.rs

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ impl DiagMassMatrix {
3333
}
3434
}
3535

36-
pub(crate) fn update_diag(&mut self, new_variance: impl Iterator<Item = f64>) {
36+
pub(crate) fn update_diag(&mut self, new_variance: impl Iterator<Item = Option<f64>>) {
3737
update_diag(&mut self.variance, &mut self.inv_stds, new_variance);
3838
}
3939
}
@@ -44,13 +44,16 @@ impl DiagMassMatrix {
4444
fn update_diag(
4545
variance_out: &mut [f64],
4646
inv_std_out: &mut [f64],
47-
new_variance: impl Iterator<Item = f64>,
47+
new_variance: impl Iterator<Item = Option<f64>>,
4848
) {
49-
izip!(variance_out, inv_std_out, new_variance,).for_each(|(var, inv_std, x)| {
50-
assert!(x.is_finite(), "Illegal value on mass matrix: {}", x);
51-
assert!(x > 0f64, "Illegal value on mass matrix: {}", x);
52-
*var = x;
53-
*inv_std = (1. / x).sqrt();
49+
izip!(variance_out, inv_std_out, new_variance).for_each(|(var, inv_std, x)| {
50+
if let Some(x) = x {
51+
assert!(x.is_finite(), "Illegal value on mass matrix: {}", x);
52+
assert!(x > 0f64, "Illegal value on mass matrix: {}", x);
53+
//assert!(*var != x, "No change in mass matrix from {} to {}", *var, x);
54+
*var = x;
55+
*inv_std = (1. / x).sqrt();
56+
};
5457
});
5558
}
5659

@@ -135,7 +138,6 @@ fn add_sample(self_: &mut ExpWeightedVariance, value: impl Iterator<Item = f64>)
135138
// assert!(x - *mean != 0f64, "var = {}, mean = {}, x = {}, delta = {}, count = {}", var, mean, x, x - *mean, self_.count);
136139
//}
137140
let delta = x - *mean;
138-
//*mean += self_.alpha * delta;
139141
*mean = self_.alpha.mul_add(delta, *mean);
140142
*var = (1f64 - self_.alpha) * (*var + self_.alpha * delta * delta);
141143
},
@@ -165,6 +167,8 @@ pub struct DiagAdaptExpSettings {
165167
/// Switch to a new variance estimator every `window_switch_freq` draws.
166168
pub window_switch_freq: u64,
167169
pub early_window_switch_freq: u64,
170+
/// The ratio of the adaptation steps that is considered "early"
171+
pub early_ratio: f64,
168172
pub grad_init: bool,
169173
}
170174

@@ -174,9 +178,10 @@ impl Default for DiagAdaptExpSettings {
174178
variance_decay: 0.02,
175179
final_window: 50,
176180
store_mass_matrix: false,
177-
window_switch_freq: 50,
178-
early_window_switch_freq: 10,
181+
window_switch_freq: 200,
182+
early_window_switch_freq: 20,
179183
early_variance_decay: 0.1,
184+
early_ratio: 0.4,
180185
grad_init: true,
181186
}
182187
}
@@ -201,14 +206,10 @@ impl DrawGradCollector {
201206
impl Collector for DrawGradCollector {
202207
type State = State;
203208

204-
fn register_draw(&mut self, state: &Self::State, info: &crate::nuts::SampleInfo) {
209+
fn register_draw(&mut self, state: &Self::State, _info: &crate::nuts::SampleInfo) {
205210
self.draw.copy_from_slice(&state.q);
206211
self.grad.copy_from_slice(&state.grad);
207212
let idx = state.index_in_trajectory();
208-
if let Some(_) = info.divergence_info {
209-
self.is_good = (idx <= -4) | (idx >= 4);
210-
} else {
211-
self.is_good = idx != 0;
212-
}
213+
self.is_good = idx != 0;
213214
}
214215
}

0 commit comments

Comments
 (0)