Skip to content

Commit 1e4dcc9

Browse files
committed
handle some more edge cases with variable lengthed sequences
1 parent 5a36b94 commit 1e4dcc9

File tree

2 files changed

+48
-13
lines changed

2 files changed

+48
-13
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.24',
6+
version = '1.6.25',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -345,17 +345,30 @@ def update_with_decay(self, buffer_name, new_value, decay):
345345
self.register_buffer(buffer_name, value)
346346

347347
@torch.jit.ignore
348-
def update_affine(self, data, embed):
348+
def update_affine(self, data, embed, mask = None):
349349
assert self.affine_param
350350

351351
var_fn = partial(torch.var, unbiased = False)
352352

353-
data, embed = map(lambda t: rearrange(t, 'h ... d -> h (...) d'), (data, embed))
353+
# calculate codebook mean and variance
354+
355+
embed = rearrange(embed, 'h ... d -> h (...) d')
354356

355357
if self.training:
356358
self.update_with_decay('codebook_mean', reduce(embed, 'h n d -> h 1 d', 'mean'), self.affine_param_codebook_decay)
357359
self.update_with_decay('codebook_variance', reduce(embed, 'h n d -> h 1 d', var_fn), self.affine_param_codebook_decay)
358360

361+
# prepare batch data, which depends on whether it has masking
362+
363+
data = rearrange(data, 'h ... d -> h (...) d')
364+
365+
if exists(mask):
366+
h = data.shape[0]
367+
mask = repeat(mask, 'b n -> h (b n)', h = h)
368+
data = rearrange(data[mask], '(h n) d -> h n d', h = h)
369+
370+
# calculate batch mean and variance
371+
359372
if not self.sync_affine_param:
360373
self.update_with_decay('batch_mean', reduce(data, 'h n d -> h 1 d', 'mean'), self.affine_param_batch_decay)
361374
self.update_with_decay('batch_variance', reduce(data, 'h n d -> h 1 d', var_fn), self.affine_param_batch_decay)
@@ -413,7 +426,8 @@ def expire_codes_(self, batch_samples):
413426
def forward(
414427
self,
415428
x,
416-
sample_codebook_temp = None
429+
sample_codebook_temp = None,
430+
mask = None
417431
):
418432
needs_codebook_dim = x.ndim < 4
419433
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
@@ -429,7 +443,7 @@ def forward(
429443
self.init_embed_(flatten)
430444

431445
if self.affine_param:
432-
self.update_affine(flatten, self.embed)
446+
self.update_affine(flatten, self.embed, mask = mask)
433447

434448
embed = self.embed if self.learnable_codebook else self.embed.detach()
435449

@@ -582,7 +596,8 @@ def expire_codes_(self, batch_samples):
582596
def forward(
583597
self,
584598
x,
585-
sample_codebook_temp = None
599+
sample_codebook_temp = None,
600+
mask = None
586601
):
587602
needs_codebook_dim = x.ndim < 4
588603
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
@@ -783,9 +798,12 @@ def forward(
783798
mask = None,
784799
sample_codebook_temp = None
785800
):
801+
orig_input = x
802+
786803
only_one = x.ndim == 2
787804

788805
if only_one:
806+
assert not exists(mask)
789807
x = rearrange(x, 'b d -> b 1 d')
790808

791809
shape, device, heads, is_multiheaded, codebook_size, return_loss = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size, exists(indices)
@@ -816,9 +834,16 @@ def forward(
816834

817835
x = self._codebook.transform_input(x)
818836

837+
# codebook forward kwargs
838+
839+
codebook_forward_kwargs = dict(
840+
sample_codebook_temp = sample_codebook_temp,
841+
mask = mask
842+
)
843+
819844
# quantize
820845

821-
quantize, embed_ind, distances = self._codebook(x, sample_codebook_temp = sample_codebook_temp)
846+
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
822847

823848
# one step in-place update
824849

@@ -827,8 +852,9 @@ def forward(
827852
self.in_place_codebook_optimizer.step()
828853
self.in_place_codebook_optimizer.zero_grad()
829854

830-
# Quantize again
831-
quantize, embed_ind, distances = self._codebook(x, sample_codebook_temp = sample_codebook_temp)
855+
# quantize again
856+
857+
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
832858

833859
if self.training:
834860
# determine code to use for commitment loss
@@ -891,9 +917,9 @@ def calculate_ce_loss(codes):
891917
if self.commitment_use_cross_entropy_loss:
892918
if exists(mask):
893919
if is_multiheaded:
894-
mask = repeat(mask, 'b n -> b n h', h = heads)
920+
mask_with_heads = repeat(mask, 'b n -> b n h', h = heads)
895921

896-
embed_ind.masked_fill_(~mask, -1)
922+
embed_ind.masked_fill_(~mask_with_heads, -1)
897923

898924
commit_loss = calculate_ce_loss(embed_ind)
899925
else:
@@ -902,9 +928,9 @@ def calculate_ce_loss(codes):
902928
commit_loss = F.mse_loss(commit_quantize, x, reduction = 'none')
903929

904930
if is_multiheaded:
905-
mask = repeat(mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0])
931+
mask_with_heads = repeat(mask, 'b n -> c (b h) n', c = commit_loss.shape[0], h = commit_loss.shape[1] // mask.shape[0])
906932

907-
commit_loss = commit_loss[mask].mean()
933+
commit_loss = commit_loss[mask_with_heads].mean()
908934
else:
909935
commit_loss = F.mse_loss(commit_quantize, x)
910936

@@ -952,4 +978,13 @@ def calculate_ce_loss(codes):
952978
if only_one:
953979
quantize = rearrange(quantize, 'b 1 d -> b d')
954980

981+
# if masking, only return quantized for where mask has True
982+
983+
if exists(mask):
984+
quantize = torch.where(
985+
rearrange(mask, '... -> ... 1'),
986+
quantize,
987+
orig_input
988+
)
989+
955990
return quantize, embed_ind, loss

0 commit comments

Comments
 (0)