Skip to content

Commit 970d416

Browse files
committed
break out a function needed for audiolm
1 parent e7ced71 commit 970d416

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.10.2',
6+
version = '0.10.3',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/residual_vq.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ def codebooks(self):
3737
codebooks = rearrange(codebooks, 'q 1 c d -> q c d')
3838
return codebooks
3939

40+
def get_codes_from_indices(self, indices):
41+
batch = indices.shape[0]
42+
codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = batch)
43+
gather_indices = repeat(indices, 'b n q -> q b n d', d = codebooks.shape[-1])
44+
45+
all_codes = codebooks.gather(2, gather_indices) # gather all codes
46+
return all_codes
47+
4048
def forward(
4149
self,
4250
x,
@@ -62,11 +70,7 @@ def forward(
6270

6371
if return_all_codes:
6472
# whether to return all codes from all codebooks across layers
65-
66-
codebooks = repeat(self.codebooks, 'q c d -> q b c d', b = x.shape[0])
67-
gather_indices = repeat(all_indices, 'b n q -> q b n d', d = codebooks.shape[-1])
68-
69-
all_codes = codebooks.gather(2, gather_indices) # gather all codes
73+
all_codes = self.get_codes_from_indices(all_indices)
7074

7175
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
7276
ret = (*ret, all_codes)

0 commit comments

Comments
 (0)