@@ -37,9 +37,6 @@ def gumbel_sample(t, temperature = 1., dim = -1):
3737
3838 return ((t / temperature ) + gumbel_noise (t )).argmax (dim = dim )
3939
40- def ema_inplace (moving_avg , new , decay ):
41- moving_avg .data .mul_ (decay ).add_ (new , alpha = (1 - decay ))
42-
4340def laplace_smoothing (x , n_categories , eps = 1e-5 ):
4441 return (x + eps ) / (x .sum () + n_categories * eps )
4542
@@ -289,11 +286,11 @@ def forward(self, x):
289286 cluster_size = embed_onehot .sum (dim = 1 )
290287
291288 self .all_reduce_fn (cluster_size )
292- ema_inplace ( self .cluster_size , cluster_size , self .decay )
289+ self .cluster_size . data . lerp_ ( cluster_size , 1 - self .decay )
293290
294291 embed_sum = einsum ('h n d, h n c -> h c d' , flatten , embed_onehot )
295292 self .all_reduce_fn (embed_sum .contiguous ())
296- ema_inplace ( self .embed_avg , embed_sum , self .decay )
293+ self .embed_avg . data . lerp_ ( embed_sum , 1 - self .decay )
297294
298295 cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum ()
299296
@@ -421,7 +418,7 @@ def forward(self, x):
421418 bins = embed_onehot .sum (dim = 1 )
422419 self .all_reduce_fn (bins )
423420
424- ema_inplace ( self .cluster_size , bins , self .decay )
421+ self .cluster_size . data . lerp_ ( bins , 1 - self .decay )
425422
426423 zero_mask = (bins == 0 )
427424 bins = bins .masked_fill (zero_mask , 1. )
@@ -438,7 +435,7 @@ def forward(self, x):
438435 embed_normalized
439436 )
440437
441- ema_inplace ( self .embed , embed_normalized , self .decay )
438+ self .embed . data . lerp_ ( embed_normalized , 1 - self .decay )
442439 self .expire_codes_ (x )
443440
444441 if needs_codebook_dim :
0 commit comments