@@ -273,12 +273,14 @@ def efficient_rotation_trick_transform(u, q, e):
273273 e = rearrange (e , 'b d -> b 1 d' )
274274 w = l2norm (u + q , dim = 1 ).detach ()
275275
276- return (
276+ out = (
277277 e -
278278 2 * (e @ rearrange (w , 'b d -> b d 1' ) @ rearrange (w , 'b d -> b 1 d' )) +
279279 2 * (e @ rearrange (u , 'b d -> b d 1' ).detach () @ rearrange (q , 'b d -> b 1 d' ).detach ())
280280 )
281281
282+ return rearrange (out , '... 1 -> ...' )
283+
282284def rotate_to (src , tgt ):
283285 # rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
284286 src , inverse = pack_one (src , '* d' )
@@ -291,7 +293,7 @@ def rotate_to(src, tgt):
291293 safe_div (src , norm_src ),
292294 safe_div (tgt , norm_tgt ),
293295 src
294- ). squeeze ()
296+ )
295297
296298 rotated = rotated_tgt * safe_div (norm_tgt , norm_src ).detach ()
297299
@@ -896,7 +898,7 @@ def __init__(
896898 stochastic_sample_codes = False ,
897899 sample_codebook_temp = 1. ,
898900 straight_through = False ,
899- rotation_trick = True , # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424 by @cfifty
901+ rotation_trick = None , # Propagate grads through VQ layer w/ rotation trick: https://arxiv.org/abs/2410.06424 by @cfifty
900902 sync_codebook = None ,
901903 sync_affine_param = False ,
902904 ema_update = True ,
@@ -911,6 +913,8 @@ def __init__(
911913 return_zeros_for_masked_padding = True
912914 ):
913915 super ().__init__ ()
916+ rotation_trick = default (rotation_trick , dim > 1 ) # only use rotation trick if feature dimension greater than 1
917+
914918 self .dim = dim
915919 self .heads = heads
916920 self .separate_codebook_per_head = separate_codebook_per_head
0 commit comments