@@ -358,6 +358,7 @@ def __init__(
358358
359359 self .register_buffer ('initted' , torch .Tensor ([not kmeans_init ]))
360360 self .register_buffer ('cluster_size' , torch .zeros (num_codebooks , codebook_size ))
361+ self .register_buffer ('embed_avg' , embed .clone ())
361362
362363 self .learnable_codebook = learnable_codebook
363364 if learnable_codebook :
@@ -380,6 +381,7 @@ def init_embed_(self, data):
380381 )
381382
382383 self .embed .data .copy_ (embed )
384+ self .embed_avg .data .copy_ (embed .clone ())
383385 self .cluster_size .data .copy_ (cluster_size )
384386 self .initted .data .copy_ (torch .Tensor ([True ]))
385387
@@ -394,7 +396,8 @@ def replace(self, batch_samples, batch_mask):
394396 sampled = rearrange (sampled , '1 ... -> ...' )
395397
396398 self .embed .data [ind ][mask ] = sampled
397- self .cluster_size .data [ind ][mask ] = self .reset_cluster_size
399+ self .embed_avg .data [ind ][mask ] = sampled * self .reset_cluster_size
400+ self .cluster_size .data [ind ][mask ] = self .reset_cluster_size
398401
399402 def expire_codes_ (self , batch_samples ):
400403 if self .threshold_ema_dead_code == 0 :
@@ -445,8 +448,11 @@ def forward(self, x):
445448
446449 embed_sum = einsum ('h n d, h n c -> h c d' , flatten , embed_onehot )
447450 self .all_reduce_fn (embed_sum )
451+ self .embed_avg .data .lerp_ (embed_sum , 1 - self .decay )
448452
449- embed_normalized = embed_sum / rearrange (bins , '... -> ... 1' )
453+ cluster_size = laplace_smoothing (self .cluster_size , self .codebook_size , self .eps ) * self .cluster_size .sum ()
454+
455+ embed_normalized = self .embed_avg / rearrange (cluster_size , '... -> ... 1' )
450456 embed_normalized = l2norm (embed_normalized )
451457
452458 embed_normalized = torch .where (
@@ -456,6 +462,8 @@ def forward(self, x):
456462 )
457463
458464 self .embed .data .lerp_ (embed_normalized , 1 - self .decay )
465+ self .embed .data .copy_ (l2norm (self .embed ))
466+
459467 self .expire_codes_ (x )
460468
461469 if needs_codebook_dim :
0 commit comments