Skip to content

Commit 500db14

Browse files
committed
make sure get_codes_from_indices work for grouped rvq
1 parent bf53137 commit 500db14

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.4.0',
6+
version = '1.4.1',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/residual_vq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ def __init__(
220220
def codebooks(self):
221221
return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
222222

223+
def get_codes_from_indices(self, indices):
224+
codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
225+
return torch.stack(codes)
226+
223227
def forward(
224228
self,
225229
x,

0 commit comments

Comments
 (0)