@@ -438,7 +438,8 @@ def forward(
438438 self ,
439439 x ,
440440 sample_codebook_temp = None ,
441- mask = None
441+ mask = None ,
442+ freeze_codebook = False
442443 ):
443444 needs_codebook_dim = x .ndim < 4
444445 sample_codebook_temp = default (sample_codebook_temp , self .sample_codebook_temp )
@@ -478,7 +479,7 @@ def forward(
478479 else :
479480 quantize = batched_embedding (embed_ind , embed )
480481
481- if self .training and self .ema_update :
482+ if self .training and self .ema_update and not freeze_codebook :
482483
483484 if self .affine_param :
484485 flatten = (flatten - self .batch_mean ) * (codebook_std / batch_std ) + self .codebook_mean
@@ -620,7 +621,8 @@ def forward(
620621 self ,
621622 x ,
622623 sample_codebook_temp = None ,
623- mask = None
624+ mask = None ,
625+ freeze_codebook = False
624626 ):
625627 needs_codebook_dim = x .ndim < 4
626628 sample_codebook_temp = default (sample_codebook_temp , self .sample_codebook_temp )
@@ -652,7 +654,7 @@ def forward(
652654 else :
653655 quantize = batched_embedding (embed_ind , embed )
654656
655- if self .training and self .ema_update :
657+ if self .training and self .ema_update and not freeze_codebook :
656658 if exists (mask ):
657659 embed_onehot [~ mask ] = 0.
658660
@@ -691,6 +693,7 @@ def __init__(
691693 separate_codebook_per_head = False ,
692694 decay = 0.8 ,
693695 eps = 1e-5 ,
696+ freeze_codebook = False ,
694697 kmeans_init = False ,
695698 kmeans_iters = 10 ,
696699 sync_kmeans = True ,
@@ -796,11 +799,19 @@ def __init__(
796799 @property
797800 def codebook (self ):
798801 codebook = self ._codebook .embed
802+
799803 if self .separate_codebook_per_head :
800804 return codebook
801805
802806 return rearrange (codebook , '1 ... -> ...' )
803807
808+ @codebook .setter
809+ def codebook (self , codes ):
810+ if not self .separate_codebook_per_head :
811+ codes = rearrange (codes , '... -> 1 ...' )
812+
813+ self ._codebook .embed .copy_ (codes )
814+
804815 def get_codes_from_indices (self , indices ):
805816 codebook = self .codebook
806817 is_multiheaded = codebook .ndim > 2
@@ -825,7 +836,8 @@ def forward(
825836 x ,
826837 indices = None ,
827838 mask = None ,
828- sample_codebook_temp = None
839+ sample_codebook_temp = None ,
840+ freeze_codebook = False
829841 ):
830842 orig_input = x
831843
@@ -867,7 +879,8 @@ def forward(
867879
868880 codebook_forward_kwargs = dict (
869881 sample_codebook_temp = sample_codebook_temp ,
870- mask = mask
882+ mask = mask ,
883+ freeze_codebook = freeze_codebook
871884 )
872885
873886 # quantize
@@ -876,7 +889,7 @@ def forward(
876889
877890 # one step in-place update
878891
879- if should_inplace_optimize and self .training :
892+ if should_inplace_optimize and self .training and not freeze_codebook :
880893
881894 if exists (mask ):
882895 loss = F .mse_loss (quantize , x .detach (), reduction = 'none' )
@@ -900,7 +913,7 @@ def forward(
900913
901914 if self .training :
902915 # determine code to use for commitment loss
903- maybe_detach = torch .detach if not self .learnable_codebook else identity
916+ maybe_detach = torch .detach if not self .learnable_codebook or freeze_codebook else identity
904917
905918 commit_quantize = maybe_detach (quantize )
906919
0 commit comments