@@ -166,9 +166,11 @@ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
166166 # NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates
167167 # set to 1, which should be equivalent to a 2D convolution
168168 expanded_kernel = self .blur_kernel [None , None , :, :].expand (self .in_channels , 1 , - 1 , - 1 )
169+ x = x .to (expanded_kernel .dtype )
169170 x = F .conv2d (x , expanded_kernel , padding = self .blur_padding , groups = self .in_channels )
170171
171172 # Main Conv2D with scaling
173+ x = x .to (self .weight .dtype )
172174 x = F .conv2d (x , self .weight * self .scale , bias = self .bias , stride = self .stride , padding = self .padding )
173175
174176 # Activation with fused bias, if using
@@ -804,6 +806,8 @@ def forward(
804806 time_embedder_dtype = next (iter (self .time_embedder .parameters ())).dtype
805807 if timestep .dtype != time_embedder_dtype and time_embedder_dtype != torch .int8 :
806808 timestep = timestep .to (time_embedder_dtype )
809+ if timestep .dtype != encoder_hidden_states .dtype :
810+ timestep = timestep .to (encoder_hidden_states .dtype )
807811 temb = self .time_embedder (timestep ).type_as (encoder_hidden_states )
808812 timestep_proj = self .time_proj (self .act_fn (temb ))
809813
0 commit comments