@@ -47,7 +47,8 @@ def __init__(
4747 num_codebooks = 1 ,
4848 keep_num_codebooks_dim : Optional [bool ] = None ,
4949 scale : Optional [float ] = None ,
50- allowed_dtypes : Tuple [torch .dtype , ...] = (torch .float32 , torch .float64 )
50+ allowed_dtypes : Tuple [torch .dtype , ...] = (torch .float32 , torch .float64 ),
51+ channel_first : bool = False
5152 ):
5253 super ().__init__ ()
5354 _levels = torch .tensor (levels , dtype = int32 )
@@ -71,14 +72,17 @@ def __init__(
7172
7273 self .dim = default (dim , len (_levels ) * num_codebooks )
7374
75+ self .channel_first = channel_first
76+
7477 has_projections = self .dim != effective_codebook_dim
7578 self .project_in = nn .Linear (self .dim , effective_codebook_dim ) if has_projections else nn .Identity ()
7679 self .project_out = nn .Linear (effective_codebook_dim , self .dim ) if has_projections else nn .Identity ()
80+
7781 self .has_projections = has_projections
7882
7983 self .codebook_size = self ._levels .prod ().item ()
8084
81- implicit_codebook = self .indices_to_codes (torch .arange (self .codebook_size ), project_out = False )
85+ implicit_codebook = self ._indices_to_codes (torch .arange (self .codebook_size ))
8286 self .register_buffer ("implicit_codebook" , implicit_codebook , persistent = False )
8387
8488 self .allowed_dtypes = allowed_dtypes
@@ -103,33 +107,35 @@ def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
103107 def _scale_and_shift_inverse (self , zhat : Tensor ) -> Tensor :
104108 half_width = self ._levels // 2
105109 return (zhat - half_width ) / half_width
106-
110+
111+ def _indices_to_codes (self , indices : Tensor ):
112+ indices = rearrange (indices , '... -> ... 1' )
113+ codes_non_centered = (indices // self ._basis ) % self ._levels
114+ codes = self ._scale_and_shift_inverse (codes_non_centered )
115+ return codes
116+
107117 def codes_to_indices (self , zhat : Tensor ) -> Tensor :
108118 """Converts a `code` to an index in the codebook."""
109119 assert zhat .shape [- 1 ] == self .codebook_dim
110120 zhat = self ._scale_and_shift (zhat )
111121 return (zhat * self ._basis ).sum (dim = - 1 ).to (int32 )
112-
122+
113123 def indices_to_codes (
114124 self ,
115- indices : Tensor ,
116- project_out = True
125+ indices : Tensor
117126 ) -> Tensor :
118127 """Inverse of `codes_to_indices`."""
119128
120129 is_img_or_video = indices .ndim >= (3 + int (self .keep_num_codebooks_dim ))
121130
122- indices = rearrange (indices , '... -> ... 1' )
123- codes_non_centered = (indices // self ._basis ) % self ._levels
124- codes = self ._scale_and_shift_inverse (codes_non_centered )
131+ codes = self ._indices_to_codes (indices )
125132
126133 if self .keep_num_codebooks_dim :
127134 codes = rearrange (codes , '... c d -> ... (c d)' )
128135
129- if project_out :
130- codes = self .project_out (codes )
136+ codes = self .project_out (codes )
131137
132- if is_img_or_video :
138+ if is_img_or_video or self . channel_first :
133139 codes = rearrange (codes , 'b ... d -> b d ...' )
134140
135141 return codes
@@ -146,10 +152,11 @@ def forward(self, z: Tensor) -> Tensor:
146152
147153 orig_dtype = z .dtype
148154 is_img_or_video = z .ndim >= 4
155+ need_move_channel_last = is_img_or_video or self .channel_first
149156
150157 # standardize image or video into (batch, seq, dimension)
151158
152- if is_img_or_video :
159+ if need_move_channel_last :
153160 z = rearrange (z , 'b d ... -> b ... d' )
154161 z , ps = pack_one (z , 'b * d' )
155162
@@ -180,7 +187,7 @@ def forward(self, z: Tensor) -> Tensor:
180187
181188 # reconstitute image or video dimensions
182189
183- if is_img_or_video :
190+ if need_move_channel_last :
184191 out = unpack_one (out , ps , 'b * d' )
185192 out = rearrange (out , 'b ... d -> b d ...' )
186193
0 commit comments