Skip to content

Commit ff9363f

Browse files
committed
when doing ema update in the presence of masked tokens, make sure to omit their contribution
1 parent 1e4dcc9 commit ff9363f

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.25',
6+
version = '1.6.26',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,10 @@ def forward(
469469
if self.affine_param:
470470
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
471471

472+
if exists(mask):
473+
mask = repeat(mask, 'b n -> h (b n)', h = flatten.shape[0])
474+
embed_onehot[~mask] = 0.
475+
472476
cluster_size = embed_onehot.sum(dim = 1)
473477

474478
self.all_reduce_fn(cluster_size)
@@ -627,6 +631,10 @@ def forward(
627631
quantize = batched_embedding(embed_ind, embed)
628632

629633
if self.training and self.ema_update:
634+
if exists(mask):
635+
mask = repeat(mask, 'b n -> h (b n)', h = flatten.shape[0])
636+
embed_onehot[~mask] = 0.
637+
630638
bins = embed_onehot.sum(dim = 1)
631639
self.all_reduce_fn(bins)
632640

0 commit comments

Comments
 (0)