@@ -610,6 +610,34 @@ def patchify_and_embed(
610610 all_cap_pad_mask ,
611611 )
612612
613+ def patchify (
614+ self ,
615+ all_image : List [torch .Tensor ],
616+ patch_size : int ,
617+ f_patch_size : int ,
618+ ):
619+ pH = pW = patch_size
620+ pF = f_patch_size
621+ all_image_out = []
622+
623+ for i , image in enumerate (all_image ):
624+ ### Process Image
625+ C , F , H , W = image .size ()
626+ F_tokens , H_tokens , W_tokens = F // pF , H // pH , W // pW
627+
628+ image = image .view (C , F_tokens , pF , H_tokens , pH , W_tokens , pW )
629+ # "c f pf h ph w pw -> (f h w) (pf ph pw c)"
630+ image = image .permute (1 , 3 , 5 , 2 , 4 , 6 , 0 ).reshape (F_tokens * H_tokens * W_tokens , pF * pH * pW * C )
631+
632+ image_ori_len = len (image )
633+ image_padding_len = (- image_ori_len ) % SEQ_MULTI_OF
634+
635+ # padded feature
636+ image_padded_feat = torch .cat ([image , image [- 1 :].repeat (image_padding_len , 1 )], dim = 0 )
637+ all_image_out .append (image_padded_feat )
638+
639+ return all_image_out
640+
613641 def forward (
614642 self ,
615643 x : List [torch .Tensor ],
@@ -719,6 +747,7 @@ def forward(
719747
720748 controlnet_block_samples = None
721749 if control_context is not None :
750+ control_context = self .patchify (control_context , patch_size , f_patch_size )
722751 control_context = torch .cat (control_context , dim = 0 )
723752 control_context = self .control_all_x_embedder [f"{ patch_size } -{ f_patch_size } " ](control_context )
724753
0 commit comments