Skip to content

Commit cfcb4ee

Browse files
committed
fix an issue with multi-headed codebooks and reduction in cluster sizes for laplace smoothing
1 parent 03da009 commit cfcb4ee

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.5.0',
6+
version = '1.5.1',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ def gumbel_sample(t, temperature = 1., dim = -1):
4343

4444
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
4545

46-
def laplace_smoothing(x, n_categories, eps = 1e-5):
47-
return (x + eps) / (x.sum() + n_categories * eps)
46+
def laplace_smoothing(x, n_categories, eps = 1e-5, dim = -1):
47+
denom = x.sum(dim = dim, keepdim = True)
48+
return (x + eps) / (denom + n_categories * eps)
4849

4950
def sample_vectors(samples, num):
5051
num_samples, device = samples.shape[0], samples.device
@@ -305,7 +306,7 @@ def forward(self, x):
305306
self.all_reduce_fn(embed_sum.contiguous())
306307
self.embed_avg.data.lerp_(embed_sum, 1 - self.decay)
307308

308-
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()
309+
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
309310

310311
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
311312
self.embed.data.copy_(embed_normalized)
@@ -450,7 +451,7 @@ def forward(self, x):
450451
self.all_reduce_fn(embed_sum)
451452
self.embed_avg.data.lerp_(embed_sum, 1 - self.decay)
452453

453-
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()
454+
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)
454455

455456
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
456457
embed_normalized = l2norm(embed_normalized)

0 commit comments

Comments
 (0)