@@ -345,17 +345,30 @@ def update_with_decay(self, buffer_name, new_value, decay):
345345 self .register_buffer (buffer_name , value )
346346
347347 @torch .jit .ignore
348- def update_affine (self , data , embed ):
348+ def update_affine (self , data , embed , mask = None ):
349349 assert self .affine_param
350350
351351 var_fn = partial (torch .var , unbiased = False )
352352
353- data , embed = map (lambda t : rearrange (t , 'h ... d -> h (...) d' ), (data , embed ))
353+ # calculate codebook mean and variance
354+
355+ embed = rearrange (embed , 'h ... d -> h (...) d' )
354356
355357 if self .training :
356358 self .update_with_decay ('codebook_mean' , reduce (embed , 'h n d -> h 1 d' , 'mean' ), self .affine_param_codebook_decay )
357359 self .update_with_decay ('codebook_variance' , reduce (embed , 'h n d -> h 1 d' , var_fn ), self .affine_param_codebook_decay )
358360
361+ # prepare batch data, which depends on whether it has masking
362+
363+ data = rearrange (data , 'h ... d -> h (...) d' )
364+
365+ if exists (mask ):
366+ h = data .shape [0 ]
367+ mask = repeat (mask , 'b n -> h (b n)' , h = h )
368+ data = rearrange (data [mask ], '(h n) d -> h n d' , h = h )
369+
370+ # calculate batch mean and variance
371+
359372 if not self .sync_affine_param :
360373 self .update_with_decay ('batch_mean' , reduce (data , 'h n d -> h 1 d' , 'mean' ), self .affine_param_batch_decay )
361374 self .update_with_decay ('batch_variance' , reduce (data , 'h n d -> h 1 d' , var_fn ), self .affine_param_batch_decay )
@@ -413,7 +426,8 @@ def expire_codes_(self, batch_samples):
413426 def forward (
414427 self ,
415428 x ,
416- sample_codebook_temp = None
429+ sample_codebook_temp = None ,
430+ mask = None
417431 ):
418432 needs_codebook_dim = x .ndim < 4
419433 sample_codebook_temp = default (sample_codebook_temp , self .sample_codebook_temp )
@@ -429,7 +443,7 @@ def forward(
429443 self .init_embed_ (flatten )
430444
431445 if self .affine_param :
432- self .update_affine (flatten , self .embed )
446+ self .update_affine (flatten , self .embed , mask = mask )
433447
434448 embed = self .embed if self .learnable_codebook else self .embed .detach ()
435449
@@ -582,7 +596,8 @@ def expire_codes_(self, batch_samples):
582596 def forward (
583597 self ,
584598 x ,
585- sample_codebook_temp = None
599+ sample_codebook_temp = None ,
600+ mask = None
586601 ):
587602 needs_codebook_dim = x .ndim < 4
588603 sample_codebook_temp = default (sample_codebook_temp , self .sample_codebook_temp )
@@ -783,9 +798,12 @@ def forward(
783798 mask = None ,
784799 sample_codebook_temp = None
785800 ):
801+ orig_input = x
802+
786803 only_one = x .ndim == 2
787804
788805 if only_one :
806+ assert not exists (mask )
789807 x = rearrange (x , 'b d -> b 1 d' )
790808
791809 shape , device , heads , is_multiheaded , codebook_size , return_loss = x .shape , x .device , self .heads , self .heads > 1 , self .codebook_size , exists (indices )
@@ -816,9 +834,16 @@ def forward(
816834
817835 x = self ._codebook .transform_input (x )
818836
837+ # codebook forward kwargs
838+
839+ codebook_forward_kwargs = dict (
840+ sample_codebook_temp = sample_codebook_temp ,
841+ mask = mask
842+ )
843+
819844 # quantize
820845
821- quantize , embed_ind , distances = self ._codebook (x , sample_codebook_temp = sample_codebook_temp )
846+ quantize , embed_ind , distances = self ._codebook (x , ** codebook_forward_kwargs )
822847
823848 # one step in-place update
824849
@@ -827,8 +852,9 @@ def forward(
827852 self .in_place_codebook_optimizer .step ()
828853 self .in_place_codebook_optimizer .zero_grad ()
829854
830- # Quantize again
831- quantize , embed_ind , distances = self ._codebook (x , sample_codebook_temp = sample_codebook_temp )
855+ # quantize again
856+
857+ quantize , embed_ind , distances = self ._codebook (x , ** codebook_forward_kwargs )
832858
833859 if self .training :
834860 # determine code to use for commitment loss
@@ -891,9 +917,9 @@ def calculate_ce_loss(codes):
891917 if self .commitment_use_cross_entropy_loss :
892918 if exists (mask ):
893919 if is_multiheaded :
894- mask = repeat (mask , 'b n -> b n h' , h = heads )
920+ mask_with_heads = repeat (mask , 'b n -> b n h' , h = heads )
895921
896- embed_ind .masked_fill_ (~ mask , - 1 )
922+ embed_ind .masked_fill_ (~ mask_with_heads , - 1 )
897923
898924 commit_loss = calculate_ce_loss (embed_ind )
899925 else :
@@ -902,9 +928,9 @@ def calculate_ce_loss(codes):
902928 commit_loss = F .mse_loss (commit_quantize , x , reduction = 'none' )
903929
904930 if is_multiheaded :
905- mask = repeat (mask , 'b n -> c (b h) n' , c = commit_loss .shape [0 ], h = commit_loss .shape [1 ] // mask .shape [0 ])
931+ mask_with_heads = repeat (mask , 'b n -> c (b h) n' , c = commit_loss .shape [0 ], h = commit_loss .shape [1 ] // mask .shape [0 ])
906932
907- commit_loss = commit_loss [mask ].mean ()
933+ commit_loss = commit_loss [mask_with_heads ].mean ()
908934 else :
909935 commit_loss = F .mse_loss (commit_quantize , x )
910936
@@ -952,4 +978,13 @@ def calculate_ce_loss(codes):
952978 if only_one :
953979 quantize = rearrange (quantize , 'b 1 d -> b d' )
954980
981+ # if masking, only return quantized for where mask has True
982+
983+ if exists (mask ):
984+ quantize = torch .where (
985+ rearrange (mask , '... -> ... 1' ),
986+ quantize ,
987+ orig_input
988+ )
989+
955990 return quantize , embed_ind , loss
0 commit comments