Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,10 +1420,7 @@ def test_float16_inference(self, expected_max_diff=5e-2):
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.to(torch_device).half()
components = self.get_dummy_components(dtype=torch.float16)

pipe = self.pipeline_class(**components)
for component in pipe.components.values():
Expand Down
18 changes: 9 additions & 9 deletions tests/pipelines/wan/test_wan_22.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ class Wan22PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False
supports_dduf = False

def get_dummy_components(self):
def get_dummy_components(self, dtype=torch.float32):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Contributor Author

@kaixuanliu kaixuanliu Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls refer to L246-L256 (Sorry I only found Chinese version for this explanation). Using torch.Tensor.to method will convert all weights, while using torch_dtype parameter with from_pretrained will preserve layers in _keep_in_fp32_modules. For wan models, all components of pipe will be fp16 dtype while it is not the case for pipe_loaded. Here I override test_save_load_float16 function seperately for wan models.

torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
).to(dtype=dtype)

torch.manual_seed(0)
scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5", dtype=dtype)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

torch.manual_seed(0)
Expand All @@ -80,7 +80,7 @@ def get_dummy_components(self):
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)
).to(dtype=dtype)

torch.manual_seed(0)
transformer_2 = WanTransformer3DModel(
Expand All @@ -96,7 +96,7 @@ def get_dummy_components(self):
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)
).to(dtype=dtype)

components = {
"transformer": transformer,
Expand Down Expand Up @@ -215,7 +215,7 @@ class Wan225BPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_xformers_attention = False
supports_dduf = False

def get_dummy_components(self):
def get_dummy_components(self, dtype=torch.float32):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
Expand All @@ -231,11 +231,11 @@ def get_dummy_components(self):
scale_factor_spatial=16,
scale_factor_temporal=4,
temperal_downsample=[False, True, True],
)
).to(dtype=dtype)

torch.manual_seed(0)
scheduler = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5", dtype=dtype)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")

torch.manual_seed(0)
Expand All @@ -252,7 +252,7 @@ def get_dummy_components(self):
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)
).to(dtype=dtype)

components = {
"transformer": transformer,
Expand Down