Skip to content

Commit 918322a

Browse files
committed
add ability to only enforce orthogonality loss on activated codes within a batch, for large codebooks (say taming transformers 16k)
1 parent 4c5726f commit 918322a

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ from vector_quantize_pytorch import VectorQuantize
139139
vq = VectorQuantize(
140140
dim = 256,
141141
codebook_size = 256,
142-
accept_image_fmap = True, # set this true to be able to pass in an image feature map
143-
orthogonal_reg_weight = 10, # in paper, they recommended a value of 10
142+
accept_image_fmap = True, # set this true to be able to pass in an image feature map
143+
orthogonal_reg_weight = 10, # in paper, they recommended a value of 10
144+
orthogonal_reg_active_codes_only = False # set this to True if you have a very large codebook, and would only like to enforce the loss on the activated codes per batch
144145
)
145146

146147
img_fmap = torch.randn(1, 256, 32, 32)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.4.3',
6+
version = '0.4.4',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,16 +253,17 @@ def __init__(
253253
n_embed = None,
254254
codebook_dim = None,
255255
decay = 0.8,
256-
orthogonal_reg_weight = 0.,
257-
commitment_weight = None,
258256
eps = 1e-5,
259257
kmeans_init = False,
260258
kmeans_iters = 10,
261259
use_cosine_sim = False,
262260
threshold_ema_dead_code = 0,
263261
channel_last = True,
264262
accept_image_fmap = False,
265-
commitment = 1. # deprecate in next version, turn off by default
263+
commitment_weight = None,
264+
commitment = 1., # deprecate in next version, turn off by default
265+
orthogonal_reg_weight = 0.,
266+
orthogonal_reg_active_codes_only = False
266267
):
267268
super().__init__()
268269
n_embed = default(n_embed, codebook_size)
@@ -276,7 +277,9 @@ def __init__(
276277

277278
self.eps = eps
278279
self.commitment_weight = default(commitment_weight, commitment)
280+
279281
self.orthogonal_reg_weight = orthogonal_reg_weight
282+
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
280283

281284
codebook_class = EuclideanCodebook if not use_cosine_sim \
282285
else CosineSimCodebook
@@ -327,6 +330,15 @@ def forward(self, x):
327330
loss = loss + commit_loss * self.commitment_weight
328331

329332
if self.orthogonal_reg_weight > 0:
333+
codebook = self.codebook
334+
335+
if self.orthogonal_reg_active_codes_only:
336+
# only calculate orthogonal loss for the activated codes for this batch
337+
unique_code_ids = torch.unique(embed_ind)
338+
codebook = self.codebook[unique_code_ids]
339+
else:
340+
codebook = self.codebook
341+
330342
orthogonal_reg_loss = orthgonal_loss_fn(self.codebook)
331343
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
332344

0 commit comments

Comments
 (0)