Skip to content

Commit f9540cb

Browse files
committed
patchify control_context
1 parent f63a5a8 commit f9540cb

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

src/diffusers/models/transformers/transformer_z_image_control.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,34 @@ def patchify_and_embed(
610610
all_cap_pad_mask,
611611
)
612612

613+
def patchify(
614+
self,
615+
all_image: List[torch.Tensor],
616+
patch_size: int,
617+
f_patch_size: int,
618+
):
619+
pH = pW = patch_size
620+
pF = f_patch_size
621+
all_image_out = []
622+
623+
for i, image in enumerate(all_image):
624+
### Process Image
625+
C, F, H, W = image.size()
626+
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
627+
628+
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
629+
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
630+
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
631+
632+
image_ori_len = len(image)
633+
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
634+
635+
# padded feature
636+
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
637+
all_image_out.append(image_padded_feat)
638+
639+
return all_image_out
640+
613641
def forward(
614642
self,
615643
x: List[torch.Tensor],
@@ -719,6 +747,7 @@ def forward(
719747

720748
controlnet_block_samples = None
721749
if control_context is not None:
750+
control_context = self.patchify(control_context, patch_size, f_patch_size)
722751
control_context = torch.cat(control_context, dim=0)
723752
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
724753

0 commit comments

Comments
 (0)