@@ -87,44 +87,46 @@ def __init__(
8787
8888 self .allowed_dtypes = allowed_dtypes
8989
90- def bound (self , z : Tensor , eps : float = 1e-3 ) -> Tensor :
91- """Bound `z`, an array of shape (..., d)."""
90+ def bound (self , z , eps : float = 1e-3 ):
91+ """ Bound `z`, an array of shape (..., d). """
9292 half_l = (self ._levels - 1 ) * (1 + eps ) / 2
9393 offset = torch .where (self ._levels % 2 == 0 , 0.5 , 0.0 )
9494 shift = (offset / half_l ).atanh ()
9595 return (z + shift ).tanh () * half_l - offset
9696
97- def quantize (self , z : Tensor ) -> Tensor :
98- """Quantizes z, returns quantized zhat, same shape as z."""
97+ def quantize (self , z ) :
98+ """ Quantizes z, returns quantized zhat, same shape as z. """
9999 quantized = round_ste (self .bound (z ))
100100 half_width = self ._levels // 2 # Renormalize to [-1, 1].
101101 return quantized / half_width
102102
103- def _scale_and_shift (self , zhat_normalized : Tensor ) -> Tensor :
103+ def _scale_and_shift (self , zhat_normalized ) :
104104 half_width = self ._levels // 2
105105 return (zhat_normalized * half_width ) + half_width
106106
107- def _scale_and_shift_inverse (self , zhat : Tensor ) -> Tensor :
107+ def _scale_and_shift_inverse (self , zhat ) :
108108 half_width = self ._levels // 2
109109 return (zhat - half_width ) / half_width
110110
111- def _indices_to_codes (self , indices : Tensor ):
112- indices = rearrange (indices , '... -> ... 1' )
113- codes_non_centered = (indices // self ._basis ) % self ._levels
114- codes = self ._scale_and_shift_inverse (codes_non_centered )
111+ def _indices_to_codes (self , indices ):
112+ level_indices = self .indices_to_level_indices (indices )
113+ codes = self ._scale_and_shift_inverse (level_indices )
115114 return codes
116115
117- def codes_to_indices (self , zhat : Tensor ) -> Tensor :
118- """Converts a `code` to an index in the codebook."""
116+ def codes_to_indices (self , zhat ) :
117+ """ Converts a `code` to an index in the codebook. """
119118 assert zhat .shape [- 1 ] == self .codebook_dim
120119 zhat = self ._scale_and_shift (zhat )
121120 return (zhat * self ._basis ).sum (dim = - 1 ).to (int32 )
122121
123- def indices_to_codes (
124- self ,
125- indices : Tensor
126- ) -> Tensor :
127- """Inverse of `codes_to_indices`."""
122+ def indices_to_level_indices (self , indices ):
123+ """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
124+ indices = rearrange (indices , '... -> ... 1' )
125+ codes_non_centered = (indices // self ._basis ) % self ._levels
126+ return codes_non_centered
127+
128+ def indices_to_codes (self , indices ):
129+ """ Inverse of `codes_to_indices`. """
128130
129131 is_img_or_video = indices .ndim >= (3 + int (self .keep_num_codebooks_dim ))
130132
@@ -141,7 +143,7 @@ def indices_to_codes(
141143 return codes
142144
143145 @autocast (enabled = False )
144- def forward (self , z : Tensor ) -> Tensor :
146+ def forward (self , z ) :
145147 """
146148 einstein notation
147149 b - batch
0 commit comments