@@ -147,11 +147,6 @@ def forward(self, z: Tensor) -> Tensor:
147147 orig_dtype = z .dtype
148148 is_img_or_video = z .ndim >= 4
149149
150- # make sure allowed dtype
151-
152- if z .dtype not in self .allowed_dtypes :
153- z = z .float ()
154-
155150 # standardize image or video into (batch, seq, dimension)
156151
157152 if is_img_or_video :
@@ -164,11 +159,23 @@ def forward(self, z: Tensor) -> Tensor:
164159
165160 z = rearrange (z , 'b n (c d) -> b n c d' , c = self .num_codebooks )
166161
162+ # make sure allowed dtype before quantizing
163+
164+ if z .dtype not in self .allowed_dtypes :
165+ z = z .float ()
166+
167167 codes = self .quantize (z )
168168 indices = self .codes_to_indices (codes )
169169
170170 codes = rearrange (codes , 'b n c d -> b n (c d)' )
171171
172+ # cast codes back to original dtype
173+
174+ if codes .dtype != orig_dtype :
175+ codes = codes .type (orig_dtype )
176+
177+ # project out
178+
172179 out = self .project_out (codes )
173180
174181 # reconstitute image or video dimensions
@@ -182,11 +189,6 @@ def forward(self, z: Tensor) -> Tensor:
182189 if not self .keep_num_codebooks_dim :
183190 indices = rearrange (indices , '... 1 -> ...' )
184191
185- # cast back to original dtype
186-
187- if out .dtype != orig_dtype :
188- out = out .type (orig_dtype )
189-
190192 # return quantized output and indices
191193
192194 return out , indices
0 commit comments