@@ -162,17 +162,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
162162
163163 model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
164164 _callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" ]
165- _optional_components = ["transformer_2" , "image_encoder" , "image_processor" ]
165+ _optional_components = ["transformer" , " transformer_2" , "image_encoder" , "image_processor" ]
166166
167167 def __init__ (
168168 self ,
169169 tokenizer : AutoTokenizer ,
170170 text_encoder : UMT5EncoderModel ,
171- transformer : WanTransformer3DModel ,
172171 vae : AutoencoderKLWan ,
173172 scheduler : FlowMatchEulerDiscreteScheduler ,
174173 image_processor : CLIPImageProcessor = None ,
175174 image_encoder : CLIPVisionModel = None ,
175+ transformer : WanTransformer3DModel = None ,
176176 transformer_2 : WanTransformer3DModel = None ,
177177 boundary_ratio : Optional [float ] = None ,
178178 expand_timesteps : bool = False ,
@@ -669,12 +669,13 @@ def __call__(
669669 )
670670
671671 # Encode image embedding
672- transformer_dtype = self .transformer .dtype
672+ transformer_dtype = self .transformer .dtype if self . transformer is not None else self . transformer_2 . dtype
673673 prompt_embeds = prompt_embeds .to (transformer_dtype )
674674 if negative_prompt_embeds is not None :
675675 negative_prompt_embeds = negative_prompt_embeds .to (transformer_dtype )
676676
677- if self .config .boundary_ratio is None and not self .config .expand_timesteps :
677+ # only wan 2.1 i2v transformer accepts image_embeds
678+ if self .transformer is not None and self .transformer .config .added_kv_proj_dim is not None :
678679 if image_embeds is None :
679680 if last_image is None :
680681 image_embeds = self .encode_image (image , device )
@@ -709,6 +710,7 @@ def __call__(
709710 last_image ,
710711 )
711712 if self .config .expand_timesteps :
713+ # wan 2.2 5b i2v use firt_frame_mask to mask timesteps
712714 latents , condition , first_frame_mask = latents_outputs
713715 else :
714716 latents , condition = latents_outputs
0 commit comments