Skip to content

Commit 8e847d1

Browse files
committed
address yet more edge cases with masked tokens, losses, and kmeans init
1 parent ff9363f commit 8e847d1

File tree

2 files changed

+41
-15
lines changed

2 files changed

+41
-15
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.26',
6+
version = '1.6.28',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,14 @@ def __init__(
310310
self.register_buffer('codebook_variance', torch.empty(num_codebooks, 1, dim))
311311

312312
@torch.jit.ignore
313-
def init_embed_(self, data):
313+
def init_embed_(self, data, mask = None):
314314
if self.initted:
315315
return
316316

317+
if exists(mask):
318+
c = data.shape[0]
319+
data = rearrange(data[mask], '(c n) d -> c n d', c = c)
320+
317321
embed, cluster_size = kmeans(
318322
data,
319323
self.codebook_size,
@@ -363,9 +367,8 @@ def update_affine(self, data, embed, mask = None):
363367
data = rearrange(data, 'h ... d -> h (...) d')
364368

365369
if exists(mask):
366-
h = data.shape[0]
367-
mask = repeat(mask, 'b n -> h (b n)', h = h)
368-
data = rearrange(data[mask], '(h n) d -> h n d', h = h)
370+
c = data.shape[0]
371+
data = rearrange(data[mask], '(c n) d -> c n d', c = c)
369372

370373
# calculate batch mean and variance
371374

@@ -440,7 +443,10 @@ def forward(
440443
dtype = x.dtype
441444
flatten, ps = pack_one(x, 'h * d')
442445

443-
self.init_embed_(flatten)
446+
if exists(mask):
447+
mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1]))
448+
449+
self.init_embed_(flatten, mask = mask)
444450

445451
if self.affine_param:
446452
self.update_affine(flatten, self.embed, mask = mask)
@@ -470,7 +476,6 @@ def forward(
470476
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
471477

472478
if exists(mask):
473-
mask = repeat(mask, 'b n -> h (b n)', h = flatten.shape[0])
474479
embed_onehot[~mask] = 0.
475480

476481
cluster_size = embed_onehot.sum(dim = 1)
@@ -552,10 +557,14 @@ def __init__(
552557
self.register_buffer('embed', embed)
553558

554559
@torch.jit.ignore
555-
def init_embed_(self, data):
560+
def init_embed_(self, data, mask = None):
556561
if self.initted:
557562
return
558563

564+
if exists(mask):
565+
c = data.shape[0]
566+
data = rearrange(data[mask], '(c n) d -> c n d', c = c)
567+
559568
embed, cluster_size = kmeans(
560569
data,
561570
self.codebook_size,
@@ -615,7 +624,10 @@ def forward(
615624

616625
flatten, ps = pack_one(x, 'h * d')
617626

618-
self.init_embed_(flatten)
627+
if exists(mask):
628+
mask = repeat(mask, 'b n -> c (b h n)', c = flatten.shape[0], h = flatten.shape[-2] // (mask.shape[0] * mask.shape[1]))
629+
630+
self.init_embed_(flatten, mask = mask)
619631

620632
embed = self.embed if self.learnable_codebook else self.embed.detach()
621633

@@ -632,7 +644,6 @@ def forward(
632644

633645
if self.training and self.ema_update:
634646
if exists(mask):
635-
mask = repeat(mask, 'b n -> h (b n)', h = flatten.shape[0])
636647
embed_onehot[~mask] = 0.
637648

638649
bins = embed_onehot.sum(dim = 1)
@@ -856,7 +867,20 @@ def forward(
856867
# one step in-place update
857868

858869
if should_inplace_optimize and self.training:
859-
F.mse_loss(quantize, x.detach()).backward()
870+
871+
if exists(mask):
872+
loss = F.mse_loss(quantize, x.detach(), reduction = 'none')
873+
874+
loss_mask = mask
875+
if is_multiheaded:
876+
loss_mask = repeat(mask, 'b n -> c (b h) n', c = loss.shape[0], h = loss.shape[1] // mask.shape[0])
877+
878+
loss = loss[loss_mask].mean()
879+
880+
else:
881+
loss = F.mse_loss(quantize, x.detach())
882+
883+
loss.backward()
860884
self.in_place_codebook_optimizer.step()
861885
self.in_place_codebook_optimizer.zero_grad()
862886

@@ -924,21 +948,23 @@ def calculate_ce_loss(codes):
924948
if self.commitment_weight > 0:
925949
if self.commitment_use_cross_entropy_loss:
926950
if exists(mask):
951+
ce_loss_mask = mask
927952
if is_multiheaded:
928-
mask_with_heads = repeat(mask, 'b n -> b n h', h = heads)
953+
ce_loss_mask = repeat(ce_loss_mask, 'b n -> b n h', h = heads)
929954

930-
embed_ind.masked_fill_(~mask_with_heads, -1)
955+
embed_ind.masked_fill_(~ce_loss_mask, -1)
931956

932957
commit_loss = calculate_ce_loss(embed_ind)
933958
else:
934959
if exists(mask):
935960
# with variable lengthed sequences
936961
commit_loss = F.mse_loss(commit_quantize, x, reduction = 'none')
937962

963+
loss_mask = mask
938964
if is_multiheaded:
939-
mask_with_heads = repeat(mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0])
965+
loss_mask = repeat(loss_mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0])
940966

941-
commit_loss = commit_loss[mask_with_heads].mean()
967+
commit_loss = commit_loss[loss_mask].mean()
942968
else:
943969
commit_loss = F.mse_loss(commit_quantize, x)
944970

0 commit comments

Comments
 (0)