@@ -304,8 +304,10 @@ def __init__(
304304 self .register_buffer ('batch_mean' , None )
305305 self .register_buffer ('batch_variance' , None )
306306
307- self .register_buffer ('codebook_mean' , None )
308- self .register_buffer ('codebook_variance' , None )
307+ self .register_buffer ('codebook_mean_needs_init' , torch .Tensor ([True ]))
308+ self .register_buffer ('codebook_mean' , torch .empty (num_codebooks , 1 , dim ))
309+ self .register_buffer ('codebook_variance_needs_init' , torch .Tensor ([True ]))
310+ self .register_buffer ('codebook_variance' , torch .empty (num_codebooks , 1 , dim ))
309311
310312 @torch .jit .ignore
311313 def init_embed_ (self , data ):
@@ -329,8 +331,14 @@ def init_embed_(self, data):
329331 def update_with_decay (self , buffer_name , new_value , decay ):
330332 old_value = getattr (self , buffer_name )
331333
332- if not exists (old_value ):
334+ needs_init = getattr (self , buffer_name + "_needs_init" , False )
335+
336+ if needs_init :
337+ self .register_buffer (buffer_name + "_needs_init" , torch .Tensor ([False ]))
338+
339+ if not exists (old_value ) or needs_init :
333340 self .register_buffer (buffer_name , new_value .detach ())
341+
334342 return
335343
336344 value = old_value * decay + new_value .detach () * (1 - decay )
@@ -344,8 +352,9 @@ def update_affine(self, data, embed):
344352
345353 data , embed = map (lambda t : rearrange (t , 'h ... d -> h (...) d' ), (data , embed ))
346354
347- self .update_with_decay ('codebook_mean' , reduce (embed , 'h n d -> h 1 d' , 'mean' ), self .affine_param_codebook_decay )
348- self .update_with_decay ('codebook_variance' , reduce (embed , 'h n d -> h 1 d' , var_fn ), self .affine_param_codebook_decay )
355+ if self .training :
356+ self .update_with_decay ('codebook_mean' , reduce (embed , 'h n d -> h 1 d' , 'mean' ), self .affine_param_codebook_decay )
357+ self .update_with_decay ('codebook_variance' , reduce (embed , 'h n d -> h 1 d' , var_fn ), self .affine_param_codebook_decay )
349358
350359 if not self .sync_affine_param :
351360 self .update_with_decay ('batch_mean' , reduce (data , 'h n d -> h 1 d' , 'mean' ), self .affine_param_batch_decay )
0 commit comments