Skip to content

Commit bf53137

Browse files
committed
add grouped residual vq technique from hifi-codec paper
1 parent 09a778f commit bf53137

File tree

4 files changed

+111
-2
lines changed

4 files changed

+111
-2
lines changed

README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,28 @@ quantized, indices, commit_loss = residual_vq(x)
8181
# (batch, seq, dim), (quantizer, batch, seq), (quantizer, batch)
8282
```
8383

84+
<a href="https://arxiv.org/abs/2305.02765">A recent paper</a> further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing `GroupedResidualVQ`
85+
86+
```python
87+
import torch
88+
from vector_quantize_pytorch import GroupedResidualVQ
89+
90+
residual_vq = GroupedResidualVQ(
91+
dim = 256,
92+
num_quantizers = 8, # specify number of quantizers
93+
groups = 2,
94+
codebook_size = 1024, # codebook size
95+
)
96+
97+
x = torch.randn(1, 1024, 256)
98+
99+
quantized, indices, commit_loss = residual_vq(x)
100+
101+
# (1, 1024, 256), (1, 1024, 8), (1, 8)
102+
# (batch, seq, dim), (groups, batch, seq, quantizer), (groups, batch, quantizer)
103+
104+
```
105+
84106
## Initialization
85107

86108
The SoundStream paper proposes that the codebook should be initialized by the kmeans centroids of the first batch. You can easily turn on this feature with one flag `kmeans_init = True`, for either `VectorQuantize` or `ResidualVQ` class
@@ -375,3 +397,12 @@ if __name__ == '__main__':
375397
year = {2023}
376398
}
377399
```
400+
401+
```bibtex
402+
@inproceedings{Yang2023HiFiCodecGV,
403+
title = {HiFi-Codec: Group-residual Vector quantization for High Fidelity Audio Codec},
404+
author = {Dongchao Yang and Songxiang Liu and Rongjie Huang and Jinchuan Tian and Chao Weng and Yuexian Zou},
405+
year = {2023}
406+
}
407+
```
408+

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.2.3',
6+
version = '1.4.0',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize
2-
from vector_quantize_pytorch.residual_vq import ResidualVQ
2+
from vector_quantize_pytorch.residual_vq import ResidualVQ, GroupedResidualVQ
33
from vector_quantize_pytorch.random_projection_quantizer import RandomProjectionQuantizer

vector_quantize_pytorch/residual_vq.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from math import ceil
22
from functools import partial
3+
from itertools import zip_longest
34
from random import randrange
45

56
import torch
@@ -14,6 +15,9 @@
1415
def exists(val):
1516
return val is not None
1617

18+
def default(val, d):
19+
return val if exists(val) else d
20+
1721
def round_up_multiple(num, mult):
1822
return ceil(num / mult) * mult
1923

@@ -183,3 +187,77 @@ def forward(
183187
ret = (*ret, all_codes)
184188

185189
return ret
190+
191+
# grouped residual vq
192+
193+
class GroupedResidualVQ(nn.Module):
194+
def __init__(
195+
self,
196+
*,
197+
dim,
198+
groups = 1,
199+
accept_image_fmap = False,
200+
**kwargs
201+
):
202+
super().__init__()
203+
self.dim = dim
204+
self.groups = groups
205+
assert (dim % groups) == 0
206+
dim_per_group = dim // groups
207+
208+
self.accept_image_fmap = accept_image_fmap
209+
210+
self.rvqs = nn.ModuleList([])
211+
212+
for _ in range(groups):
213+
self.rvqs.append(ResidualVQ(
214+
dim = dim_per_group,
215+
accept_image_fmap = accept_image_fmap,
216+
**kwargs
217+
))
218+
219+
@property
220+
def codebooks(self):
221+
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
222+
223+
def forward(
224+
self,
225+
x,
226+
indices = None,
227+
return_all_codes = False
228+
):
229+
shape = x.shape
230+
split_dim = 1 if self.accept_image_fmap else -1
231+
assert shape[split_dim] == self.dim
232+
233+
# split the feature dimension into groups
234+
235+
x = x.chunk(self.groups, dim = split_dim)
236+
237+
indices = default(indices, tuple())
238+
return_ce_loss = len(indices) > 0
239+
assert len(indices) == 0 or len(indices) == self.groups
240+
241+
forward_kwargs = dict(return_all_codes = return_all_codes)
242+
243+
# invoke residual vq on each group
244+
245+
out = tuple(rvq(chunk, indices = chunk_indices, **forward_kwargs) for rvq, chunk, chunk_indices in zip_longest(self.rvqs, x, indices))
246+
out = tuple(zip(*out))
247+
248+
# if returning cross entropy loss to rvq codebooks
249+
250+
if return_ce_loss:
251+
quantized, ce_losses = out
252+
return torch.cat(quantized, dim = split_dim), sum(ce_losses)
253+
254+
# otherwise, get all the zipped outputs and combine them
255+
256+
quantized, all_indices, commit_losses, *maybe_all_codes = out
257+
258+
quantized = torch.cat(quantized, dim = split_dim)
259+
all_indices = torch.stack(all_indices)
260+
commit_losses = torch.stack(commit_losses)
261+
262+
ret = (quantized, all_indices, commit_losses, *maybe_all_codes)
263+
return ret

0 commit comments

Comments
 (0)