Skip to content

Commit 9379b23

Browse files
JerryWu-codelime-jgithub-actions[bot]
authored
Fix TPU (torch_xla) compatibility Error about tensor repeat func along with empty dim. (#12770)
* Refactor image padding logic to pervent zero tensor in transformer_z_image.py * Apply style fixes * Add more support to fix repeat bug on tpu devices. * Fix for dynamo compile error for multi if-branches. --------- Co-authored-by: Mingjia Li <mingjiali@tju.edu.cn> Co-authored-by: Mingjia Li <mail@mingjia.li> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 4f136f8 commit 9379b23

File tree

1 file changed

+44
-30
lines changed

1 file changed

+44
-30
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -482,21 +482,23 @@ def patchify_and_embed(
482482
).flatten(0, 2)
483483
all_cap_pos_ids.append(cap_padded_pos_ids)
484484
# pad mask
485+
cap_pad_mask = torch.cat(
486+
[
487+
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
488+
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
489+
],
490+
dim=0,
491+
)
485492
all_cap_pad_mask.append(
486-
torch.cat(
487-
[
488-
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
489-
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
490-
],
491-
dim=0,
492-
)
493+
cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
493494
)
495+
494496
# padded feature
495497
cap_padded_feat = torch.cat(
496498
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
497499
dim=0,
498500
)
499-
all_cap_feats_out.append(cap_padded_feat)
501+
all_cap_feats_out.append(cap_padded_feat if cap_padding_len > 0 else cap_feat)
500502

501503
### Process Image
502504
C, F, H, W = image.size()
@@ -515,30 +517,35 @@ def patchify_and_embed(
515517
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
516518
device=device,
517519
).flatten(0, 2)
518-
image_padding_pos_ids = (
519-
self.create_coordinate_grid(
520-
size=(1, 1, 1),
521-
start=(0, 0, 0),
522-
device=device,
523-
)
524-
.flatten(0, 2)
525-
.repeat(image_padding_len, 1)
520+
image_padded_pos_ids = torch.cat(
521+
[
522+
image_ori_pos_ids,
523+
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
524+
.flatten(0, 2)
525+
.repeat(image_padding_len, 1),
526+
],
527+
dim=0,
526528
)
527-
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
528-
all_image_pos_ids.append(image_padded_pos_ids)
529+
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
529530
# pad mask
531+
image_pad_mask = torch.cat(
532+
[
533+
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
534+
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
535+
],
536+
dim=0,
537+
)
530538
all_image_pad_mask.append(
531-
torch.cat(
532-
[
533-
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
534-
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
535-
],
536-
dim=0,
537-
)
539+
image_pad_mask
540+
if image_padding_len > 0
541+
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
538542
)
539543
# padded feature
540-
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
541-
all_image_out.append(image_padded_feat)
544+
image_padded_feat = torch.cat(
545+
[image, image[-1:].repeat(image_padding_len, 1)],
546+
dim=0,
547+
)
548+
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
542549

543550
return (
544551
all_image_out,
@@ -588,10 +595,13 @@ def forward(
588595
adaln_input = t.type_as(x)
589596
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
590597
x = list(x.split(x_item_seqlens, dim=0))
591-
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
598+
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
592599

593600
x = pad_sequence(x, batch_first=True, padding_value=0.0)
594601
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
602+
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
603+
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
604+
595605
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
596606
for i, seq_len in enumerate(x_item_seqlens):
597607
x_attn_mask[i, :seq_len] = 1
@@ -605,17 +615,21 @@ def forward(
605615

606616
# cap embed & refine
607617
cap_item_seqlens = [len(_) for _ in cap_feats]
608-
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
609618
cap_max_item_seqlen = max(cap_item_seqlens)
610619

611620
cap_feats = torch.cat(cap_feats, dim=0)
612621
cap_feats = self.cap_embedder(cap_feats)
613622
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
614623
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
615-
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
624+
cap_freqs_cis = list(
625+
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
626+
)
616627

617628
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
618629
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
630+
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
631+
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
632+
619633
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
620634
for i, seq_len in enumerate(cap_item_seqlens):
621635
cap_attn_mask[i, :seq_len] = 1

0 commit comments

Comments
 (0)