Skip to content

Commit 8dc0b71

Browse files
committed
handle variable sequence lengths for cross entropy commitment loss
1 parent e190b8e commit 8dc0b71

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.5.7',
6+
version = '1.5.8',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -602,9 +602,6 @@ def forward(
602602
ein_rhs_eq = 'h b n d' if self.separate_codebook_per_head else '1 (b h) n d'
603603
x = rearrange(x, f'b n (h d) -> {ein_rhs_eq}', h = heads)
604604

605-
if exists(mask):
606-
mask = repeat(mask, 'b n -> h b n', h = heads)
607-
608605
# quantize
609606

610607
quantize, embed_ind, distances = self._codebook(x)
@@ -657,13 +654,23 @@ def calculate_ce_loss(codes):
657654
if self.training:
658655
if self.commitment_weight > 0:
659656
if self.commitment_use_cross_entropy_loss:
657+
if exists(mask):
658+
if is_multiheaded:
659+
mask = repeat(mask, 'b n -> b n h', h = heads)
660+
661+
embed_ind.masked_fill_(~mask, -1)
662+
660663
commit_loss = calculate_ce_loss(embed_ind)
661664
else:
662665
detached_quantize = quantize.detach()
663666

664667
if exists(mask):
665668
# with variable lengthed sequences
666669
commit_loss = F.mse_loss(detached_quantize, x, reduction = 'none')
670+
671+
if is_multiheaded:
672+
mask = repeat(mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0])
673+
667674
commit_loss = commit_loss[mask].mean()
668675
else:
669676
commit_loss = F.mse_loss(detached_quantize, x)

0 commit comments

Comments
 (0)