|
14 | 14 | from torch import Tensor, int32 |
15 | 15 | from torch.amp import autocast |
16 | 16 |
|
| 17 | +import einx |
17 | 18 | from einops import rearrange, pack, unpack |
18 | 19 |
|
19 | 20 | import random |
@@ -45,11 +46,15 @@ def unpack_one(t, ps, pattern): |
45 | 46 |
|
46 | 47 | # tensor helpers |
47 | 48 |
|
48 | | -def round_ste(z: Tensor) -> Tensor: |
| 49 | +def round_ste(z): |
49 | 50 | """Round with straight through gradients.""" |
50 | 51 | zhat = z.round() |
51 | 52 | return z + (zhat - z).detach() |
52 | 53 |
|
| 54 | +def floor_ste(z): |
| 55 | + zhat = z.floor() |
| 56 | + return z + (zhat - z).detach() |
| 57 | + |
53 | 58 | # main class |
54 | 59 |
|
55 | 60 | class FSQ(Module): |
@@ -127,41 +132,43 @@ def symmetry_preserving_bound(self, z): |
127 | 132 | levels_minus_1 = (self._levels - 1) |
128 | 133 | scale = 2.0 / levels_minus_1 |
129 | 134 | bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.0) + 0.5 |
| 135 | + bracket = floor_ste(bracket) |
130 | 136 | return scale * bracket - 1.0 |
131 | 137 |
|
132 | | - def quantize(self, z, preserve_symmetry = False): |
| 138 | + def quantize(self, z): |
133 | 139 | """ Quantizes z, returns quantized zhat, same shape as z. """ |
134 | 140 |
|
| 141 | + preserve_symmetry = self.preserve_symmetry |
135 | 142 | half_width = self._levels // 2 |
136 | 143 |
|
137 | | - if self.training: |
138 | | - unquantized = z |
139 | | - |
140 | | - # determine where to quantize elementwise |
141 | | - |
142 | | - quantize_mask = torch.bernoulli( |
143 | | - torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device) |
144 | | - ).bool().expand_as(z) |
145 | | - |
146 | | - if preserve_symmetry: |
147 | | - quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width |
148 | | - else: |
149 | | - quantized = round_ste(self.bound(z)) / half_width |
150 | | - quantized = torch.where(quantize_mask, unquantized, quantized) |
151 | | - |
152 | | - # determine where to add a random offset elementwise |
153 | | - |
154 | | - offset_mask = torch.bernoulli( |
155 | | - torch.full([z.shape[0], 1, 1, 1], self.noise_dropout, device = z.device) |
156 | | - ).bool().expand_as(z) |
157 | | - |
158 | | - offset = (torch.rand_like(z) - 0.5) / half_width |
159 | | - quantized = torch.where(offset_mask, unquantized + offset, quantized) |
160 | | - elif preserve_symmetry: |
| 144 | + if preserve_symmetry: |
161 | 145 | quantized = round_ste(self.symmetry_preserving_bound(z)) / half_width |
162 | 146 | else: |
163 | 147 | quantized = round_ste(self.bound(z)) / half_width |
164 | 148 |
|
| 149 | + if not self.training: |
| 150 | + return quantized |
| 151 | + |
| 152 | + batch, device, noise_dropout = z.shape[0], z.device, self.noise_dropout |
| 153 | + unquantized = z |
| 154 | + |
| 155 | + # determine where to quantize elementwise |
| 156 | + |
| 157 | + quantize_mask = torch.bernoulli( |
| 158 | + torch.full((batch,), noise_dropout, device = device) |
| 159 | + ).bool() |
| 160 | + |
| 161 | + quantized = torch.where(quantize_mask, unquantized, quantized) |
| 162 | + |
| 163 | + # determine where to add a random offset elementwise |
| 164 | + |
| 165 | + offset_mask = torch.bernoulli( |
| 166 | + torch.full((batch,), noise_dropout, device = device) |
| 167 | + ).bool() |
| 168 | + |
| 169 | + offset = (torch.rand_like(z) - 0.5) / half_width |
| 170 | + quantized = einx.where('b, b ..., b ...', offset_mask, unquantized + offset, quantized) |
| 171 | + |
165 | 172 | return quantized |
166 | 173 |
|
167 | 174 | def _scale_and_shift(self, zhat_normalized): |
@@ -242,7 +249,7 @@ def forward(self, z): |
242 | 249 | if force_f32 and orig_dtype not in self.allowed_dtypes: |
243 | 250 | z = z.float() |
244 | 251 |
|
245 | | - codes = self.quantize(z, preserve_symmetry=self.preserve_symmetry) |
| 252 | + codes = self.quantize(z) |
246 | 253 |
|
247 | 254 | # returning indices could be optional |
248 | 255 |
|
|
0 commit comments