Skip to content

Commit 2fd172a

Browse files
committed
add ability to random sample codes for orthogonality loss, with orthogonal_reg_max_codes param
1 parent f4604a7 commit 2fd172a

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ vq = VectorQuantize(
141141
codebook_size = 256,
142142
accept_image_fmap = True, # set this true to be able to pass in an image feature map
143143
orthogonal_reg_weight = 10, # in paper, they recommended a value of 10
144+
orthogonal_reg_max_codes = 128, # this would randomly sample from the codebook for the orthogonal regularization loss, for limiting memory usage
144145
orthogonal_reg_active_codes_only = False # set this to True if you have a very large codebook, and would only like to enforce the loss on the activated codes per batch
145146
)
146147

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.4.5',
6+
version = '0.4.7',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,8 @@ def __init__(
263263
commitment_weight = None,
264264
commitment = 1., # deprecate in next version, turn off by default
265265
orthogonal_reg_weight = 0.,
266-
orthogonal_reg_active_codes_only = False
266+
orthogonal_reg_active_codes_only = False,
267+
orthogonal_reg_max_codes = None
267268
):
268269
super().__init__()
269270
n_embed = default(n_embed, codebook_size)
@@ -280,6 +281,7 @@ def __init__(
280281

281282
self.orthogonal_reg_weight = orthogonal_reg_weight
282283
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
284+
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
283285

284286
codebook_class = EuclideanCodebook if not use_cosine_sim \
285287
else CosineSimCodebook
@@ -337,6 +339,11 @@ def forward(self, x):
337339
unique_code_ids = torch.unique(embed_ind)
338340
codebook = codebook[unique_code_ids]
339341

342+
num_codes = codebook.shape[0]
343+
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
344+
rand_ids = torch.randperm(num_codes, device = device)[:self.orthogonal_reg_max_codes]
345+
codebook = codebook[rand_ids]
346+
340347
orthogonal_reg_loss = orthgonal_loss_fn(codebook)
341348
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
342349

0 commit comments

Comments
 (0)