@@ -74,7 +74,7 @@ def gumbel_sample(
7474 assert not (reinmax and not straight_through ), 'reinmax can only be turned on if using straight through gumbel softmax'
7575
7676 if not straight_through or temperature <= 0. or not training :
77- return ind , one_hot , None
77+ return ind , one_hot
7878
7979 # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
8080 # algorithm 2
@@ -89,9 +89,7 @@ def gumbel_sample(
8989 π1 = (logits / temperature ).softmax (dim = dim )
9090 one_hot = one_hot + π1 - π1 .detach ()
9191
92- st_mult = one_hot .gather (- 1 , rearrange (ind , '... -> ... 1' )) # multiplier for straight-through
93-
94- return ind , one_hot , st_mult
92+ return ind , one_hot
9593
9694def laplace_smoothing (x , n_categories , eps = 1e-5 , dim = - 1 ):
9795 denom = x .sum (dim = dim , keepdim = True )
@@ -433,15 +431,15 @@ def forward(
433431
434432 dist = - torch .cdist (flatten , embed , p = 2 )
435433
436- embed_ind , embed_onehot , straight_through_mult = self .gumbel_sample (dist , dim = - 1 , temperature = sample_codebook_temp , training = self .training )
434+ embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 , temperature = sample_codebook_temp , training = self .training )
437435
438436 embed_ind = unpack_one (embed_ind , ps , 'h *' )
439437
440- quantize = batched_embedding ( embed_ind , self .embed )
441-
442- if exists ( straight_through_mult ):
443- mult = unpack_one ( straight_through_mult , ps , 'h * d' )
444- quantize = quantize * mult
438+ if self .training :
439+ unpacked_onehot = unpack_one ( embed_onehot , ps , 'h * c' )
440+ quantize = einsum ( 'h b n c, h c d -> h b n d' , unpacked_onehot , embed )
441+ else :
442+ quantize = batched_embedding ( embed_ind , embed )
445443
446444 if self .training and self .ema_update :
447445
@@ -595,14 +593,14 @@ def forward(
595593
596594 dist = einsum ('h n d, h c d -> h n c' , flatten , embed )
597595
598- embed_ind , embed_onehot , straight_through_mult = self .gumbel_sample (dist , dim = - 1 , temperature = sample_codebook_temp , training = self .training )
596+ embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 , temperature = sample_codebook_temp , training = self .training )
599597 embed_ind = unpack_one (embed_ind , ps , 'h *' )
600598
601- quantize = batched_embedding ( embed_ind , self .embed )
602-
603- if exists ( straight_through_mult ):
604- mult = unpack_one ( straight_through_mult , ps , 'h * d' )
605- quantize = quantize * mult
599+ if self .training :
600+ unpacked_onehot = unpack_one ( embed_onehot , ps , 'h * c' )
601+ quantize = einsum ( 'h b n c, h c d -> h b n d' , unpacked_onehot , embed )
602+ else :
603+ quantize = batched_embedding ( embed_ind , embed )
606604
607605 if self .training and self .ema_update :
608606 bins = embed_onehot .sum (dim = 1 )
@@ -726,7 +724,7 @@ def __init__(
726724 )
727725
728726 if affine_param :
729- assert not use_cosine_sim
727+ assert not use_cosine_sim , 'affine param is only compatible with euclidean codebook'
730728 codebook_kwargs = dict (
731729 ** codebook_kwargs ,
732730 affine_param = True ,
0 commit comments