Skip to content

Commit f4604a7

Browse files
committed
fix
1 parent 918322a commit f4604a7

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.4.4',
6+
version = '0.4.5',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
author = 'Phil Wang',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,9 @@ def forward(self, x):
335335
if self.orthogonal_reg_active_codes_only:
336336
# only calculate orthogonal loss for the activated codes for this batch
337337
unique_code_ids = torch.unique(embed_ind)
338-
codebook = self.codebook[unique_code_ids]
339-
else:
340-
codebook = self.codebook
338+
codebook = codebook[unique_code_ids]
341339

342-
orthogonal_reg_loss = orthgonal_loss_fn(self.codebook)
340+
orthogonal_reg_loss = orthgonal_loss_fn(codebook)
343341
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
344342

345343
quantize = self.project_out(quantize)

0 commit comments

Comments
 (0)