Skip to content

Commit 6fd0547

Browse files
committed
allow for freezing codebook on forward, also convenience setter for codebook
1 parent da2fb35 commit 6fd0547

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.6.31',
6+
version = '1.6.32',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,8 @@ def forward(
438438
self,
439439
x,
440440
sample_codebook_temp = None,
441-
mask = None
441+
mask = None,
442+
freeze_codebook = False
442443
):
443444
needs_codebook_dim = x.ndim < 4
444445
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
@@ -478,7 +479,7 @@ def forward(
478479
else:
479480
quantize = batched_embedding(embed_ind, embed)
480481

481-
if self.training and self.ema_update:
482+
if self.training and self.ema_update and not freeze_codebook:
482483

483484
if self.affine_param:
484485
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
@@ -620,7 +621,8 @@ def forward(
620621
self,
621622
x,
622623
sample_codebook_temp = None,
623-
mask = None
624+
mask = None,
625+
freeze_codebook = False
624626
):
625627
needs_codebook_dim = x.ndim < 4
626628
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
@@ -652,7 +654,7 @@ def forward(
652654
else:
653655
quantize = batched_embedding(embed_ind, embed)
654656

655-
if self.training and self.ema_update:
657+
if self.training and self.ema_update and not freeze_codebook:
656658
if exists(mask):
657659
embed_onehot[~mask] = 0.
658660

@@ -691,6 +693,7 @@ def __init__(
691693
separate_codebook_per_head = False,
692694
decay = 0.8,
693695
eps = 1e-5,
696+
freeze_codebook = False,
694697
kmeans_init = False,
695698
kmeans_iters = 10,
696699
sync_kmeans = True,
@@ -796,11 +799,19 @@ def __init__(
796799
@property
797800
def codebook(self):
798801
codebook = self._codebook.embed
802+
799803
if self.separate_codebook_per_head:
800804
return codebook
801805

802806
return rearrange(codebook, '1 ... -> ...')
803807

808+
@codebook.setter
809+
def codebook(self, codes):
810+
if not self.separate_codebook_per_head:
811+
codes = rearrange(codes, '... -> 1 ...')
812+
813+
self._codebook.embed.copy_(codes)
814+
804815
def get_codes_from_indices(self, indices):
805816
codebook = self.codebook
806817
is_multiheaded = codebook.ndim > 2
@@ -825,7 +836,8 @@ def forward(
825836
x,
826837
indices = None,
827838
mask = None,
828-
sample_codebook_temp = None
839+
sample_codebook_temp = None,
840+
freeze_codebook = False
829841
):
830842
orig_input = x
831843

@@ -867,7 +879,8 @@ def forward(
867879

868880
codebook_forward_kwargs = dict(
869881
sample_codebook_temp = sample_codebook_temp,
870-
mask = mask
882+
mask = mask,
883+
freeze_codebook = freeze_codebook
871884
)
872885

873886
# quantize
@@ -876,7 +889,7 @@ def forward(
876889

877890
# one step in-place update
878891

879-
if should_inplace_optimize and self.training:
892+
if should_inplace_optimize and self.training and not freeze_codebook:
880893

881894
if exists(mask):
882895
loss = F.mse_loss(quantize, x.detach(), reduction = 'none')
@@ -900,7 +913,7 @@ def forward(
900913

901914
if self.training:
902915
# determine code to use for commitment loss
903-
maybe_detach = torch.detach if not self.learnable_codebook else identity
916+
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
904917

905918
commit_quantize = maybe_detach(quantize)
906919

0 commit comments

Comments
 (0)