|
3 | 3 |
|
4 | 4 | import torch |
5 | 5 | from torch import nn |
| 6 | +import torch.nn.functional as F |
6 | 7 | from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize |
7 | 8 |
|
8 | 9 | from einops import rearrange, repeat |
@@ -48,11 +49,22 @@ def codebooks(self): |
48 | 49 | return codebooks |
49 | 50 |
|
50 | 51 | def get_codes_from_indices(self, indices): |
51 | | - batch = indices.shape[0] |
| 52 | + batch, quantize_dim = indices.shape[0], indices.shape[-1] |
| 53 | + |
| 54 | + # because of quantize dropout, one can pass in indices that are coarse |
| 55 | + # and the network should be able to reconstruct |
| 56 | + |
| 57 | + if quantize_dim < self.num_quantizers: |
| 58 | + assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations' |
| 59 | + indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1) |
| 60 | + |
| 61 | + # get ready for gathering |
| 62 | + |
52 | 63 | codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch) |
53 | 64 | gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1]) |
54 | 65 |
|
55 | 66 | # take care of quantizer dropout |
| 67 | + |
56 | 68 | mask = gather_indices == -1. |
57 | 69 | gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later |
58 | 70 |
|
|
0 commit comments