Skip to content

Commit 87b54c1

Browse files
committed
start using lerp for ema
1 parent 61b821d commit 87b54c1

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.0.0',
6+
version = '1.0.1',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@ def gumbel_sample(t, temperature = 1., dim = -1):
3737

3838
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
3939

40-
def ema_inplace(moving_avg, new, decay):
41-
moving_avg.data.mul_(decay).add_(new, alpha = (1 - decay))
42-
4340
def laplace_smoothing(x, n_categories, eps = 1e-5):
4441
return (x + eps) / (x.sum() + n_categories * eps)
4542

@@ -289,11 +286,11 @@ def forward(self, x):
289286
cluster_size = embed_onehot.sum(dim = 1)
290287

291288
self.all_reduce_fn(cluster_size)
292-
ema_inplace(self.cluster_size, cluster_size, self.decay)
289+
self.cluster_size.data.lerp_(cluster_size, 1 - self.decay)
293290

294291
embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
295292
self.all_reduce_fn(embed_sum.contiguous())
296-
ema_inplace(self.embed_avg, embed_sum, self.decay)
293+
self.embed_avg.data.lerp_(embed_sum, 1 - self.decay)
297294

298295
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()
299296

@@ -421,7 +418,7 @@ def forward(self, x):
421418
bins = embed_onehot.sum(dim = 1)
422419
self.all_reduce_fn(bins)
423420

424-
ema_inplace(self.cluster_size, bins, self.decay)
421+
self.cluster_size.data.lerp_(bins, 1 - self.decay)
425422

426423
zero_mask = (bins == 0)
427424
bins = bins.masked_fill(zero_mask, 1.)
@@ -438,7 +435,7 @@ def forward(self, x):
438435
embed_normalized
439436
)
440437

441-
ema_inplace(self.embed, embed_normalized, self.decay)
438+
self.embed.data.lerp_(embed_normalized, 1 - self.decay)
442439
self.expire_codes_(x)
443440

444441
if needs_codebook_dim:

0 commit comments

Comments
 (0)