66import torch .distributed as distributed
77from torch .cuda .amp import autocast
88
9- from einops import rearrange , repeat , pack , unpack
9+ from einops import rearrange , repeat , reduce , pack , unpack
1010from contextlib import contextmanager
1111
1212def exists (val ):
@@ -239,7 +239,10 @@ def __init__(
239239 learnable_codebook = False ,
240240 gumbel_sample = gumbel_sample ,
241241 sample_codebook_temp = 1. ,
242- ema_update = True
242+ ema_update = True ,
243+ affine_param = False ,
244+ affine_param_batch_decay = 0.99 ,
245+ affine_param_codebook_decay = 0.9
243246 ):
244247 super ().__init__ ()
245248 self .transform_input = identity
@@ -278,6 +281,22 @@ def __init__(
278281 else :
279282 self .register_buffer ('embed' , embed )
280283
284+ # affine related params
285+
286+ self .affine_param = affine_param
287+
288+ if not affine_param :
289+ return
290+
291+ self .affine_param_batch_decay = affine_param_batch_decay
292+ self .affine_param_codebook_decay = affine_param_codebook_decay
293+
294+ self .register_buffer ('batch_mean' , None )
295+ self .register_buffer ('batch_variance' , None )
296+
297+ self .register_buffer ('codebook_mean' , None )
298+ self .register_buffer ('codebook_variance' , None )
299+
281300 @torch .jit .ignore
282301 def init_embed_ (self , data ):
283302 if self .initted :
@@ -296,6 +315,29 @@ def init_embed_(self, data):
296315 self .cluster_size .data .copy_ (cluster_size )
297316 self .initted .data .copy_ (torch .Tensor ([True ]))
298317
318+ @torch .jit .ignore
319+ def update_with_decay (self , buffer_name , new_value , decay ):
320+ old_value = getattr (self , buffer_name )
321+
322+ if not exists (old_value ):
323+ self .register_buffer (buffer_name , new_value )
324+ return
325+
326+ value = old_value * decay + new_value * (1 - decay )
327+ self .register_buffer (buffer_name , value )
328+
329+ @torch .jit .ignore
330+ def update_affine (self , data , embed ):
331+ assert self .affine_param
332+
333+ var_fn = partial (torch .var , unbiased = False )
334+
335+ self .update_with_decay ('batch_mean' , reduce (data , '... d -> d' , 'mean' ), self .affine_param_batch_decay )
336+ self .update_with_decay ('batch_variance' , reduce (data , '... d -> d' , var_fn ), self .affine_param_batch_decay )
337+
338+ self .update_with_decay ('codebook_mean' , reduce (embed , '... d -> d' , 'mean' ), self .affine_param_codebook_decay )
339+ self .update_with_decay ('codebook_variance' , reduce (embed , '... d -> d' , var_fn ), self .affine_param_codebook_decay )
340+
299341 def replace (self , batch_samples , batch_mask ):
300342 for ind , (samples , mask ) in enumerate (zip (batch_samples .unbind (dim = 0 ), batch_mask .unbind (dim = 0 ))):
301343 if not torch .any (mask ):
@@ -340,8 +382,16 @@ def forward(
340382
341383 self .init_embed_ (flatten )
342384
385+ if self .affine_param :
386+ self .update_affine (flatten , self .embed )
387+
343388 embed = self .embed if not self .learnable_codebook else self .embed .detach ()
344389
390+ if self .affine_param :
391+ codebook_std = self .codebook_variance .clamp (min = 1e-5 ).sqrt ()
392+ batch_std = self .batch_variance .clamp (min = 1e-5 ).sqrt ()
393+ embed = (embed - self .codebook_mean ) * (batch_std / codebook_std ) + self .batch_mean
394+
345395 dist = - torch .cdist (flatten , embed , p = 2 )
346396
347397 embed_ind , embed_onehot , straight_through_mult = self .gumbel_sample (dist , dim = - 1 , temperature = sample_codebook_temp , training = self .training )
@@ -355,6 +405,10 @@ def forward(
355405 quantize = quantize * mult
356406
357407 if self .training and self .ema_update :
408+
409+ if self .affine_param :
410+ flatten = (flatten - self .batch_mean ) * (codebook_std / batch_std ) + self .codebook_mean
411+
358412 cluster_size = embed_onehot .sum (dim = 1 )
359413
360414 self .all_reduce_fn (cluster_size )
@@ -565,8 +619,10 @@ def __init__(
565619 reinmax = False , # using reinmax for improved straight-through, assuming straight through helps at all
566620 sync_codebook = False ,
567621 ema_update = True ,
568- learnable_codebook = False
569-
622+ learnable_codebook = False ,
623+ affine_param = False ,
624+ affine_param_batch_decay = 0.99 ,
625+ affine_param_codebook_decay = 0.9
570626 ):
571627 super ().__init__ ()
572628 self .dim = dim
@@ -598,7 +654,7 @@ def __init__(
598654 straight_through = straight_through
599655 )
600656
601- self . _codebook = codebook_class (
657+ codebook_kwargs = dict (
602658 dim = codebook_dim ,
603659 num_codebooks = heads if separate_codebook_per_head else 1 ,
604660 codebook_size = codebook_size ,
@@ -615,6 +671,17 @@ def __init__(
615671 ema_update = ema_update
616672 )
617673
674+ if affine_param :
675+ assert not use_cosine_sim
676+ codebook_kwargs = dict (
677+ ** codebook_kwargs ,
678+ affine_param = True ,
679+ affine_param_batch_decay = affine_param_batch_decay ,
680+ affine_param_codebook_decay = affine_param_codebook_decay
681+ )
682+
683+ self ._codebook = codebook_class (** codebook_kwargs )
684+
618685 self .codebook_size = codebook_size
619686
620687 self .accept_image_fmap = accept_image_fmap
0 commit comments