1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import gc
15+ import tempfile
1616import unittest
1717
18+ import numpy as np
1819import torch
1920from transformers import AutoTokenizer , T5EncoderModel
20- import numpy as np
21- import tempfile
2221
23- from diffusers import AutoencoderKLWan , WanPipeline , WanTransformer3DModel , UniPCMultistepScheduler
22+ from diffusers import AutoencoderKLWan , UniPCMultistepScheduler , WanPipeline , WanTransformer3DModel
2423from diffusers .utils .testing_utils import (
25- backend_empty_cache ,
2624 enable_full_determinism ,
27- require_torch_accelerator ,
28- slow ,
2925 torch_device ,
3026)
3127
3228from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS , TEXT_TO_IMAGE_IMAGE_PARAMS , TEXT_TO_IMAGE_PARAMS
3329from ..test_pipelines_common import PipelineTesterMixin
3430
3531
36-
37-
3832enable_full_determinism ()
3933
4034
@@ -139,7 +133,9 @@ def test_inference(self):
139133 device = "cpu"
140134
141135 components = self .get_dummy_components ()
142- pipe = self .pipeline_class (** components , )
136+ pipe = self .pipeline_class (
137+ ** components ,
138+ )
143139 pipe .to (device )
144140 pipe .set_progress_bar_config (disable = None )
145141
@@ -159,14 +155,13 @@ def test_inference(self):
159155 @unittest .skip ("Test not supported" )
160156 def test_attention_slicing_forward_pass (self ):
161157 pass
162-
163- def test_save_load_optional_components (self , expected_max_difference = 1e-4 ):
164158
159+ def test_save_load_optional_components (self , expected_max_difference = 1e-4 ):
165160 optional_component = "transformer"
166161
167162 components = self .get_dummy_components ()
168163 components [optional_component ] = None
169- components ["boundary_ratio" ] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0
164+ components ["boundary_ratio" ] = 1.0 # for wan 2.2 14B, transformer is not used when boundary_ratio is 1.0
170165
171166 pipe = self .pipeline_class (** components )
172167 for component in pipe .components .values ():
@@ -189,9 +184,10 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
189184 pipe_loaded .to (torch_device )
190185 pipe_loaded .set_progress_bar_config (disable = None )
191186
192- self .assertTrue (getattr (pipe_loaded , "transformer" ) is None ,
193- f"`transformer` did not stay set to None after loading." ,
194- )
187+ self .assertTrue (
188+ getattr (pipe_loaded , "transformer" ) is None ,
189+ "`transformer` did not stay set to None after loading." ,
190+ )
195191
196192 inputs = self .get_dummy_inputs (generator_device )
197193 torch .manual_seed (0 )
@@ -201,7 +197,6 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
201197 self .assertLess (max_diff , expected_max_difference )
202198
203199
204-
205200class Wan225BPipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
206201 pipeline_class = WanPipeline
207202 params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs" }
@@ -230,8 +225,8 @@ def get_dummy_components(self):
230225 out_channels = 12 ,
231226 is_residual = True ,
232227 patch_size = 2 ,
233- latents_mean = [0.0 ] * 48 ,
234- latents_std = [1.0 ] * 48 ,
228+ latents_mean = [0.0 ] * 48 ,
229+ latents_std = [1.0 ] * 48 ,
235230 dim_mult = [1 , 1 , 1 , 1 ],
236231 num_res_blocks = 1 ,
237232 scale_factor_spatial = 16 ,
@@ -295,7 +290,9 @@ def test_inference(self):
295290 device = "cpu"
296291
297292 components = self .get_dummy_components ()
298- pipe = self .pipeline_class (** components , )
293+ pipe = self .pipeline_class (
294+ ** components ,
295+ )
299296 pipe .to (device )
300297 pipe .set_progress_bar_config (disable = None )
301298
@@ -311,7 +308,10 @@ def test_inference(self):
311308
312309 generated_slice = generated_video .flatten ()
313310 generated_slice = torch .cat ([generated_slice [:8 ], generated_slice [- 8 :]])
314- self .assertTrue (torch .allclose (generated_slice , expected_slice , atol = 1e-3 ), f"generated_slice: { generated_slice } , expected_slice: { expected_slice } " )
311+ self .assertTrue (
312+ torch .allclose (generated_slice , expected_slice , atol = 1e-3 ),
313+ f"generated_slice: { generated_slice } , expected_slice: { expected_slice } " ,
314+ )
315315
316316 @unittest .skip ("Test not supported" )
317317 def test_attention_slicing_forward_pass (self ):
@@ -327,7 +327,6 @@ def test_components_function(self):
327327 self .assertTrue (set (pipe .components .keys ()) == set (init_components .keys ()))
328328
329329 def test_save_load_optional_components (self , expected_max_difference = 1e-4 ):
330-
331330 optional_component = "transformer_2"
332331
333332 components = self .get_dummy_components ()
@@ -353,16 +352,17 @@ def test_save_load_optional_components(self, expected_max_difference=1e-4):
353352 pipe_loaded .to (torch_device )
354353 pipe_loaded .set_progress_bar_config (disable = None )
355354
356- self .assertTrue (getattr (pipe_loaded , optional_component ) is None ,
357- f"`{ optional_component } ` did not stay set to None after loading." ,
358- )
355+ self .assertTrue (
356+ getattr (pipe_loaded , optional_component ) is None ,
357+ f"`{ optional_component } ` did not stay set to None after loading." ,
358+ )
359359
360360 inputs = self .get_dummy_inputs (generator_device )
361361 torch .manual_seed (0 )
362362 output_loaded = pipe_loaded (** inputs )[0 ]
363363
364364 max_diff = np .abs (output .detach ().cpu ().numpy () - output_loaded .detach ().cpu ().numpy ()).max ()
365365 self .assertLess (max_diff , expected_max_difference )
366-
366+
367367 def test_inference_batch_single_identical (self ):
368- self ._test_inference_batch_single_identical (expected_max_diff = 2e-3 )
368+ self ._test_inference_batch_single_identical (expected_max_diff = 2e-3 )
0 commit comments