Skip to content

Commit db1c1d9

Browse files
authored
Fix updating moving average of embed
1 parent ecf2f7c commit db1c1d9

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ def forward(self, x):
295295

296296
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
297297
self.all_reduce_fn(embed_sum.contiguous())
298+
ema_inplace(self.embed_avg, embed_sum, self.decay)
298299

299300
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()
300301

0 commit comments

Comments
 (0)