Skip to content

Commit f867b33

Browse files
committed
address straight through #59
1 parent 1a59a89 commit f867b33

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.21',
6+
version = '1.6.22',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def gumbel_sample(
7474
assert not (reinmax and not straight_through), 'reinmax can only be turned on if using straight through gumbel softmax'
7575

7676
if not straight_through or temperature <= 0. or not training:
77-
return ind, one_hot, None
77+
return ind, one_hot
7878

7979
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
8080
# algorithm 2
@@ -89,9 +89,7 @@ def gumbel_sample(
8989
π1 = (logits / temperature).softmax(dim = dim)
9090
one_hot = one_hot + π1 - π1.detach()
9191

92-
st_mult = one_hot.gather(-1, rearrange(ind, '... -> ... 1')) # multiplier for straight-through
93-
94-
return ind, one_hot, st_mult
92+
return ind, one_hot
9593

9694
def laplace_smoothing(x, n_categories, eps = 1e-5, dim = -1):
9795
denom = x.sum(dim = dim, keepdim = True)
@@ -433,15 +431,15 @@ def forward(
433431

434432
dist = -torch.cdist(flatten, embed, p = 2)
435433

436-
embed_ind, embed_onehot, straight_through_mult = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
434+
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
437435

438436
embed_ind = unpack_one(embed_ind, ps, 'h *')
439437

440-
quantize = batched_embedding(embed_ind, self.embed)
441-
442-
if exists(straight_through_mult):
443-
mult = unpack_one(straight_through_mult, ps, 'h * d')
444-
quantize = quantize * mult
438+
if self.training:
439+
unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c')
440+
quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)
441+
else:
442+
quantize = batched_embedding(embed_ind, embed)
445443

446444
if self.training and self.ema_update:
447445

@@ -595,14 +593,14 @@ def forward(
595593

596594
dist = einsum('h n d, h c d -> h n c', flatten, embed)
597595

598-
embed_ind, embed_onehot, straight_through_mult = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
596+
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
599597
embed_ind = unpack_one(embed_ind, ps, 'h *')
600598

601-
quantize = batched_embedding(embed_ind, self.embed)
602-
603-
if exists(straight_through_mult):
604-
mult = unpack_one(straight_through_mult, ps, 'h * d')
605-
quantize = quantize * mult
599+
if self.training:
600+
unpacked_onehot = unpack_one(embed_onehot, ps, 'h * c')
601+
quantize = einsum('h b n c, h c d -> h b n d', unpacked_onehot, embed)
602+
else:
603+
quantize = batched_embedding(embed_ind, embed)
606604

607605
if self.training and self.ema_update:
608606
bins = embed_onehot.sum(dim = 1)
@@ -726,7 +724,7 @@ def __init__(
726724
)
727725

728726
if affine_param:
729-
assert not use_cosine_sim
727+
assert not use_cosine_sim, 'affine param is only compatible with euclidean codebook'
730728
codebook_kwargs = dict(
731729
**codebook_kwargs,
732730
affine_param = True,

0 commit comments

Comments
 (0)