@@ -61,7 +61,7 @@ def gumbel_sample(
6161 assert not (reinmax and not straight_through ), 'reinmax can only be turned on if using straight through gumbel softmax'
6262
6363 if not straight_through :
64- return ind , one_hot
64+ return ind , one_hot , None
6565
6666 # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
6767 # algorithm 2
@@ -78,7 +78,9 @@ def gumbel_sample(
7878 π1 = (logits / temperature ).softmax (dim = dim )
7979 one_hot = one_hot + π1 - π1 .detach ()
8080
81- return ind , one_hot
81+ st_mult = one_hot .gather (- 1 , rearrange (ind , '... -> ... 1' )) # multiplier for straight-through
82+
83+ return ind , one_hot , st_mult
8284
8385def laplace_smoothing (x , n_categories , eps = 1e-5 , dim = - 1 ):
8486 denom = x .sum (dim = dim , keepdim = True )
@@ -333,11 +335,16 @@ def forward(self, x):
333335
334336 dist = - torch .cdist (flatten , embed , p = 2 )
335337
336- embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 )
338+ embed_ind , embed_onehot , straight_through_mult = self .gumbel_sample (dist , dim = - 1 )
339+
337340 embed_ind = unpack_one (embed_ind , ps , 'h *' )
338341
339342 quantize = batched_embedding (embed_ind , self .embed )
340343
344+ if exists (straight_through_mult ):
345+ mult = unpack_one (straight_through_mult , ps , 'h * d' )
346+ quantize = quantize * mult
347+
341348 if self .training :
342349 cluster_size = embed_onehot .sum (dim = 1 )
343350
@@ -476,11 +483,15 @@ def forward(self, x):
476483 embed = self .embed if not self .learnable_codebook else self .embed .detach ()
477484
478485 dist = einsum ('h n d, h c d -> h n c' , flatten , embed )
479- embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 )
486+ embed_ind , embed_onehot , straight_through_mult = self .gumbel_sample (dist , dim = - 1 )
480487 embed_ind = unpack_one (embed_ind , ps , 'h *' )
481488
482489 quantize = batched_embedding (embed_ind , self .embed )
483490
491+ if exists (straight_through_mult ):
492+ mult = unpack_one (straight_through_mult , ps , 'h * d' )
493+ quantize = quantize * mult
494+
484495 if self .training :
485496 bins = embed_onehot .sum (dim = 1 )
486497 self .all_reduce_fn (bins )
0 commit comments