Skip to content

Commit 028bc19

Browse files
committed
Optimize dtype handling
Removed unnecessary dtype conversions for timestep and weight.
1 parent 9e47293 commit 028bc19

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,6 @@ def forward(self, face_image: torch.Tensor, channel_dim: int = 1) -> torch.Tenso
341341
# Upcast the QR orthogonalization operation to FP32
342342
original_motion_dtype = motion_feat.dtype
343343
motion_feat = motion_feat.to(weight.dtype)
344-
# weight = weight.to(torch.float32)
345344

346345
Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device)
347346

@@ -803,12 +802,9 @@ def forward(
803802
if timestep_seq_len is not None:
804803
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
805804

806-
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
807-
if timestep.dtype != time_embedder_dtype and time_embedder_dtype not in [torch.int8, torch.uint8]:
808-
timestep = timestep.to(time_embedder_dtype)
809-
if timestep.dtype != encoder_hidden_states.dtype:
810-
timestep = timestep.to(encoder_hidden_states.dtype)
811-
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
805+
timestep = timestep.to(encoder_hidden_states.dtype)
806+
807+
temb = self.time_embedder(timestep)
812808
timestep_proj = self.time_proj(self.act_fn(temb))
813809

814810
encoder_hidden_states = self.text_embedder(encoder_hidden_states)

0 commit comments

Comments
 (0)