@@ -181,16 +181,16 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
181181 potential : & mut Self :: Potential ,
182182 state : & <Self :: Potential as Hamiltonian >:: State ,
183183 ) {
184- self . exp_variance_draw . set_variance ( iter:: repeat ( 1f64 ) ) ;
184+ self . exp_variance_draw . set_variance ( iter:: repeat ( 0f64 ) ) ;
185185 self . exp_variance_draw . set_mean ( state. q . iter ( ) . copied ( ) ) ;
186186 self . exp_variance_grad
187187 . set_variance ( state. grad . iter ( ) . map ( |& val| {
188188 let diag = if !self . settings . grad_init {
189189 1f64
190190 } else {
191191 let out = val * val;
192- let out = out. clamp ( LOWER_LIMIT * LOWER_LIMIT , UPPER_LIMIT * UPPER_LIMIT ) ;
193- if ( out == 0f64 ) | ( !out. is_finite ( ) ) {
192+ let out = out. clamp ( LOWER_LIMIT , UPPER_LIMIT ) ;
193+ if !out. is_finite ( ) {
194194 1f64
195195 } else {
196196 out
@@ -205,12 +205,13 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
205205 self . exp_variance_draw. current( ) ,
206206 self . exp_variance_grad. current( ) ,
207207 )
208- . map ( |( draw, grad) | {
209- let val = ( draw / grad) . sqrt ( ) . clamp ( LOWER_LIMIT , UPPER_LIMIT ) ;
208+ . map ( |( _draw, grad) | {
209+ //let val = (1f64 / grad).clamp(LOWER_LIMIT, UPPER_LIMIT);
210+ let val = ( 1f64 / grad) . sqrt ( ) . clamp ( LOWER_LIMIT , UPPER_LIMIT ) ;
210211 if val. is_finite ( ) {
211- val
212+ Some ( val)
212213 } else {
213- 1f64
214+ Some ( 1f64 )
214215 }
215216 } ) ,
216217 ) ;
@@ -227,25 +228,44 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
227228 return ;
228229 }
229230
231+ let is_early = ( draw as f64 ) < self . settings . early_ratio * ( self . num_tune as f64 ) ;
232+
233+
230234 let count = self . exp_variance_draw_bg . count ( ) ;
231235
232- let early_switch = ( count == self . settings . early_window_switch_freq )
233- & ( draw < self . settings . window_switch_freq ) ;
236+ let switch_freq = if is_early {
237+ self . settings . early_window_switch_freq
238+ } else {
239+ self . settings . window_switch_freq
240+ } ;
241+
242+ let variance_decay = if is_early {
243+ self . settings . early_variance_decay
244+ } else {
245+ self . settings . variance_decay
246+ } ;
234247
235- if early_switch | ( ( draw % self . settings . window_switch_freq == 0 ) & ( count > 5 ) ) {
248+ let switch = count >= switch_freq;
249+
250+ if switch {
251+ assert ! ( count == switch_freq) ;
236252 self . exp_variance_draw = std:: mem:: replace (
237253 & mut self . exp_variance_draw_bg ,
238- ExpWeightedVariance :: new ( self . dim , self . settings . variance_decay , true ) ,
254+ ExpWeightedVariance :: new ( self . dim , variance_decay, true ) ,
239255 ) ;
240256 self . exp_variance_grad = std:: mem:: replace (
241257 & mut self . exp_variance_grad_bg ,
242- ExpWeightedVariance :: new ( self . dim , self . settings . variance_decay , true ) ,
258+ ExpWeightedVariance :: new ( self . dim , variance_decay, true ) ,
243259 ) ;
244260
245261 self . exp_variance_draw_bg
246262 . set_mean ( collector. draw . iter ( ) . copied ( ) ) ;
263+ self . exp_variance_draw_bg
264+ . set_variance ( iter:: repeat ( 0f64 ) ) ;
247265 self . exp_variance_grad_bg
248- . set_mean ( collector. grad . iter ( ) . copied ( ) ) ;
266+ . set_mean ( iter:: repeat ( 0f64 ) ) ;
267+ self . exp_variance_grad_bg
268+ . set_variance ( collector. grad . iter ( ) . map ( |& x| x * x) ) ;
249269 } else if collector. is_good {
250270 self . exp_variance_draw
251271 . add_sample ( collector. draw . iter ( ) . copied ( ) ) ;
@@ -257,23 +277,23 @@ impl<F: CpuLogpFunc> AdaptStrategy for ExpWindowDiagAdapt<F> {
257277 . add_sample ( collector. grad . iter ( ) . copied ( ) ) ;
258278 }
259279
280+ //if (is_early & (self.exp_variance_draw.count() > 2)) | (!is_early & switch) {
260281 if self . exp_variance_draw . count ( ) > 2 {
261282 assert ! ( self . exp_variance_draw. count( ) == self . exp_variance_grad. count( ) ) ;
262- if ( self . settings . grad_init ) | ( draw > self . settings . window_switch_freq ) {
263- potential. mass_matrix . update_diag (
264- izip ! (
265- self . exp_variance_draw. current( ) ,
266- self . exp_variance_grad. current( ) ,
267- )
268- . map ( |( draw, grad) | {
269- let mut val = ( draw / grad) . sqrt ( ) . clamp ( LOWER_LIMIT , UPPER_LIMIT ) ;
270- if !val. is_finite ( ) {
271- val = 1f64 ;
272- }
273- val
274- } ) ,
275- ) ;
276- }
283+ potential. mass_matrix . update_diag (
284+ izip ! (
285+ self . exp_variance_draw. current( ) ,
286+ self . exp_variance_grad. current( ) ,
287+ )
288+ . map ( |( draw, grad) | {
289+ let val = ( draw / grad) . sqrt ( ) . clamp ( LOWER_LIMIT , UPPER_LIMIT ) ;
290+ if !val. is_finite ( ) {
291+ None
292+ } else {
293+ Some ( val)
294+ }
295+ } ) ,
296+ ) ;
277297 }
278298 }
279299
0 commit comments