Skip to content

Commit bb9e878

Browse files
committed
an attempt to address #142
1 parent fea8f22 commit bb9e878

File tree

2 files changed

+3
-9
lines changed

2 files changed

+3
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.40"
3+
version = "1.14.41"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -430,10 +430,7 @@ def update_affine(self, data, embed, mask = None):
430430
self.update_with_decay('batch_variance', batch_variance, self.affine_param_batch_decay)
431431

432432
def replace(self, batch_samples, batch_mask):
433-
for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))):
434-
if not torch.any(mask):
435-
continue
436-
433+
for ind, (samples, mask) in enumerate(zip(batch_samples, batch_mask)):
437434
sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
438435
sampled = rearrange(sampled, '1 ... -> ...')
439436

@@ -619,10 +616,7 @@ def init_embed_(self, data, mask = None):
619616
def replace(self, batch_samples, batch_mask):
620617
batch_samples = l2norm(batch_samples)
621618

622-
for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))):
623-
if not torch.any(mask):
624-
continue
625-
619+
for ind, (samples, mask) in enumerate(zip(batch_samples, batch_mask)):
626620
sampled = self.replace_sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
627621
sampled = rearrange(sampled, '1 ... -> ...')
628622

0 commit comments

Comments
 (0)