File tree Expand file tree Collapse file tree 2 files changed +9
-1
lines changed
Expand file tree Collapse file tree 2 files changed +9
-1
lines changed Original file line number Diff line number Diff line change 33setup (
44 name = 'vector_quantize_pytorch' ,
55 packages = find_packages (),
6- version = '1.6.25 ' ,
6+ version = '1.6.26 ' ,
77 license = 'MIT' ,
88 description = 'Vector Quantization - Pytorch' ,
99 long_description_content_type = 'text/markdown' ,
Original file line number Diff line number Diff line change @@ -469,6 +469,10 @@ def forward(
469469 if self .affine_param :
470470 flatten = (flatten - self .batch_mean ) * (codebook_std / batch_std ) + self .codebook_mean
471471
472+ if exists (mask ):
473+ mask = repeat (mask , 'b n -> h (b n)' , h = flatten .shape [0 ])
474+ embed_onehot [~ mask ] = 0.
475+
472476 cluster_size = embed_onehot .sum (dim = 1 )
473477
474478 self .all_reduce_fn (cluster_size )
@@ -627,6 +631,10 @@ def forward(
627631 quantize = batched_embedding (embed_ind , embed )
628632
629633 if self .training and self .ema_update :
634+ if exists (mask ):
635+ mask = repeat (mask , 'b n -> h (b n)' , h = flatten .shape [0 ])
636+ embed_onehot [~ mask ] = 0.
637+
630638 bins = embed_onehot .sum (dim = 1 )
631639 self .all_reduce_fn (bins )
632640
You can’t perform that action at this time.
0 commit comments