@@ -304,8 +304,10 @@ def __init__(
304304 self .register_buffer ('batch_mean' , None )
305305 self .register_buffer ('batch_variance' , None )
306306
307- self .register_buffer ('codebook_mean' , None )
308- self .register_buffer ('codebook_variance' , None )
307+ self .register_buffer ('codebook_mean_needs_init' , torch .Tensor ([True ]))
308+ self .register_buffer ('codebook_mean' , torch .empty (num_codebooks , 1 , dim ))
309+ self .register_buffer ('codebook_variance_needs_init' , torch .Tensor ([True ]))
310+ self .register_buffer ('codebook_variance' , torch .empty (num_codebooks , 1 , dim ))
309311
310312 @torch .jit .ignore
311313 def init_embed_ (self , data ):
@@ -329,8 +331,14 @@ def init_embed_(self, data):
329331 def update_with_decay (self , buffer_name , new_value , decay ):
330332 old_value = getattr (self , buffer_name )
331333
332- if not exists (old_value ):
334+ needs_init = getattr (self , buffer_name + "_needs_init" , False )
335+
336+ if needs_init :
337+ self .register_buffer (buffer_name + "_needs_init" , torch .Tensor ([False ]))
338+
339+ if not exists (old_value ) or needs_init :
333340 self .register_buffer (buffer_name , new_value .detach ())
341+
334342 return
335343
336344 value = old_value * decay + new_value .detach () * (1 - decay )
@@ -419,7 +427,7 @@ def forward(
419427
420428 self .init_embed_ (flatten )
421429
422- if self .affine_param :
430+ if self .affine_param and self . training :
423431 self .update_affine (flatten , self .embed )
424432
425433 embed = self .embed if self .learnable_codebook else self .embed .detach ()
0 commit comments