Skip to content

Commit 03da009

Browse files
committed
for cosine sim, keep track of the exponential moving averages of the embedding sum and cluster sizes, and use both for deriving the next center
1 parent 500db14 commit 03da009

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.4.1',
6+
version = '1.5.0',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def __init__(
358358

359359
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
360360
self.register_buffer('cluster_size', torch.zeros(num_codebooks, codebook_size))
361+
self.register_buffer('embed_avg', embed.clone())
361362

362363
self.learnable_codebook = learnable_codebook
363364
if learnable_codebook:
@@ -380,6 +381,7 @@ def init_embed_(self, data):
380381
)
381382

382383
self.embed.data.copy_(embed)
384+
self.embed_avg.data.copy_(embed.clone())
383385
self.cluster_size.data.copy_(cluster_size)
384386
self.initted.data.copy_(torch.Tensor([True]))
385387

@@ -394,7 +396,8 @@ def replace(self, batch_samples, batch_mask):
394396
sampled = rearrange(sampled, '1 ... -> ...')
395397

396398
self.embed.data[ind][mask] = sampled
397-
self.cluster_size.data[ind][mask] = self.reset_cluster_size
399+
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
400+
self.cluster_size.data[ind][mask] = self.reset_cluster_size
398401

399402
def expire_codes_(self, batch_samples):
400403
if self.threshold_ema_dead_code == 0:
@@ -445,8 +448,11 @@ def forward(self, x):
445448

446449
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
447450
self.all_reduce_fn(embed_sum)
451+
self.embed_avg.data.lerp_(embed_sum, 1 - self.decay)
448452

449-
embed_normalized = embed_sum / rearrange(bins, '... -> ... 1')
453+
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()
454+
455+
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
450456
embed_normalized = l2norm(embed_normalized)
451457

452458
embed_normalized = torch.where(
@@ -456,6 +462,8 @@ def forward(self, x):
456462
)
457463

458464
self.embed.data.lerp_(embed_normalized, 1 - self.decay)
465+
self.embed.data.copy_(l2norm(self.embed))
466+
459467
self.expire_codes_(x)
460468

461469
if needs_codebook_dim:

0 commit comments

Comments
 (0)