Skip to content

Commit fd26552

Browse files
samedwardsFMsamadwar
authored andcommitted
Fixed dtype mismatch when loading a single file
1 parent 01351be commit fd26552

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,11 @@ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
166166
# NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates
167167
# set to 1, which should be equivalent to a 2D convolution
168168
expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1)
169+
x = x.to(expanded_kernel.dtype)
169170
x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels)
170171

171172
# Main Conv2D with scaling
173+
x = x.to(self.weight.dtype)
172174
x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
173175

174176
# Activation with fused bias, if using
@@ -804,6 +806,8 @@ def forward(
804806
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
805807
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
806808
timestep = timestep.to(time_embedder_dtype)
809+
if timestep.dtype != encoder_hidden_states.dtype:
810+
timestep = timestep.to(encoder_hidden_states.dtype)
807811
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
808812
timestep_proj = self.time_proj(self.act_fn(temb))
809813

0 commit comments

Comments
 (0)