1+ from functools import partial
2+
13import torch
24from torch import nn , einsum
35import torch .nn .functional as F
@@ -40,11 +42,41 @@ def gumbel_noise(t):
4042 noise = torch .zeros_like (t ).uniform_ (0 , 1 )
4143 return - log (- log (noise ))
4244
43- def gumbel_sample (t , temperature = 1. , dim = - 1 ):
44- if temperature == 0 :
45- return t .argmax (dim = dim )
45+ def gumbel_sample (
46+ logits ,
47+ temperature = 1. ,
48+ stochastic = False ,
49+ straight_through = False ,
50+ reinmax = False ,
51+ dim = - 1
52+ ):
53+ dtype , size = logits .dtype , logits .shape [dim ]
54+
55+ if stochastic :
56+ logits = logits + gumbel_noise (logits )
57+
58+ ind = logits .argmax (dim = dim )
59+ one_hot = F .one_hot (ind , size ).type (dtype )
4660
47- return ((t / temperature ) + gumbel_noise (t )).argmax (dim = dim )
61+ assert not (reinmax and not straight_through ), 'reinmax can only be turned on if using straight through gumbel softmax'
62+
63+ if not straight_through :
64+ return ind , one_hot
65+
66+ # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
67+ # algorithm 2
68+
69+ if reinmax :
70+ π0 = logits .softmax (dim = dim )
71+ π1 = (one_hot + (logits / temperature ).softmax (dim = dim )) / 2
72+ π1 = ((π1 .log () - logits ).detach () + logits ).softmax (dim = 1 )
73+ π2 = 2 * π1 - 0.5 * π0
74+ one_hot = π2 - π2 .detach () + one_hot
75+ else :
76+ π1 = (logits / temperature ).softmax (dim = dim )
77+ one_hot = one_hot + π1 - π1 .detach ()
78+
79+ return ind , one_hot
4880
4981def laplace_smoothing (x , n_categories , eps = 1e-5 , dim = - 1 ):
5082 denom = x .sum (dim = dim , keepdim = True )
@@ -200,7 +232,9 @@ def __init__(
200232 reset_cluster_size = None ,
201233 use_ddp = False ,
202234 learnable_codebook = False ,
203- sample_codebook_temp = 0
235+ sample_codebook_temp = 0 ,
236+ straight_through = False ,
237+ gumbel_sample = gumbel_sample
204238 ):
205239 super ().__init__ ()
206240 self .transform_input = identity
@@ -216,7 +250,9 @@ def __init__(
216250 self .eps = eps
217251 self .threshold_ema_dead_code = threshold_ema_dead_code
218252 self .reset_cluster_size = default (reset_cluster_size , threshold_ema_dead_code )
219- self .sample_codebook_temp = sample_codebook_temp
253+
254+ assert callable (gumbel_sample )
255+ self .gumbel_sample = gumbel_sample
220256
221257 assert not (use_ddp and num_codebooks > 1 and kmeans_init ), 'kmeans init is not compatible with multiple codebooks in distributed environment for now'
222258
@@ -295,8 +331,7 @@ def forward(self, x):
295331
296332 dist = - torch .cdist (flatten , embed , p = 2 )
297333
298- embed_ind = gumbel_sample (dist , dim = - 1 , temperature = self .sample_codebook_temp )
299- embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
334+ embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 )
300335 embed_ind = unpack_one (embed_ind , ps , 'h *' )
301336
302337 quantize = batched_embedding (embed_ind , self .embed )
@@ -339,7 +374,7 @@ def __init__(
339374 reset_cluster_size = None ,
340375 use_ddp = False ,
341376 learnable_codebook = False ,
342- sample_codebook_temp = 0.
377+ gumbel_sample = gumbel_sample
343378 ):
344379 super ().__init__ ()
345380 self .transform_input = l2norm
@@ -358,7 +393,9 @@ def __init__(
358393 self .eps = eps
359394 self .threshold_ema_dead_code = threshold_ema_dead_code
360395 self .reset_cluster_size = default (reset_cluster_size , threshold_ema_dead_code )
361- self .sample_codebook_temp = sample_codebook_temp
396+
397+ assert callable (gumbel_sample )
398+ self .gumbel_sample = gumbel_sample
362399
363400 self .sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
364401 self .kmeans_all_reduce_fn = distributed .all_reduce if use_ddp and sync_kmeans else noop
@@ -437,8 +474,7 @@ def forward(self, x):
437474 embed = self .embed if not self .learnable_codebook else self .embed .detach ()
438475
439476 dist = einsum ('h n d, h c d -> h n c' , flatten , embed )
440- embed_ind = gumbel_sample (dist , dim = - 1 , temperature = self .sample_codebook_temp )
441- embed_onehot = F .one_hot (embed_ind , self .codebook_size ).type (dtype )
477+ embed_ind , embed_onehot = self .gumbel_sample (dist , dim = - 1 )
442478 embed_ind = unpack_one (embed_ind , ps , 'h *' )
443479
444480 quantize = batched_embedding (embed_ind , self .embed )
@@ -491,8 +527,11 @@ def __init__(
491527 orthogonal_reg_weight = 0. ,
492528 orthogonal_reg_active_codes_only = False ,
493529 orthogonal_reg_max_codes = None ,
530+ stochastic_sample_codes = False ,
494531 sample_codebook_temp = 0. ,
495- sync_codebook = False
532+ straight_through = False ,
533+ reinmax = False , # using reinmax for improved straight-through, assuming straight through helps at all
534+ sync_codebook = False ,
496535 ):
497536 super ().__init__ ()
498537 self .dim = dim
@@ -517,6 +556,14 @@ def __init__(
517556
518557 codebook_class = EuclideanCodebook if not use_cosine_sim else CosineSimCodebook
519558
559+ gumbel_sample_fn = partial (
560+ gumbel_sample ,
561+ stochastic = stochastic_sample_codes ,
562+ temperature = sample_codebook_temp ,
563+ reinmax = reinmax ,
564+ straight_through = straight_through
565+ )
566+
520567 self ._codebook = codebook_class (
521568 dim = codebook_dim ,
522569 num_codebooks = heads if separate_codebook_per_head else 1 ,
@@ -529,7 +576,7 @@ def __init__(
529576 threshold_ema_dead_code = threshold_ema_dead_code ,
530577 use_ddp = sync_codebook ,
531578 learnable_codebook = has_codebook_orthogonal_loss ,
532- sample_codebook_temp = sample_codebook_temp
579+ gumbel_sample = gumbel_sample_fn
533580 )
534581
535582 self .codebook_size = codebook_size
0 commit comments