@@ -310,10 +310,14 @@ def __init__(
310310 self .register_buffer ('codebook_variance' , torch .empty (num_codebooks , 1 , dim ))
311311
312312 @torch .jit .ignore
313- def init_embed_ (self , data ):
313+ def init_embed_ (self , data , mask = None ):
314314 if self .initted :
315315 return
316316
317+ if exists (mask ):
318+ c = data .shape [0 ]
319+ data = rearrange (data [mask ], '(c n) d -> c n d' , c = c )
320+
317321 embed , cluster_size = kmeans (
318322 data ,
319323 self .codebook_size ,
@@ -363,9 +367,8 @@ def update_affine(self, data, embed, mask = None):
363367 data = rearrange (data , 'h ... d -> h (...) d' )
364368
365369 if exists (mask ):
366- h = data .shape [0 ]
367- mask = repeat (mask , 'b n -> h (b n)' , h = h )
368- data = rearrange (data [mask ], '(h n) d -> h n d' , h = h )
370+ c = data .shape [0 ]
371+ data = rearrange (data [mask ], '(c n) d -> c n d' , c = c )
369372
370373 # calculate batch mean and variance
371374
@@ -440,7 +443,10 @@ def forward(
440443 dtype = x .dtype
441444 flatten , ps = pack_one (x , 'h * d' )
442445
443- self .init_embed_ (flatten )
446+ if exists (mask ):
447+ mask = repeat (mask , 'b n -> c (b h n)' , c = flatten .shape [0 ], h = flatten .shape [- 2 ] // (mask .shape [0 ] * mask .shape [1 ]))
448+
449+ self .init_embed_ (flatten , mask = mask )
444450
445451 if self .affine_param :
446452 self .update_affine (flatten , self .embed , mask = mask )
@@ -470,7 +476,6 @@ def forward(
470476 flatten = (flatten - self .batch_mean ) * (codebook_std / batch_std ) + self .codebook_mean
471477
472478 if exists (mask ):
473- mask = repeat (mask , 'b n -> h (b n)' , h = flatten .shape [0 ])
474479 embed_onehot [~ mask ] = 0.
475480
476481 cluster_size = embed_onehot .sum (dim = 1 )
@@ -552,10 +557,14 @@ def __init__(
552557 self .register_buffer ('embed' , embed )
553558
554559 @torch .jit .ignore
555- def init_embed_ (self , data ):
560+ def init_embed_ (self , data , mask = None ):
556561 if self .initted :
557562 return
558563
564+ if exists (mask ):
565+ c = data .shape [0 ]
566+ data = rearrange (data [mask ], '(c n) d -> c n d' , c = c )
567+
559568 embed , cluster_size = kmeans (
560569 data ,
561570 self .codebook_size ,
@@ -615,7 +624,10 @@ def forward(
615624
616625 flatten , ps = pack_one (x , 'h * d' )
617626
618- self .init_embed_ (flatten )
627+ if exists (mask ):
628+ mask = repeat (mask , 'b n -> c (b h n)' , c = flatten .shape [0 ], h = flatten .shape [- 2 ] // (mask .shape [0 ] * mask .shape [1 ]))
629+
630+ self .init_embed_ (flatten , mask = mask )
619631
620632 embed = self .embed if self .learnable_codebook else self .embed .detach ()
621633
@@ -632,7 +644,6 @@ def forward(
632644
633645 if self .training and self .ema_update :
634646 if exists (mask ):
635- mask = repeat (mask , 'b n -> h (b n)' , h = flatten .shape [0 ])
636647 embed_onehot [~ mask ] = 0.
637648
638649 bins = embed_onehot .sum (dim = 1 )
@@ -856,7 +867,20 @@ def forward(
856867 # one step in-place update
857868
858869 if should_inplace_optimize and self .training :
859- F .mse_loss (quantize , x .detach ()).backward ()
870+
871+ if exists (mask ):
872+ loss = F .mse_loss (quantize , x .detach (), reduction = 'none' )
873+
874+ loss_mask = mask
875+ if is_multiheaded :
876+ loss_mask = repeat (mask , 'b n -> c (b h) n' , c = loss .shape [0 ], h = loss .shape [1 ] // mask .shape [0 ])
877+
878+ loss = loss [loss_mask ].mean ()
879+
880+ else :
881+ loss = F .mse_loss (quantize , x .detach ())
882+
883+ loss .backward ()
860884 self .in_place_codebook_optimizer .step ()
861885 self .in_place_codebook_optimizer .zero_grad ()
862886
@@ -924,21 +948,23 @@ def calculate_ce_loss(codes):
924948 if self .commitment_weight > 0 :
925949 if self .commitment_use_cross_entropy_loss :
926950 if exists (mask ):
951+ ce_loss_mask = mask
927952 if is_multiheaded :
928- mask_with_heads = repeat (mask , 'b n -> b n h' , h = heads )
953+ ce_loss_mask = repeat (ce_loss_mask , 'b n -> b n h' , h = heads )
929954
930- embed_ind .masked_fill_ (~ mask_with_heads , - 1 )
955+ embed_ind .masked_fill_ (~ ce_loss_mask , - 1 )
931956
932957 commit_loss = calculate_ce_loss (embed_ind )
933958 else :
934959 if exists (mask ):
935960 # with variable lengthed sequences
936961 commit_loss = F .mse_loss (commit_quantize , x , reduction = 'none' )
937962
963+ loss_mask = mask
938964 if is_multiheaded :
939- mask_with_heads = repeat (mask , 'b n -> c (b h) n' , c = commit_loss .shape [0 ], h = commit_loss .shape [1 ] // mask .shape [0 ])
965+ loss_mask = repeat (loss_mask , 'b n -> c (b h) n' , c = commit_loss .shape [0 ], h = commit_loss .shape [1 ] // mask .shape [0 ])
940966
941- commit_loss = commit_loss [mask_with_heads ].mean ()
967+ commit_loss = commit_loss [loss_mask ].mean ()
942968 else :
943969 commit_loss = F .mse_loss (commit_quantize , x )
944970
0 commit comments