Skip to content

Commit 1c0e73e

Browse files
committed
buy the conclusions of the new MIT paper with their solution for the internal covariate shift
1 parent e513881 commit 1c0e73e

File tree

3 files changed

+84
-7
lines changed

3 files changed

+84
-7
lines changed

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ if __name__ == '__main__':
307307

308308
- [x] allow for multi-headed codebooks
309309
- [x] support masking
310-
310+
- [ ] make sure affine param works in a distributed setting (batch mean and variance must be synced with dist reduce or whatever)
311311

312312
## Citations
313313

@@ -416,3 +416,13 @@ if __name__ == '__main__':
416416
volume = {abs/2304.08612}
417417
}
418418
```
419+
420+
```bibtex
421+
@inproceedings{huh2023improvedvqste,
422+
title = {Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks},
423+
author = {Huh, Minyoung and Cheung, Brian and Agrawal, Pulkit and Isola, Phillip},
424+
booktitle = {International Conference on Machine Learning},
425+
year = {2023},
426+
organization = {PMLR}
427+
}
428+
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.5.19',
6+
version = '1.6.0',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.distributed as distributed
77
from torch.cuda.amp import autocast
88

9-
from einops import rearrange, repeat, pack, unpack
9+
from einops import rearrange, repeat, reduce, pack, unpack
1010
from contextlib import contextmanager
1111

1212
def exists(val):
@@ -239,7 +239,10 @@ def __init__(
239239
learnable_codebook = False,
240240
gumbel_sample = gumbel_sample,
241241
sample_codebook_temp = 1.,
242-
ema_update = True
242+
ema_update = True,
243+
affine_param = False,
244+
affine_param_batch_decay = 0.99,
245+
affine_param_codebook_decay = 0.9
243246
):
244247
super().__init__()
245248
self.transform_input = identity
@@ -278,6 +281,22 @@ def __init__(
278281
else:
279282
self.register_buffer('embed', embed)
280283

284+
# affine related params
285+
286+
self.affine_param = affine_param
287+
288+
if not affine_param:
289+
return
290+
291+
self.affine_param_batch_decay = affine_param_batch_decay
292+
self.affine_param_codebook_decay = affine_param_codebook_decay
293+
294+
self.register_buffer('batch_mean', None)
295+
self.register_buffer('batch_variance', None)
296+
297+
self.register_buffer('codebook_mean', None)
298+
self.register_buffer('codebook_variance', None)
299+
281300
@torch.jit.ignore
282301
def init_embed_(self, data):
283302
if self.initted:
@@ -296,6 +315,29 @@ def init_embed_(self, data):
296315
self.cluster_size.data.copy_(cluster_size)
297316
self.initted.data.copy_(torch.Tensor([True]))
298317

318+
@torch.jit.ignore
319+
def update_with_decay(self, buffer_name, new_value, decay):
320+
old_value = getattr(self, buffer_name)
321+
322+
if not exists(old_value):
323+
self.register_buffer(buffer_name, new_value)
324+
return
325+
326+
value = old_value * decay + new_value * (1 - decay)
327+
self.register_buffer(buffer_name, value)
328+
329+
@torch.jit.ignore
330+
def update_affine(self, data, embed):
331+
assert self.affine_param
332+
333+
var_fn = partial(torch.var, unbiased = False)
334+
335+
self.update_with_decay('batch_mean', reduce(data, '... d -> d', 'mean'), self.affine_param_batch_decay)
336+
self.update_with_decay('batch_variance', reduce(data, '... d -> d', var_fn), self.affine_param_batch_decay)
337+
338+
self.update_with_decay('codebook_mean', reduce(embed, '... d -> d', 'mean'), self.affine_param_codebook_decay)
339+
self.update_with_decay('codebook_variance', reduce(embed, '... d -> d', var_fn), self.affine_param_codebook_decay)
340+
299341
def replace(self, batch_samples, batch_mask):
300342
for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))):
301343
if not torch.any(mask):
@@ -340,8 +382,16 @@ def forward(
340382

341383
self.init_embed_(flatten)
342384

385+
if self.affine_param:
386+
self.update_affine(flatten, self.embed)
387+
343388
embed = self.embed if not self.learnable_codebook else self.embed.detach()
344389

390+
if self.affine_param:
391+
codebook_std = self.codebook_variance.clamp(min = 1e-5).sqrt()
392+
batch_std = self.batch_variance.clamp(min = 1e-5).sqrt()
393+
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
394+
345395
dist = -torch.cdist(flatten, embed, p = 2)
346396

347397
embed_ind, embed_onehot, straight_through_mult = self.gumbel_sample(dist, dim = -1, temperature = sample_codebook_temp, training = self.training)
@@ -355,6 +405,10 @@ def forward(
355405
quantize = quantize * mult
356406

357407
if self.training and self.ema_update:
408+
409+
if self.affine_param:
410+
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
411+
358412
cluster_size = embed_onehot.sum(dim = 1)
359413

360414
self.all_reduce_fn(cluster_size)
@@ -565,8 +619,10 @@ def __init__(
565619
reinmax = False, # using reinmax for improved straight-through, assuming straight through helps at all
566620
sync_codebook = False,
567621
ema_update = True,
568-
learnable_codebook = False
569-
622+
learnable_codebook = False,
623+
affine_param = False,
624+
affine_param_batch_decay = 0.99,
625+
affine_param_codebook_decay = 0.9
570626
):
571627
super().__init__()
572628
self.dim = dim
@@ -598,7 +654,7 @@ def __init__(
598654
straight_through = straight_through
599655
)
600656

601-
self._codebook = codebook_class(
657+
codebook_kwargs = dict(
602658
dim = codebook_dim,
603659
num_codebooks = heads if separate_codebook_per_head else 1,
604660
codebook_size = codebook_size,
@@ -615,6 +671,17 @@ def __init__(
615671
ema_update = ema_update
616672
)
617673

674+
if affine_param:
675+
assert not use_cosine_sim
676+
codebook_kwargs = dict(
677+
**codebook_kwargs,
678+
affine_param = True,
679+
affine_param_batch_decay = affine_param_batch_decay,
680+
affine_param_codebook_decay = affine_param_codebook_decay
681+
)
682+
683+
self._codebook = codebook_class(**codebook_kwargs)
684+
618685
self.codebook_size = codebook_size
619686

620687
self.accept_image_fmap = accept_image_fmap

0 commit comments

Comments
 (0)