Skip to content

Commit 6f65e72

Browse files
committed
prepare for audiolm to be able to reconstruct from only coarse quantize signals
1 parent 33c380d commit 6f65e72

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.10.10',
6+
version = '0.10.11',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/residual_vq.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
from torch import nn
6+
import torch.nn.functional as F
67
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
78

89
from einops import rearrange, repeat
@@ -48,11 +49,22 @@ def codebooks(self):
4849
return codebooks
4950

5051
def get_codes_from_indices(self, indices):
51-
batch = indices.shape[0]
52+
batch, quantize_dim = indices.shape[0], indices.shape[-1]
53+
54+
# because of quantize dropout, one can pass in indices that are coarse
55+
# and the network should be able to reconstruct
56+
57+
if quantize_dim < self.num_quantizers:
58+
assert self.quantize_dropout > 0., 'quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations'
59+
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value = -1)
60+
61+
# get ready for gathering
62+
5263
codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
5364
gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])
5465

5566
# take care of quantizer dropout
67+
5668
mask = gather_indices == -1.
5769
gather_indices = gather_indices.masked_fill(mask, 0) # have it fetch a dummy code to be masked out later
5870

0 commit comments

Comments
 (0)