Skip to content

Commit dcf0c3b

Browse files
committed
add ability to get straight through gradients, as well as use reinmax
1 parent 8dc0b71 commit dcf0c3b

File tree

3 files changed

+74
-17
lines changed

3 files changed

+74
-17
lines changed

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ residual_vq = ResidualVQ(
7070
dim = 256,
7171
num_quantizers = 8,
7272
codebook_size = 1024,
73-
sample_codebook_temp = 0.1, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
74-
shared_codebook = True # whether to share the codebooks for all quantizers or not
73+
stochastic_sample_codes = True,
74+
sample_codebook_temp = 0.1, # temperature for stochastically sampling codes, 0 would be equivalent to non-stochastic
75+
shared_codebook = True # whether to share the codebooks for all quantizers or not
7576
)
7677

7778
x = torch.randn(1, 1024, 256)
@@ -406,3 +407,12 @@ if __name__ == '__main__':
406407
}
407408
```
408409

410+
```bibtex
411+
@article{Liu2023BridgingDA,
412+
title = {Bridging Discrete and Backpropagation: Straight-Through and Beyond},
413+
author = {Liyuan Liu and Chengyu Dong and Xiaodong Liu and Bin Yu and Jianfeng Gao},
414+
journal = {ArXiv},
415+
year = {2023},
416+
volume = {abs/2304.08612}
417+
}
418+
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.5.8',
6+
version = '1.5.9',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from functools import partial
2+
13
import torch
24
from torch import nn, einsum
35
import torch.nn.functional as F
@@ -40,11 +42,41 @@ def gumbel_noise(t):
4042
noise = torch.zeros_like(t).uniform_(0, 1)
4143
return -log(-log(noise))
4244

43-
def gumbel_sample(t, temperature = 1., dim = -1):
44-
if temperature == 0:
45-
return t.argmax(dim = dim)
45+
def gumbel_sample(
46+
logits,
47+
temperature = 1.,
48+
stochastic = False,
49+
straight_through = False,
50+
reinmax = False,
51+
dim = -1
52+
):
53+
dtype, size = logits.dtype, logits.shape[dim]
54+
55+
if stochastic:
56+
logits = logits + gumbel_noise(logits)
57+
58+
ind = logits.argmax(dim = dim)
59+
one_hot = F.one_hot(ind, size).type(dtype)
4660

47-
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
61+
assert not (reinmax and not straight_through), 'reinmax can only be turned on if using straight through gumbel softmax'
62+
63+
if not straight_through:
64+
return ind, one_hot
65+
66+
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
67+
# algorithm 2
68+
69+
if reinmax:
70+
π0 = logits.softmax(dim = dim)
71+
π1 = (one_hot + (logits / temperature).softmax(dim = dim)) / 2
72+
π1 = ((π1.log() - logits).detach() + logits).softmax(dim = 1)
73+
π2 = 2 * π1 - 0.5 * π0
74+
one_hot = π2 - π2.detach() + one_hot
75+
else:
76+
π1 = (logits / temperature).softmax(dim = dim)
77+
one_hot = one_hot + π1 - π1.detach()
78+
79+
return ind, one_hot
4880

4981
def laplace_smoothing(x, n_categories, eps = 1e-5, dim = -1):
5082
denom = x.sum(dim = dim, keepdim = True)
@@ -200,7 +232,9 @@ def __init__(
200232
reset_cluster_size = None,
201233
use_ddp = False,
202234
learnable_codebook = False,
203-
sample_codebook_temp = 0
235+
sample_codebook_temp = 0,
236+
straight_through = False,
237+
gumbel_sample = gumbel_sample
204238
):
205239
super().__init__()
206240
self.transform_input = identity
@@ -216,7 +250,9 @@ def __init__(
216250
self.eps = eps
217251
self.threshold_ema_dead_code = threshold_ema_dead_code
218252
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
219-
self.sample_codebook_temp = sample_codebook_temp
253+
254+
assert callable(gumbel_sample)
255+
self.gumbel_sample = gumbel_sample
220256

221257
assert not (use_ddp and num_codebooks > 1 and kmeans_init), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'
222258

@@ -295,8 +331,7 @@ def forward(self, x):
295331

296332
dist = -torch.cdist(flatten, embed, p = 2)
297333

298-
embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp)
299-
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
334+
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1)
300335
embed_ind = unpack_one(embed_ind, ps, 'h *')
301336

302337
quantize = batched_embedding(embed_ind, self.embed)
@@ -339,7 +374,7 @@ def __init__(
339374
reset_cluster_size = None,
340375
use_ddp = False,
341376
learnable_codebook = False,
342-
sample_codebook_temp = 0.
377+
gumbel_sample = gumbel_sample
343378
):
344379
super().__init__()
345380
self.transform_input = l2norm
@@ -358,7 +393,9 @@ def __init__(
358393
self.eps = eps
359394
self.threshold_ema_dead_code = threshold_ema_dead_code
360395
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
361-
self.sample_codebook_temp = sample_codebook_temp
396+
397+
assert callable(gumbel_sample)
398+
self.gumbel_sample = gumbel_sample
362399

363400
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
364401
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
@@ -437,8 +474,7 @@ def forward(self, x):
437474
embed = self.embed if not self.learnable_codebook else self.embed.detach()
438475

439476
dist = einsum('h n d, h c d -> h n c', flatten, embed)
440-
embed_ind = gumbel_sample(dist, dim = -1, temperature = self.sample_codebook_temp)
441-
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
477+
embed_ind, embed_onehot = self.gumbel_sample(dist, dim = -1)
442478
embed_ind = unpack_one(embed_ind, ps, 'h *')
443479

444480
quantize = batched_embedding(embed_ind, self.embed)
@@ -491,8 +527,11 @@ def __init__(
491527
orthogonal_reg_weight = 0.,
492528
orthogonal_reg_active_codes_only = False,
493529
orthogonal_reg_max_codes = None,
530+
stochastic_sample_codes = False,
494531
sample_codebook_temp = 0.,
495-
sync_codebook = False
532+
straight_through = False,
533+
reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all
534+
sync_codebook = False,
496535
):
497536
super().__init__()
498537
self.dim = dim
@@ -517,6 +556,14 @@ def __init__(
517556

518557
codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
519558

559+
gumbel_sample_fn = partial(
560+
gumbel_sample,
561+
stochastic = stochastic_sample_codes,
562+
temperature = sample_codebook_temp,
563+
reinmax = reinmax,
564+
straight_through = straight_through
565+
)
566+
520567
self._codebook = codebook_class(
521568
dim = codebook_dim,
522569
num_codebooks = heads if separate_codebook_per_head else 1,
@@ -529,7 +576,7 @@ def __init__(
529576
threshold_ema_dead_code = threshold_ema_dead_code,
530577
use_ddp = sync_codebook,
531578
learnable_codebook = has_codebook_orthogonal_loss,
532-
sample_codebook_temp = sample_codebook_temp
579+
gumbel_sample = gumbel_sample_fn
533580
)
534581

535582
self.codebook_size = codebook_size

0 commit comments

Comments
 (0)