Skip to content

Commit 61b821d

Browse files
committed
handle only 1 embedding given, release 1.0.0, should be matured
1 parent 4467f8d commit 61b821d

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '0.10.15',
6+
version = '1.0.0',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,11 @@ def forward(
526526
x,
527527
mask = None
528528
):
529+
only_one = x.ndim == 2
530+
531+
if only_one:
532+
x = rearrange(x, 'b d -> b 1 d')
533+
529534
shape, device, heads, is_multiheaded, codebook_size = x.shape, x.device, self.heads, self.heads > 1, self.codebook_size
530535

531536
need_transpose = not self.channel_last and not self.accept_image_fmap
@@ -600,4 +605,8 @@ def forward(
600605
quantize = rearrange(quantize, 'b (h w) c -> b c h w', h = height, w = width)
601606
embed_ind = rearrange(embed_ind, 'b (h w) ... -> b h w ...', h = height, w = width)
602607

608+
if only_one:
609+
quantize = rearrange(quantize, 'b 1 d -> b d')
610+
embed_ind = rearrange(embed_ind, 'b 1 -> b')
611+
603612
return quantize, embed_ind, loss

0 commit comments

Comments
 (0)