Skip to content

Commit 14b58b6

Browse files
committed
manually calculate cdist, as observed NaNs using torch.cdist in another project
1 parent 8e847d1 commit 14b58b6

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.28',
6+
version = '1.6.29',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ def identity(t):
2626
def l2norm(t):
2727
return F.normalize(t, p = 2, dim = -1)
2828

29+
def cdist(x, y, eps = 1e-10):
30+
x2 = reduce(x ** 2, 'b n d -> b n', 'sum')
31+
y2 = reduce(y ** 2, 'b n d -> b n', 'sum')
32+
xy = einsum('b i d, b j d -> b i j', x, y) * -2
33+
return (rearrange(x2, 'b i -> b i 1') + rearrange(y2, 'b j -> b 1 j') + xy).clamp(min = eps).sqrt()
34+
2935
def log(t, eps = 1e-20):
3036
return torch.log(t.clamp(min = eps))
3137

@@ -458,7 +464,7 @@ def forward(
458464
batch_std = self.batch_variance.clamp(min = 1e-5).sqrt()
459465
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
460466

461-
dist = -torch.cdist(flatten, embed, p = 2)
467+
dist = -cdist(flatten, embed)
462468

463469
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
464470

0 commit comments

Comments
 (0)