|
15 | 15 | import sys |
16 | 16 | import unittest |
17 | 17 |
|
| 18 | +import numpy as np |
18 | 19 | import torch |
19 | 20 | from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model |
20 | 21 |
|
21 | | -from diffusers import ( |
22 | | - AutoencoderKL, |
23 | | - FlowMatchEulerDiscreteScheduler, |
24 | | - ZImagePipeline, |
25 | | - ZImageTransformer2DModel, |
26 | | -) |
| 22 | +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel |
27 | 23 |
|
28 | | -from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend |
| 24 | +from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device |
29 | 25 |
|
30 | 26 |
|
31 | 27 | if is_peft_available(): |
|
34 | 30 |
|
35 | 31 | sys.path.append(".") |
36 | 32 |
|
37 | | -from .utils import PeftLoraLoaderMixinTests # noqa: E402 |
| 33 | +from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 |
38 | 34 |
|
39 | 35 |
|
40 | | -@unittest.skip( |
41 | | - "ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations " |
42 | | - "and torch.empty padding tokens. LoRA functionality works correctly with real models." |
43 | | -) |
44 | 36 | @require_peft_backend |
45 | 37 | class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): |
46 | 38 | pipeline_class = ZImagePipeline |
@@ -127,6 +119,12 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No |
127 | 119 | tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id) |
128 | 120 |
|
129 | 121 | transformer = self.transformer_cls(**self.transformer_kwargs) |
| 122 | + # `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`. |
| 123 | + # This can cause NaN data values in our testing environment. Fixating them |
| 124 | + # helps prevent that issue. |
| 125 | + with torch.no_grad(): |
| 126 | + transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data)) |
| 127 | + transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data)) |
130 | 128 | vae = self.vae_cls(**self.vae_kwargs) |
131 | 129 |
|
132 | 130 | if scheduler_cls is None: |
@@ -161,3 +159,127 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=No |
161 | 159 | } |
162 | 160 |
|
163 | 161 | return pipeline_components, text_lora_config, denoiser_lora_config |
| 162 | + |
| 163 | + def test_correct_lora_configs_with_different_ranks(self): |
| 164 | + components, _, denoiser_lora_config = self.get_dummy_components() |
| 165 | + pipe = self.pipeline_class(**components) |
| 166 | + pipe = pipe.to(torch_device) |
| 167 | + pipe.set_progress_bar_config(disable=None) |
| 168 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 169 | + |
| 170 | + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 171 | + |
| 172 | + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") |
| 173 | + |
| 174 | + lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 175 | + |
| 176 | + pipe.transformer.delete_adapters("adapter-1") |
| 177 | + |
| 178 | + denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer |
| 179 | + for name, _ in denoiser.named_modules(): |
| 180 | + if "to_k" in name and "attention" in name and "lora" not in name: |
| 181 | + module_name_to_rank_update = name.replace(".base_layer.", ".") |
| 182 | + break |
| 183 | + |
| 184 | + # change the rank_pattern |
| 185 | + updated_rank = denoiser_lora_config.r * 2 |
| 186 | + denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank} |
| 187 | + |
| 188 | + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") |
| 189 | + updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern |
| 190 | + |
| 191 | + self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank}) |
| 192 | + |
| 193 | + lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 194 | + self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)) |
| 195 | + self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)) |
| 196 | + |
| 197 | + pipe.transformer.delete_adapters("adapter-1") |
| 198 | + |
| 199 | + # similarly change the alpha_pattern |
| 200 | + updated_alpha = denoiser_lora_config.lora_alpha * 2 |
| 201 | + denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha} |
| 202 | + |
| 203 | + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") |
| 204 | + self.assertTrue( |
| 205 | + pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} |
| 206 | + ) |
| 207 | + |
| 208 | + lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] |
| 209 | + self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) |
| 210 | + self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) |
| 211 | + |
| 212 | + @skip_mps |
| 213 | + def test_lora_fuse_nan(self): |
| 214 | + components, _, denoiser_lora_config = self.get_dummy_components() |
| 215 | + pipe = self.pipeline_class(**components) |
| 216 | + pipe = pipe.to(torch_device) |
| 217 | + pipe.set_progress_bar_config(disable=None) |
| 218 | + _, _, inputs = self.get_dummy_inputs(with_generator=False) |
| 219 | + |
| 220 | + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet |
| 221 | + denoiser.add_adapter(denoiser_lora_config, "adapter-1") |
| 222 | + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") |
| 223 | + |
| 224 | + # corrupt one LoRA weight with `inf` values |
| 225 | + with torch.no_grad(): |
| 226 | + possible_tower_names = ["noise_refiner"] |
| 227 | + filtered_tower_names = [ |
| 228 | + tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name) |
| 229 | + ] |
| 230 | + for tower_name in filtered_tower_names: |
| 231 | + transformer_tower = getattr(pipe.transformer, tower_name) |
| 232 | + transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf") |
| 233 | + |
| 234 | + # with `safe_fusing=True` we should see an Error |
| 235 | + with self.assertRaises(ValueError): |
| 236 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) |
| 237 | + |
| 238 | + # without we should not see an error, but every image will be black |
| 239 | + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) |
| 240 | + out = pipe(**inputs)[0] |
| 241 | + |
| 242 | + self.assertTrue(np.isnan(out).all()) |
| 243 | + |
| 244 | + def test_lora_scale_kwargs_match_fusion(self): |
| 245 | + super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2) |
| 246 | + |
| 247 | + @unittest.skip("Needs to be debugged.") |
| 248 | + def test_set_adapters_match_attention_kwargs(self): |
| 249 | + super().test_set_adapters_match_attention_kwargs() |
| 250 | + |
| 251 | + @unittest.skip("Needs to be debugged.") |
| 252 | + def test_simple_inference_with_text_denoiser_lora_and_scale(self): |
| 253 | + super().test_simple_inference_with_text_denoiser_lora_and_scale() |
| 254 | + |
| 255 | + @unittest.skip("Not supported in ZImage.") |
| 256 | + def test_simple_inference_with_text_denoiser_block_scale(self): |
| 257 | + pass |
| 258 | + |
| 259 | + @unittest.skip("Not supported in ZImage.") |
| 260 | + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): |
| 261 | + pass |
| 262 | + |
| 263 | + @unittest.skip("Not supported in ZImage.") |
| 264 | + def test_modify_padding_mode(self): |
| 265 | + pass |
| 266 | + |
| 267 | + @unittest.skip("Text encoder LoRA is not supported in ZImage.") |
| 268 | + def test_simple_inference_with_partial_text_lora(self): |
| 269 | + pass |
| 270 | + |
| 271 | + @unittest.skip("Text encoder LoRA is not supported in ZImage.") |
| 272 | + def test_simple_inference_with_text_lora(self): |
| 273 | + pass |
| 274 | + |
| 275 | + @unittest.skip("Text encoder LoRA is not supported in ZImage.") |
| 276 | + def test_simple_inference_with_text_lora_and_scale(self): |
| 277 | + pass |
| 278 | + |
| 279 | + @unittest.skip("Text encoder LoRA is not supported in ZImage.") |
| 280 | + def test_simple_inference_with_text_lora_fused(self): |
| 281 | + pass |
| 282 | + |
| 283 | + @unittest.skip("Text encoder LoRA is not supported in ZImage.") |
| 284 | + def test_simple_inference_with_text_lora_save_load(self): |
| 285 | + pass |
0 commit comments