Skip to content

Commit 6f1d669

Browse files
authored
[lora] tests for exclude_modules with Wan VACE (#11843)
* wan vace. * update * update * import problem
1 parent 0e95aa8 commit 6f1d669

File tree

1 file changed

+217
-0
lines changed

1 file changed

+217
-0
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Copyright 2025 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import sys
17+
import tempfile
18+
import unittest
19+
20+
import numpy as np
21+
import pytest
22+
import safetensors.torch
23+
import torch
24+
from PIL import Image
25+
from transformers import AutoTokenizer, T5EncoderModel
26+
27+
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
28+
from diffusers.utils.import_utils import is_peft_available
29+
from diffusers.utils.testing_utils import (
30+
floats_tensor,
31+
require_peft_backend,
32+
require_peft_version_greater,
33+
skip_mps,
34+
torch_device,
35+
)
36+
37+
38+
if is_peft_available():
39+
from peft.utils import get_peft_model_state_dict
40+
41+
sys.path.append(".")
42+
43+
from utils import PeftLoraLoaderMixinTests # noqa: E402
44+
45+
46+
@require_peft_backend
47+
@skip_mps
48+
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
49+
pipeline_class = WanVACEPipeline
50+
scheduler_cls = FlowMatchEulerDiscreteScheduler
51+
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
52+
scheduler_kwargs = {}
53+
54+
transformer_kwargs = {
55+
"patch_size": (1, 2, 2),
56+
"num_attention_heads": 2,
57+
"attention_head_dim": 8,
58+
"in_channels": 4,
59+
"out_channels": 4,
60+
"text_dim": 32,
61+
"freq_dim": 16,
62+
"ffn_dim": 16,
63+
"num_layers": 2,
64+
"cross_attn_norm": True,
65+
"qk_norm": "rms_norm_across_heads",
66+
"rope_max_seq_len": 16,
67+
"vace_layers": [0],
68+
"vace_in_channels": 72,
69+
}
70+
transformer_cls = WanVACETransformer3DModel
71+
vae_kwargs = {
72+
"base_dim": 3,
73+
"z_dim": 4,
74+
"dim_mult": [1, 1, 1, 1],
75+
"latents_mean": torch.randn(4).numpy().tolist(),
76+
"latents_std": torch.randn(4).numpy().tolist(),
77+
"num_res_blocks": 1,
78+
"temperal_downsample": [False, True, True],
79+
}
80+
vae_cls = AutoencoderKLWan
81+
has_two_text_encoders = True
82+
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
83+
text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
84+
85+
text_encoder_target_modules = ["q", "k", "v", "o"]
86+
87+
@property
88+
def output_shape(self):
89+
return (1, 9, 16, 16, 3)
90+
91+
def get_dummy_inputs(self, with_generator=True):
92+
batch_size = 1
93+
sequence_length = 16
94+
num_channels = 4
95+
num_frames = 9
96+
num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
97+
sizes = (4, 4)
98+
height, width = 16, 16
99+
100+
generator = torch.manual_seed(0)
101+
noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
102+
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
103+
video = [Image.new("RGB", (height, width))] * num_frames
104+
mask = [Image.new("L", (height, width), 0)] * num_frames
105+
106+
pipeline_inputs = {
107+
"video": video,
108+
"mask": mask,
109+
"prompt": "",
110+
"num_frames": num_frames,
111+
"num_inference_steps": 1,
112+
"guidance_scale": 6.0,
113+
"height": height,
114+
"width": height,
115+
"max_sequence_length": sequence_length,
116+
"output_type": "np",
117+
}
118+
if with_generator:
119+
pipeline_inputs.update({"generator": generator})
120+
121+
return noise, input_ids, pipeline_inputs
122+
123+
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
124+
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
125+
126+
def test_simple_inference_with_text_denoiser_lora_unfused(self):
127+
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
128+
129+
@unittest.skip("Not supported in Wan VACE.")
130+
def test_simple_inference_with_text_denoiser_block_scale(self):
131+
pass
132+
133+
@unittest.skip("Not supported in Wan VACE.")
134+
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
135+
pass
136+
137+
@unittest.skip("Not supported in Wan VACE.")
138+
def test_modify_padding_mode(self):
139+
pass
140+
141+
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
142+
def test_simple_inference_with_partial_text_lora(self):
143+
pass
144+
145+
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
146+
def test_simple_inference_with_text_lora(self):
147+
pass
148+
149+
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
150+
def test_simple_inference_with_text_lora_and_scale(self):
151+
pass
152+
153+
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
154+
def test_simple_inference_with_text_lora_fused(self):
155+
pass
156+
157+
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
158+
def test_simple_inference_with_text_lora_save_load(self):
159+
pass
160+
161+
@pytest.mark.xfail(
162+
condition=True,
163+
reason="RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same",
164+
strict=True,
165+
)
166+
def test_layerwise_casting_inference_denoiser(self):
167+
super().test_layerwise_casting_inference_denoiser()
168+
169+
@require_peft_version_greater("0.13.2")
170+
def test_lora_exclude_modules_wanvace(self):
171+
scheduler_cls = self.scheduler_classes[0]
172+
exclude_module_name = "vace_blocks.0.proj_out"
173+
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
174+
pipe = self.pipeline_class(**components).to(torch_device)
175+
_, _, inputs = self.get_dummy_inputs(with_generator=False)
176+
177+
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
178+
self.assertTrue(output_no_lora.shape == self.output_shape)
179+
180+
# only supported for `denoiser` now
181+
denoiser_lora_config.target_modules = ["proj_out"]
182+
denoiser_lora_config.exclude_modules = [exclude_module_name]
183+
pipe, _ = self.add_adapters_to_pipeline(
184+
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
185+
)
186+
# The state dict shouldn't contain the modules to be excluded from LoRA.
187+
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
188+
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
189+
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
190+
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
191+
192+
with tempfile.TemporaryDirectory() as tmpdir:
193+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
194+
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
195+
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
196+
pipe.unload_lora_weights()
197+
198+
# Check in the loaded state dict.
199+
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
200+
self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict))
201+
self.assertTrue(any("proj_out" in k for k in loaded_state_dict))
202+
203+
# Check in the state dict obtained after loading LoRA.
204+
pipe.load_lora_weights(tmpdir)
205+
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
206+
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
207+
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
208+
209+
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
210+
self.assertTrue(
211+
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
212+
"LoRA should change outputs.",
213+
)
214+
self.assertTrue(
215+
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
216+
"Lora outputs should match.",
217+
)

0 commit comments

Comments
 (0)