Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 15 additions & 40 deletions src/diffusers/models/transformers/transformer_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...models.normalization import RMSNorm
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_dispatch import dispatch_attention_fn
from ..modeling_outputs import Transformer2DModelOutput


ADALN_EMBED_DIM = 256
Expand All @@ -39,17 +40,9 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
if mid_size is None:
mid_size = out_size
self.mlp = nn.Sequential(
nn.Linear(
frequency_embedding_size,
mid_size,
bias=True,
),
nn.Linear(frequency_embedding_size, mid_size, bias=True),
nn.SiLU(),
nn.Linear(
mid_size,
out_size,
bias=True,
),
nn.Linear(mid_size, out_size, bias=True),
)

self.frequency_embedding_size = frequency_embedding_size
Expand Down Expand Up @@ -211,9 +204,7 @@ def __init__(

self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
)
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))

def forward(
self,
Expand All @@ -230,33 +221,19 @@ def forward(

# Attention block
attn_out = self.attention(
self.attention_norm1(x) * scale_msa,
attention_mask=attn_mask,
freqs_cis=freqs_cis,
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
)
x = x + gate_msa * self.attention_norm2(attn_out)

# FFN block
x = x + gate_mlp * self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x) * scale_mlp,
)
)
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
else:
# Attention block
attn_out = self.attention(
self.attention_norm1(x),
attention_mask=attn_mask,
freqs_cis=freqs_cis,
)
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
x = x + self.attention_norm2(attn_out)

# FFN block
x = x + self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x),
)
)
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))

return x

Expand Down Expand Up @@ -404,10 +381,7 @@ def __init__(
]
)
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(cap_feat_dim, dim, bias=True),
)
self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))

self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
Expand Down Expand Up @@ -492,10 +466,7 @@ def patchify_and_embed(
)
)
# padded feature
cap_padded_feat = torch.cat(
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
dim=0,
)
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
all_cap_feats_out.append(cap_padded_feat)

### Process Image
Expand Down Expand Up @@ -557,6 +528,7 @@ def forward(
cap_feats: List[torch.Tensor],
patch_size=2,
f_patch_size=1,
return_dict: bool = True,
):
assert patch_size in self.all_patch_size
assert f_patch_size in self.all_f_patch_size
Expand Down Expand Up @@ -658,4 +630,7 @@ def forward(
unified = list(unified.unbind(dim=0))
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)

return x, {}
if not return_dict:
return (x,)

return Transformer2DModelOutput(sample=x)
Comment on lines -661 to +636
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a very safe change?

4 changes: 1 addition & 3 deletions src/diffusers/pipelines/z_image/pipeline_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,7 @@ def __call__(
latent_model_input_list = list(latent_model_input.unbind(dim=0))

model_out_list = self.transformer(
latent_model_input_list,
timestep_model_input,
prompt_embeds_model_input,
latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False
)[0]

if apply_cfg:
Expand Down
Loading
Loading