@@ -43,8 +43,9 @@ def gumbel_sample(t, temperature = 1., dim = -1):
4343
4444 return ((t / temperature ) + gumbel_noise (t )).argmax (dim = dim )
4545
46- def laplace_smoothing (x , n_categories , eps = 1e-5 ):
47- return (x + eps ) / (x .sum () + n_categories * eps )
46+ def laplace_smoothing (x , n_categories , eps = 1e-5 , dim = - 1 ):
47+ denom = x .sum (dim = dim , keepdim = True )
48+ return (x + eps ) / (denom + n_categories * eps )
4849
4950def sample_vectors (samples , num ):
5051 num_samples , device = samples .shape [0 ], samples .device
@@ -305,7 +306,7 @@ def forward(self, x):
305306 self .all_reduce_fn (embed_sum .contiguous ())
306307 self .embed_avg .data .lerp_ (embed_sum , 1 - self .decay )
307308
308- cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum ()
309+ cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
309310
310311 embed_normalized = self .embed_avg / rearrange (cluster_size , '... -> ... 1' )
311312 self .embed .data .copy_ (embed_normalized )
@@ -450,7 +451,7 @@ def forward(self, x):
450451 self .all_reduce_fn (embed_sum )
451452 self .embed_avg .data .lerp_ (embed_sum , 1 - self .decay )
452453
453- cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum ()
454+ cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum (dim = - 1 , keepdim = True )
454455
455456 embed_normalized = self .embed_avg / rearrange (cluster_size , '... -> ... 1' )
456457 embed_normalized = l2norm (embed_normalized )
0 commit comments