Skip to content

Commit 1fcd6b2

Browse files
committed
want to update batch stats during eval
1 parent 6d415ef commit 1fcd6b2

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,9 @@ def update_affine(self, data, embed):
352352

353353
data, embed = map(lambda t: rearrange(t, 'h ... d -> h (...) d'), (data, embed))
354354

355-
self.update_with_decay('codebook_mean', reduce(embed, 'h n d -> h 1 d', 'mean'), self.affine_param_codebook_decay)
356-
self.update_with_decay('codebook_variance', reduce(embed, 'h n d -> h 1 d', var_fn), self.affine_param_codebook_decay)
355+
if self.training:
356+
self.update_with_decay('codebook_mean', reduce(embed, 'h n d -> h 1 d', 'mean'), self.affine_param_codebook_decay)
357+
self.update_with_decay('codebook_variance', reduce(embed, 'h n d -> h 1 d', var_fn), self.affine_param_codebook_decay)
357358

358359
if not self.sync_affine_param:
359360
self.update_with_decay('batch_mean', reduce(data, 'h n d -> h 1 d', 'mean'), self.affine_param_batch_decay)
@@ -427,7 +428,7 @@ def forward(
427428

428429
self.init_embed_(flatten)
429430

430-
if self.affine_param and self.training:
431+
if self.affine_param:
431432
self.update_affine(flatten, self.embed)
432433

433434
embed = self.embed if self.learnable_codebook else self.embed.detach()

0 commit comments

Comments
 (0)