Skip to content

Commit 936d9be

Browse files
committed
remove what reportedly does not work well
1 parent 1c0e73e commit 936d9be

File tree

3 files changed

+2
-75
lines changed

3 files changed

+2
-75
lines changed

README.md

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -182,31 +182,6 @@ x = torch.randn(1, 1024, 256)
182182
quantized, indices, commit_loss = vq(x)
183183
```
184184

185-
### Orthogonal regularization loss
186-
187-
VQ-VAE / VQ-GAN is quickly gaining popularity. A <a href="https://arxiv.org/abs/2112.00384">recent paper</a> proposes that when using vector quantization on images, enforcing the codebook to be orthogonal leads to translation equivariance of the discretized codes, leading to large improvements in downstream text to image generation tasks.
188-
189-
You can use this feature by simply setting the `orthogonal_reg_weight` to be greater than `0`, in which case the orthogonal regularization will be added to the auxiliary loss outputted by the module.
190-
191-
```python
192-
import torch
193-
from vector_quantize_pytorch import VectorQuantize
194-
195-
vq = VectorQuantize(
196-
dim = 256,
197-
codebook_size = 256,
198-
accept_image_fmap = True, # set this true to be able to pass in an image feature map
199-
orthogonal_reg_weight = 10, # in paper, they recommended a value of 10
200-
orthogonal_reg_max_codes = 128, # this would randomly sample from the codebook for the orthogonal regularization loss, for limiting memory usage
201-
orthogonal_reg_active_codes_only = False # set this to True if you have a very large codebook, and would only like to enforce the loss on the activated codes per batch
202-
)
203-
204-
img_fmap = torch.randn(1, 256, 32, 32)
205-
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)
206-
207-
# loss now contains the orthogonal regularization loss with the weight as assigned
208-
```
209-
210185
### Multi-headed VQ
211186

212187
There has been a number of papers that proposes variants of discrete latent representations with a multi-headed approach (multiple codes per feature). I have decided to offer one variant where the same codebook is used to vector quantize across the input dimension `head` times.
@@ -230,7 +205,6 @@ img_fmap = torch.randn(1, 256, 32, 32)
230205
quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32, 8), (1,)
231206

232207
# indices shape - (batch, height, width, heads)
233-
# loss now contains the orthogonal regularization loss with the weight as assigned
234208
```
235209
### Random Projection Quantizer
236210

@@ -344,17 +318,6 @@ if __name__ == '__main__':
344318
}
345319
```
346320

347-
```bibtex
348-
@misc{shin2021translationequivariant,
349-
title = {Translation-equivariant Image Quantizer for Bi-directional Image-Text Generation},
350-
author = {Woncheol Shin and Gyubok Lee and Jiyoung Lee and Joonseok Lee and Edward Choi},
351-
year = {2021},
352-
eprint = {2112.00384},
353-
archivePrefix = {arXiv},
354-
primaryClass = {cs.CV}
355-
}
356-
```
357-
358321
```bibtex
359322
@unknown{unknown,
360323
author = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.0',
6+
version = '1.6.1',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -211,15 +211,6 @@ def batched_embedding(indices, embeds):
211211
embeds = repeat(embeds, 'h c d -> h b c d', b = batch)
212212
return embeds.gather(2, indices)
213213

214-
# regularization losses
215-
216-
def orthogonal_loss_fn(t):
217-
# eq (2) from https://arxiv.org/abs/2112.00384
218-
h, n = t.shape[:2]
219-
normed_codes = l2norm(t)
220-
cosine_sim = einsum('h i d, h j d -> h i j', normed_codes, normed_codes)
221-
return (cosine_sim ** 2).sum() / (h * n ** 2) - (1 / n)
222-
223214
# distance types
224215

225216
class EuclideanCodebook(nn.Module):
@@ -610,9 +601,6 @@ def __init__(
610601
accept_image_fmap = False,
611602
commitment_weight = 1.,
612603
commitment_use_cross_entropy_loss = False,
613-
orthogonal_reg_weight = 0.,
614-
orthogonal_reg_active_codes_only = False,
615-
orthogonal_reg_max_codes = None,
616604
stochastic_sample_codes = False,
617605
sample_codebook_temp = 1.,
618606
straight_through = False,
@@ -640,11 +628,6 @@ def __init__(
640628
self.commitment_weight = commitment_weight
641629
self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
642630

643-
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
644-
self.orthogonal_reg_weight = orthogonal_reg_weight
645-
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
646-
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
647-
648631
codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
649632

650633
gumbel_sample_fn = partial(
@@ -665,7 +648,7 @@ def __init__(
665648
eps = eps,
666649
threshold_ema_dead_code = threshold_ema_dead_code,
667650
use_ddp = sync_codebook,
668-
learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook,
651+
learnable_codebook = learnable_codebook,
669652
sample_codebook_temp = sample_codebook_temp,
670653
gumbel_sample = gumbel_sample_fn,
671654
ema_update = ema_update
@@ -828,25 +811,6 @@ def calculate_ce_loss(codes):
828811

829812
loss = loss + commit_loss * self.commitment_weight
830813

831-
if self.orthogonal_reg_weight > 0:
832-
codebook = self._codebook.embed
833-
834-
# only calculate orthogonal loss for the activated codes for this batch
835-
836-
if self.orthogonal_reg_active_codes_only:
837-
assert not (is_multiheaded and self.separate_codebook_per_head), 'orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet'
838-
unique_code_ids = torch.unique(embed_ind)
839-
codebook = codebook[:, unique_code_ids]
840-
841-
num_codes = codebook.shape[-2]
842-
843-
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
844-
rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes]
845-
codebook = codebook[:, rand_ids]
846-
847-
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
848-
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
849-
850814
# handle multi-headed quantized embeddings
851815

852816
if is_multiheaded:

0 commit comments

Comments
 (0)