Skip to content

Commit 9d6a34f

Browse files
committed
need to account for multi-headedness (multiple codebooks). cite einops for immeasurable time saved
1 parent 4c090f1 commit 9d6a34f

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,13 @@ if __name__ == '__main__':
389389
organization = {PMLR}
390390
}
391391
```
392+
393+
```bibtex
394+
@inproceedings{rogozhnikov2022einops,
395+
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
396+
author = {Alex Rogozhnikov},
397+
booktitle = {International Conference on Learning Representations},
398+
year = {2022},
399+
url = {https://openreview.net/forum?id=oapKSVM2bcj}
400+
}
401+
```

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.1',
6+
version = '1.6.2',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,11 @@ def update_affine(self, data, embed):
323323

324324
var_fn = partial(torch.var, unbiased = False)
325325

326-
self.update_with_decay('batch_mean', reduce(data, '... d -> d', 'mean'), self.affine_param_batch_decay)
327-
self.update_with_decay('batch_variance', reduce(data, '... d -> d', var_fn), self.affine_param_batch_decay)
326+
self.update_with_decay('batch_mean', reduce(data, 'h ... d -> h 1 d', 'mean'), self.affine_param_batch_decay)
327+
self.update_with_decay('batch_variance', reduce(data, 'h ... d -> h 1 d', var_fn), self.affine_param_batch_decay)
328328

329-
self.update_with_decay('codebook_mean', reduce(embed, '... d -> d', 'mean'), self.affine_param_codebook_decay)
330-
self.update_with_decay('codebook_variance', reduce(embed, '... d -> d', var_fn), self.affine_param_codebook_decay)
329+
self.update_with_decay('codebook_mean', reduce(embed, 'h ... d -> h 1 d', 'mean'), self.affine_param_codebook_decay)
330+
self.update_with_decay('codebook_variance', reduce(embed, 'h ... d -> h 1 d', var_fn), self.affine_param_codebook_decay)
331331

332332
def replace(self, batch_samples, batch_mask):
333333
for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim = 0), batch_mask.unbind(dim = 0))):

0 commit comments

Comments
 (0)