Skip to content

Commit a1f36ee

Browse files
sayakpauldg845
andauthored
[Z-Image] various small changes, Z-Image transformer tests, etc. (#12741)
* start zimage model tests. * up * up * up * up * up * up * up * up * up * up * up * up * Revert "up" This reverts commit bca3e27. * expand upon compilation failure reason. * Update tests/models/transformers/test_models_transformer_z_image.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * reinitialize the padding tokens to ones to prevent NaN problems. * updates * up * skipping ZImage DiT tests * up * up --------- Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
1 parent d96cbac commit a1f36ee

File tree

6 files changed

+424
-80
lines changed

6 files changed

+424
-80
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 16 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ...models.normalization import RMSNorm
2828
from ...utils.torch_utils import maybe_allow_in_graph
2929
from ..attention_dispatch import dispatch_attention_fn
30+
from ..modeling_outputs import Transformer2DModelOutput
3031

3132

3233
ADALN_EMBED_DIM = 256
@@ -39,17 +40,9 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
3940
if mid_size is None:
4041
mid_size = out_size
4142
self.mlp = nn.Sequential(
42-
nn.Linear(
43-
frequency_embedding_size,
44-
mid_size,
45-
bias=True,
46-
),
43+
nn.Linear(frequency_embedding_size, mid_size, bias=True),
4744
nn.SiLU(),
48-
nn.Linear(
49-
mid_size,
50-
out_size,
51-
bias=True,
52-
),
45+
nn.Linear(mid_size, out_size, bias=True),
5346
)
5447

5548
self.frequency_embedding_size = frequency_embedding_size
@@ -211,9 +204,7 @@ def __init__(
211204

212205
self.modulation = modulation
213206
if modulation:
214-
self.adaLN_modulation = nn.Sequential(
215-
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
216-
)
207+
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
217208

218209
def forward(
219210
self,
@@ -230,33 +221,19 @@ def forward(
230221

231222
# Attention block
232223
attn_out = self.attention(
233-
self.attention_norm1(x) * scale_msa,
234-
attention_mask=attn_mask,
235-
freqs_cis=freqs_cis,
224+
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
236225
)
237226
x = x + gate_msa * self.attention_norm2(attn_out)
238227

239228
# FFN block
240-
x = x + gate_mlp * self.ffn_norm2(
241-
self.feed_forward(
242-
self.ffn_norm1(x) * scale_mlp,
243-
)
244-
)
229+
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
245230
else:
246231
# Attention block
247-
attn_out = self.attention(
248-
self.attention_norm1(x),
249-
attention_mask=attn_mask,
250-
freqs_cis=freqs_cis,
251-
)
232+
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
252233
x = x + self.attention_norm2(attn_out)
253234

254235
# FFN block
255-
x = x + self.ffn_norm2(
256-
self.feed_forward(
257-
self.ffn_norm1(x),
258-
)
259-
)
236+
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
260237

261238
return x
262239

@@ -404,10 +381,7 @@ def __init__(
404381
]
405382
)
406383
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
407-
self.cap_embedder = nn.Sequential(
408-
RMSNorm(cap_feat_dim, eps=norm_eps),
409-
nn.Linear(cap_feat_dim, dim, bias=True),
410-
)
384+
self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
411385

412386
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
413387
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
@@ -494,11 +468,8 @@ def patchify_and_embed(
494468
)
495469

496470
# padded feature
497-
cap_padded_feat = torch.cat(
498-
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
499-
dim=0,
500-
)
501-
all_cap_feats_out.append(cap_padded_feat if cap_padding_len > 0 else cap_feat)
471+
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
472+
all_cap_feats_out.append(cap_padded_feat)
502473

503474
### Process Image
504475
C, F, H, W = image.size()
@@ -564,6 +535,7 @@ def forward(
564535
cap_feats: List[torch.Tensor],
565536
patch_size=2,
566537
f_patch_size=1,
538+
return_dict: bool = True,
567539
):
568540
assert patch_size in self.all_patch_size
569541
assert f_patch_size in self.all_f_patch_size
@@ -672,4 +644,7 @@ def forward(
672644
unified = list(unified.unbind(dim=0))
673645
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
674646

675-
return x, {}
647+
if not return_dict:
648+
return (x,)
649+
650+
return Transformer2DModelOutput(sample=x)

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,9 +525,7 @@ def __call__(
525525
latent_model_input_list = list(latent_model_input.unbind(dim=0))
526526

527527
model_out_list = self.transformer(
528-
latent_model_input_list,
529-
timestep_model_input,
530-
prompt_embeds_model_input,
528+
latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False
531529
)[0]
532530

533531
if apply_cfg:

tests/lora/test_lora_layers_z_image.py

Lines changed: 134 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,13 @@
1515
import sys
1616
import unittest
1717

18+
import numpy as np
1819
import torch
1920
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
2021

21-
from diffusers import (
22-
AutoencoderKL,
23-
FlowMatchEulerDiscreteScheduler,
24-
ZImagePipeline,
25-
ZImageTransformer2DModel,
26-
)
22+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
2723

28-
from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend
24+
from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device
2925

3026

3127
if is_peft_available():
@@ -34,13 +30,9 @@
3430

3531
sys.path.append(".")
3632

37-
from .utils import PeftLoraLoaderMixinTests # noqa: E402
33+
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
3834

3935

40-
@unittest.skip(
41-
"ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations "
42-
"and torch.empty padding tokens. LoRA functionality works correctly with real models."
43-
)
4436
@require_peft_backend
4537
class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
4638
pipeline_class = ZImagePipeline
@@ -127,6 +119,12 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No
127119
tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id)
128120

129121
transformer = self.transformer_cls(**self.transformer_kwargs)
122+
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
123+
# This can cause NaN data values in our testing environment. Fixating them
124+
# helps prevent that issue.
125+
with torch.no_grad():
126+
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
127+
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
130128
vae = self.vae_cls(**self.vae_kwargs)
131129

132130
if scheduler_cls is None:
@@ -161,3 +159,127 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No
161159
}
162160

163161
return pipeline_components, text_lora_config, denoiser_lora_config
162+
163+
def test_correct_lora_configs_with_different_ranks(self):
164+
components, _, denoiser_lora_config = self.get_dummy_components()
165+
pipe = self.pipeline_class(**components)
166+
pipe = pipe.to(torch_device)
167+
pipe.set_progress_bar_config(disable=None)
168+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
169+
170+
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
171+
172+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
173+
174+
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
175+
176+
pipe.transformer.delete_adapters("adapter-1")
177+
178+
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
179+
for name, _ in denoiser.named_modules():
180+
if "to_k" in name and "attention" in name and "lora" not in name:
181+
module_name_to_rank_update = name.replace(".base_layer.", ".")
182+
break
183+
184+
# change the rank_pattern
185+
updated_rank = denoiser_lora_config.r * 2
186+
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
187+
188+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
189+
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
190+
191+
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
192+
193+
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
194+
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
195+
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
196+
197+
pipe.transformer.delete_adapters("adapter-1")
198+
199+
# similarly change the alpha_pattern
200+
updated_alpha = denoiser_lora_config.lora_alpha * 2
201+
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
202+
203+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
204+
self.assertTrue(
205+
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
206+
)
207+
208+
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
209+
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
210+
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
211+
212+
@skip_mps
213+
def test_lora_fuse_nan(self):
214+
components, _, denoiser_lora_config = self.get_dummy_components()
215+
pipe = self.pipeline_class(**components)
216+
pipe = pipe.to(torch_device)
217+
pipe.set_progress_bar_config(disable=None)
218+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
219+
220+
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
221+
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
222+
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
223+
224+
# corrupt one LoRA weight with `inf` values
225+
with torch.no_grad():
226+
possible_tower_names = ["noise_refiner"]
227+
filtered_tower_names = [
228+
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
229+
]
230+
for tower_name in filtered_tower_names:
231+
transformer_tower = getattr(pipe.transformer, tower_name)
232+
transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf")
233+
234+
# with `safe_fusing=True` we should see an Error
235+
with self.assertRaises(ValueError):
236+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
237+
238+
# without we should not see an error, but every image will be black
239+
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
240+
out = pipe(**inputs)[0]
241+
242+
self.assertTrue(np.isnan(out).all())
243+
244+
def test_lora_scale_kwargs_match_fusion(self):
245+
super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2)
246+
247+
@unittest.skip("Needs to be debugged.")
248+
def test_set_adapters_match_attention_kwargs(self):
249+
super().test_set_adapters_match_attention_kwargs()
250+
251+
@unittest.skip("Needs to be debugged.")
252+
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
253+
super().test_simple_inference_with_text_denoiser_lora_and_scale()
254+
255+
@unittest.skip("Not supported in ZImage.")
256+
def test_simple_inference_with_text_denoiser_block_scale(self):
257+
pass
258+
259+
@unittest.skip("Not supported in ZImage.")
260+
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
261+
pass
262+
263+
@unittest.skip("Not supported in ZImage.")
264+
def test_modify_padding_mode(self):
265+
pass
266+
267+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
268+
def test_simple_inference_with_partial_text_lora(self):
269+
pass
270+
271+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
272+
def test_simple_inference_with_text_lora(self):
273+
pass
274+
275+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
276+
def test_simple_inference_with_text_lora_and_scale(self):
277+
pass
278+
279+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
280+
def test_simple_inference_with_text_lora_fused(self):
281+
pass
282+
283+
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
284+
def test_simple_inference_with_text_lora_save_load(self):
285+
pass

0 commit comments

Comments
 (0)