Skip to content

Commit da88a72

Browse files
committed
address #228
1 parent 9315776 commit da88a72

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.23.3"
3+
version = "1.23.4"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

vector_quantize_pytorch/vector_quantize_pytorch.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,14 @@ def efficient_rotation_trick_transform(u, q, e):
273273
e = rearrange(e, 'b d -> b 1 d')
274274
w = l2norm(u + q, dim = 1).detach()
275275

276-
return (
276+
out = (
277277
e -
278278
2 * (e @ rearrange(w, 'b d -> b d 1') @ rearrange(w, 'b d -> b 1 d')) +
279279
2 * (e @ rearrange(u, 'b d -> b d 1').detach() @ rearrange(q, 'b d -> b 1 d').detach())
280280
)
281281

282+
return rearrange(out, '... 1 -> ...')
283+
282284
def rotate_to(src, tgt):
283285
# rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
284286
src, inverse = pack_one(src, '* d')
@@ -291,7 +293,7 @@ def rotate_to(src, tgt):
291293
safe_div(src, norm_src),
292294
safe_div(tgt, norm_tgt),
293295
src
294-
).squeeze()
296+
)
295297

296298
rotated = rotated_tgt * safe_div(norm_tgt, norm_src).detach()
297299

@@ -896,7 +898,7 @@ def __init__(
896898
stochastic_sample_codes = False,
897899
sample_codebook_temp = 1.,
898900
straight_through = False,
899-
rotation_trick = True, # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424 by @cfifty
901+
rotation_trick = None, # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424 by @cfifty
900902
sync_codebook = None,
901903
sync_affine_param = False,
902904
ema_update = True,
@@ -911,6 +913,8 @@ def __init__(
911913
return_zeros_for_masked_padding = True
912914
):
913915
super().__init__()
916+
rotation_trick = default(rotation_trick, dim > 1) # only use rotation trick if feature dimension greater than 1
917+
914918
self.dim = dim
915919
self.heads = heads
916920
self.separate_codebook_per_head = separate_codebook_per_head

0 commit comments

Comments
 (0)