@@ -28,7 +28,9 @@ class ResidualVQ(nn.Module):
2828 def __init__ (
2929 self ,
3030 * ,
31+ dim ,
3132 num_quantizers ,
33+ codebook_dim = None ,
3234 shared_codebook = False ,
3335 heads = 1 ,
3436 quantize_dropout = False ,
@@ -39,11 +41,17 @@ def __init__(
3941 ):
4042 super ().__init__ ()
4143 assert heads == 1 , 'residual vq is not compatible with multi-headed codes'
44+ codebook_dim = default (codebook_dim , dim )
45+ codebook_input_dim = codebook_dim * heads
46+
47+ requires_projection = codebook_input_dim != dim
48+ self .project_in = nn .Linear (dim , codebook_input_dim ) if requires_projection else nn .Identity ()
49+ self .project_out = nn .Linear (codebook_input_dim , dim ) if requires_projection else nn .Identity ()
4250
4351 self .num_quantizers = num_quantizers
4452
4553 self .accept_image_fmap = accept_image_fmap
46- self .layers = nn .ModuleList ([VectorQuantize (accept_image_fmap = accept_image_fmap , ** kwargs ) for _ in range (num_quantizers )])
54+ self .layers = nn .ModuleList ([VectorQuantize (dim = codebook_dim , codebook_dim = codebook_dim , accept_image_fmap = accept_image_fmap , ** kwargs ) for _ in range (num_quantizers )])
4755
4856 self .quantize_dropout = quantize_dropout and num_quantizers > 1
4957
@@ -114,6 +122,8 @@ def forward(
114122 ):
115123 num_quant , quant_dropout_multiple_of , return_loss , device = self .num_quantizers , self .quantize_dropout_multiple_of , exists (indices ), x .device
116124
125+ x = self .project_in (x )
126+
117127 assert not (self .accept_image_fmap and exists (indices ))
118128
119129 quantized_out = 0.
@@ -169,6 +179,10 @@ def forward(
169179 all_indices .append (embed_indices )
170180 all_losses .append (loss )
171181
182+ # project out, if needed
183+
184+ quantized_out = self .project_out (quantized_out )
185+
172186 # whether to early return the cross entropy loss
173187
174188 if return_loss :
0 commit comments