Skip to content

Commit 4b14ddd

Browse files
committed
add pipeline test
1 parent ea301df commit 4b14ddd

File tree

1 file changed

+189
-0
lines changed

1 file changed

+189
-0
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright 2024 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
import torch
19+
from PIL import Image
20+
from transformers import AutoTokenizer, T5EncoderModel
21+
22+
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
23+
from diffusers.utils.testing_utils import enable_full_determinism
24+
25+
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
26+
from ..test_pipelines_common import PipelineTesterMixin
27+
28+
29+
enable_full_determinism()
30+
31+
32+
class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
33+
pipeline_class = WanVACEPipeline
34+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
35+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
36+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
37+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
38+
required_optional_params = frozenset(
39+
[
40+
"num_inference_steps",
41+
"generator",
42+
"latents",
43+
"return_dict",
44+
"callback_on_step_end",
45+
"callback_on_step_end_tensor_inputs",
46+
]
47+
)
48+
test_xformers_attention = False
49+
supports_dduf = False
50+
51+
def get_dummy_components(self):
52+
torch.manual_seed(0)
53+
vae = AutoencoderKLWan(
54+
base_dim=3,
55+
z_dim=16,
56+
dim_mult=[1, 1, 1, 1],
57+
num_res_blocks=1,
58+
temperal_downsample=[False, True, True],
59+
)
60+
61+
torch.manual_seed(0)
62+
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
63+
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
64+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
65+
66+
torch.manual_seed(0)
67+
transformer = WanVACETransformer3DModel(
68+
patch_size=(1, 2, 2),
69+
num_attention_heads=2,
70+
attention_head_dim=12,
71+
in_channels=16,
72+
out_channels=16,
73+
text_dim=32,
74+
freq_dim=256,
75+
ffn_dim=32,
76+
num_layers=3,
77+
cross_attn_norm=True,
78+
qk_norm="rms_norm_across_heads",
79+
rope_max_seq_len=32,
80+
vace_layers=[0, 2],
81+
vace_in_channels=96,
82+
)
83+
84+
components = {
85+
"transformer": transformer,
86+
"vae": vae,
87+
"scheduler": scheduler,
88+
"text_encoder": text_encoder,
89+
"tokenizer": tokenizer,
90+
}
91+
return components
92+
93+
def get_dummy_inputs(self, device, seed=0):
94+
if str(device).startswith("mps"):
95+
generator = torch.manual_seed(seed)
96+
else:
97+
generator = torch.Generator(device=device).manual_seed(seed)
98+
99+
num_frames = 17
100+
height = 16
101+
width = 16
102+
103+
video = [Image.new("RGB", (height, width))] * num_frames
104+
mask = [Image.new("L", (height, width), 0)] * num_frames
105+
106+
inputs = {
107+
"video": video,
108+
"mask": mask,
109+
"prompt": "dance monkey",
110+
"negative_prompt": "negative", # TODO
111+
"generator": generator,
112+
"num_inference_steps": 2,
113+
"guidance_scale": 6.0,
114+
"height": 16,
115+
"width": 16,
116+
"num_frames": num_frames,
117+
"max_sequence_length": 16,
118+
"output_type": "pt",
119+
}
120+
return inputs
121+
122+
def test_inference(self):
123+
device = "cpu"
124+
125+
components = self.get_dummy_components()
126+
pipe = self.pipeline_class(**components)
127+
pipe.to(device)
128+
pipe.set_progress_bar_config(disable=None)
129+
130+
inputs = self.get_dummy_inputs(device)
131+
video = pipe(**inputs).frames
132+
generated_video = video[0]
133+
134+
self.assertEqual(generated_video.shape, (17, 3, 16, 16))
135+
expected_video = torch.randn(17, 3, 16, 16)
136+
max_diff = np.abs(generated_video - expected_video).max()
137+
self.assertLessEqual(max_diff, 1e10)
138+
139+
def test_inference_with_single_reference_image(self):
140+
device = "cpu"
141+
142+
components = self.get_dummy_components()
143+
pipe = self.pipeline_class(**components)
144+
pipe.to(device)
145+
pipe.set_progress_bar_config(disable=None)
146+
147+
inputs = self.get_dummy_inputs(device)
148+
inputs["reference_images"] = Image.new("RGB", (16, 16))
149+
video = pipe(**inputs).frames
150+
generated_video = video[0]
151+
152+
self.assertEqual(generated_video.shape, (17, 3, 16, 16))
153+
expected_video = torch.randn(17, 3, 16, 16)
154+
max_diff = np.abs(generated_video - expected_video).max()
155+
self.assertLessEqual(max_diff, 1e10)
156+
157+
def test_inference_with_multiple_reference_image(self):
158+
device = "cpu"
159+
160+
components = self.get_dummy_components()
161+
pipe = self.pipeline_class(**components)
162+
pipe.to(device)
163+
pipe.set_progress_bar_config(disable=None)
164+
165+
inputs = self.get_dummy_inputs(device)
166+
inputs["reference_images"] = [[Image.new("RGB", (16, 16))] * 2]
167+
video = pipe(**inputs).frames
168+
generated_video = video[0]
169+
170+
self.assertEqual(generated_video.shape, (17, 3, 16, 16))
171+
expected_video = torch.randn(17, 3, 16, 16)
172+
max_diff = np.abs(generated_video - expected_video).max()
173+
self.assertLessEqual(max_diff, 1e10)
174+
175+
@unittest.skip("Test not supported")
176+
def test_attention_slicing_forward_pass(self):
177+
pass
178+
179+
@unittest.skip("Errors out because passing multiple prompts at once is not yet supported by this pipeline.")
180+
def test_encode_prompt_works_in_isolation(self):
181+
pass
182+
183+
@unittest.skip("Batching is not yet supported with this pipeline")
184+
def test_inference_batch_consistent(self):
185+
pass
186+
187+
@unittest.skip("Batching is not yet supported with this pipeline")
188+
def test_inference_batch_single_identical(self):
189+
return super().test_inference_batch_single_identical()

0 commit comments

Comments
 (0)