Skip to content

Commit a7bc53b

Browse files
committed
in cosine sim, update with the l2norm of the smoothed embed_avg over smoothed cluster_size
1 parent afa2275 commit a7bc53b

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.5.2',
6+
version = '1.5.3',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,7 @@ def forward(self, x):
453453
embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
454454
embed_normalized = l2norm(embed_normalized)
455455

456-
self.embed.data.lerp_(embed_normalized, 1 - self.decay)
457-
self.embed.data.copy_(l2norm(self.embed))
458-
456+
self.embed.data.copy_(l2norm(embed_normalized))
459457
self.expire_codes_(x)
460458

461459
if needs_codebook_dim:

0 commit comments

Comments
 (0)