@@ -352,8 +352,9 @@ def update_affine(self, data, embed):
352352
353353 data , embed = map (lambda t : rearrange (t , 'h ... d -> h (...) d' ), (data , embed ))
354354
355- self .update_with_decay ('codebook_mean' , reduce (embed , 'h n d -> h 1 d' , 'mean' ), self .affine_param_codebook_decay )
356- self .update_with_decay ('codebook_variance' , reduce (embed , 'h n d -> h 1 d' , var_fn ), self .affine_param_codebook_decay )
355+ if self .training :
356+ self .update_with_decay ('codebook_mean' , reduce (embed , 'h n d -> h 1 d' , 'mean' ), self .affine_param_codebook_decay )
357+ self .update_with_decay ('codebook_variance' , reduce (embed , 'h n d -> h 1 d' , var_fn ), self .affine_param_codebook_decay )
357358
358359 if not self .sync_affine_param :
359360 self .update_with_decay ('batch_mean' , reduce (data , 'h n d -> h 1 d' , 'mean' ), self .affine_param_batch_decay )
@@ -427,7 +428,7 @@ def forward(
427428
428429 self .init_embed_ (flatten )
429430
430- if self .affine_param and self . training :
431+ if self .affine_param :
431432 self .update_affine (flatten , self .embed )
432433
433434 embed = self .embed if self .learnable_codebook else self .embed .detach ()
0 commit comments