@@ -430,10 +430,7 @@ def update_affine(self, data, embed, mask = None):
430430 self .update_with_decay ('batch_variance' , batch_variance , self .affine_param_batch_decay )
431431
432432 def replace (self , batch_samples , batch_mask ):
433- for ind , (samples , mask ) in enumerate (zip (batch_samples .unbind (dim = 0 ), batch_mask .unbind (dim = 0 ))):
434- if not torch .any (mask ):
435- continue
436-
433+ for ind , (samples , mask ) in enumerate (zip (batch_samples , batch_mask )):
437434 sampled = self .replace_sample_fn (rearrange (samples , '... -> 1 ...' ), mask .sum ().item ())
438435 sampled = rearrange (sampled , '1 ... -> ...' )
439436
@@ -619,10 +616,7 @@ def init_embed_(self, data, mask = None):
619616 def replace (self , batch_samples , batch_mask ):
620617 batch_samples = l2norm (batch_samples )
621618
622- for ind , (samples , mask ) in enumerate (zip (batch_samples .unbind (dim = 0 ), batch_mask .unbind (dim = 0 ))):
623- if not torch .any (mask ):
624- continue
625-
619+ for ind , (samples , mask ) in enumerate (zip (batch_samples , batch_mask )):
626620 sampled = self .replace_sample_fn (rearrange (samples , '... -> 1 ...' ), mask .sum ().item ())
627621 sampled = rearrange (sampled , '1 ... -> ...' )
628622
0 commit comments