@@ -749,7 +749,7 @@ def infer_diffusers_model_type(checkpoint):
749749 elif checkpoint [target_key ].shape [0 ] == 5120 :
750750 model_type = "wan-vace-14B"
751751
752- if CHECKPOINT_KEY_NAMES ["wan-animate " ] in checkpoint :
752+ if CHECKPOINT_KEY_NAMES ["wan_animate " ] in checkpoint :
753753 model_type = "wan-animate-14B"
754754
755755 elif checkpoint [target_key ].shape [0 ] == 1536 :
@@ -3132,13 +3132,62 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
31323132
31333133
31343134def convert_wan_transformer_to_diffusers (checkpoint , ** kwargs ):
3135+ def generate_motion_encoder_mappings ():
3136+ mappings = {
3137+ "motion_encoder.dec.direction.weight" : "motion_encoder.motion_synthesis_weight" ,
3138+ "motion_encoder.enc.net_app.convs.0.0.weight" : "motion_encoder.conv_in.weight" ,
3139+ "motion_encoder.enc.net_app.convs.0.1.bias" : "motion_encoder.conv_in.act_fn.bias" ,
3140+ "motion_encoder.enc.net_app.convs.8.weight" : "motion_encoder.conv_out.weight" ,
3141+ "motion_encoder.enc.fc" : "motion_encoder.motion_network" ,
3142+ }
3143+
3144+ for i in range (7 ):
3145+ conv_idx = i + 1
3146+ mappings .update ({
3147+ f"motion_encoder.enc.net_app.convs.{ conv_idx } .conv1.0.weight" : f"motion_encoder.res_blocks.{ i } .conv1.weight" ,
3148+ f"motion_encoder.enc.net_app.convs.{ conv_idx } .conv1.1.bias" : f"motion_encoder.res_blocks.{ i } .conv1.act_fn.bias" ,
3149+ f"motion_encoder.enc.net_app.convs.{ conv_idx } .conv2.1.weight" : f"motion_encoder.res_blocks.{ i } .conv2.weight" ,
3150+ f"motion_encoder.enc.net_app.convs.{ conv_idx } .conv2.2.bias" : f"motion_encoder.res_blocks.{ i } .conv2.act_fn.bias" ,
3151+ f"motion_encoder.enc.net_app.convs.{ conv_idx } .skip.1.weight" : f"motion_encoder.res_blocks.{ i } .conv_skip.weight" ,
3152+ })
3153+
3154+ return mappings
3155+
3156+ def generate_face_adapter_mappings ():
3157+ return {
3158+ "face_adapter.fuser_blocks" : "face_adapter" ,
3159+ ".k_norm." : ".norm_k." ,
3160+ ".q_norm." : ".norm_q." ,
3161+ ".linear1_q." : ".to_q." ,
3162+ ".linear2." : ".to_out." ,
3163+ "conv1_local.conv" : "conv1_local" ,
3164+ "conv2.conv" : "conv2" ,
3165+ "conv3.conv" : "conv3" ,
3166+ }
3167+
3168+ def split_tensor_handler (key , state_dict , split_pattern , target_keys ):
3169+ tensor = state_dict .pop (key )
3170+ split_idx = tensor .shape [0 ] // 2
3171+
3172+ new_key_1 = key .replace (split_pattern , target_keys [0 ])
3173+ new_key_2 = key .replace (split_pattern , target_keys [1 ])
3174+
3175+ state_dict [new_key_1 ] = tensor [:split_idx ]
3176+ state_dict [new_key_2 ] = tensor [split_idx :]
3177+
3178+ def reshape_bias_handler (key , state_dict ):
3179+ if "motion_encoder.enc.net_app.convs." in key and ".bias" in key :
3180+ state_dict [key ] = state_dict [key ][0 , :, 0 , 0 ]
3181+
31353182 converted_state_dict = {}
31363183
3184+ # Strip model.diffusion_model prefix
31373185 keys = list (checkpoint .keys ())
31383186 for k in keys :
31393187 if "model.diffusion_model." in k :
31403188 checkpoint [k .replace ("model.diffusion_model." , "" )] = checkpoint .pop (k )
31413189
3190+ # Base transformer mappings
31423191 TRANSFORMER_KEYS_RENAME_DICT = {
31433192 "time_embedding.0" : "condition_embedder.time_embedder.linear_1" ,
31443193 "time_embedding.2" : "condition_embedder.time_embedder.linear_2" ,
@@ -3160,95 +3209,42 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
31603209 "ffn.0" : "ffn.net.0.proj" ,
31613210 "ffn.2" : "ffn.net.2" ,
31623211 # Hack to swap the layer names
3163- # The original model calls the norms in following order: norm1, norm3, norm2
3164- # We convert it to: norm1, norm2, norm3
31653212 "norm2" : "norm__placeholder" ,
31663213 "norm3" : "norm2" ,
31673214 "norm__placeholder" : "norm3" ,
3168- # For the I2V model
3215+ # I2V model
31693216 "img_emb.proj.0" : "condition_embedder.image_embedder.norm1" ,
31703217 "img_emb.proj.1" : "condition_embedder.image_embedder.ff.net.0.proj" ,
31713218 "img_emb.proj.3" : "condition_embedder.image_embedder.ff.net.2" ,
31723219 "img_emb.proj.4" : "condition_embedder.image_embedder.norm2" ,
3173- # For the VACE model
3220+ # VACE model
31743221 "before_proj" : "proj_in" ,
31753222 "after_proj" : "proj_out" ,
3176- # For Wan Animate
3177- "face_adapter.fuser_blocks" : "face_adapter" ,
3178- ".k_norm." : ".norm_k." ,
3179- ".q_norm." : ".norm_q." ,
3180- # Requires tensor split
3181- ".linear1_kv." : [".to_k." , ".to_v." ],
3182- ".linear1_q." : ".to_q." ,
3183- ".linear2." : ".to_out." ,
3184- "conv1_local.conv" : "conv1_local" ,
3185- "conv2.conv" : "conv2" ,
3186- "conv3.conv" : "conv3" ,
3187- "motion_encoder.dec.direction.weight" : "motion_encoder.motion_synthesis_weight" ,
3188- "motion_encoder.enc.net_app.convs.0.0.weight" : "motion_encoder.conv_in.weight" ,
3189- "motion_encoder.enc.net_app.convs.0.1.bias" : "motion_encoder.conv_in.act_fn.bias" ,
3190- "motion_encoder.enc.net_app.convs.8.weight" : "motion_encoder.conv_out.weight" ,
3191- "motion_encoder.enc.fc" : "motion_encoder.motion_network" ,
3192- "motion_encoder.enc.net_app.convs.7.conv1.0.weight" : "motion_encoder.res_blocks.6.conv1.weight" ,
3193- "motion_encoder.enc.net_app.convs.6.conv1.0.weight" : "motion_encoder.res_blocks.5.conv1.weight" ,
3194- "motion_encoder.enc.net_app.convs.5.conv1.0.weight" : "motion_encoder.res_blocks.4.conv1.weight" ,
3195- "motion_encoder.enc.net_app.convs.4.conv1.0.weight" : "motion_encoder.res_blocks.3.conv1.weight" ,
3196- "motion_encoder.enc.net_app.convs.3.conv1.0.weight" : "motion_encoder.res_blocks.2.conv1.weight" ,
3197- "motion_encoder.enc.net_app.convs.2.conv1.0.weight" : "motion_encoder.res_blocks.1.conv1.weight" ,
3198- "motion_encoder.enc.net_app.convs.1.conv1.0.weight" : "motion_encoder.res_blocks.0.conv1.weight" ,
3199- "motion_encoder.enc.net_app.convs.7.conv2.1.weight" : "motion_encoder.res_blocks.6.conv2.weight" ,
3200- "motion_encoder.enc.net_app.convs.6.conv2.1.weight" : "motion_encoder.res_blocks.5.conv2.weight" ,
3201- "motion_encoder.enc.net_app.convs.5.conv2.1.weight" : "motion_encoder.res_blocks.4.conv2.weight" ,
3202- "motion_encoder.enc.net_app.convs.4.conv2.1.weight" : "motion_encoder.res_blocks.3.conv2.weight" ,
3203- "motion_encoder.enc.net_app.convs.3.conv2.1.weight" : "motion_encoder.res_blocks.2.conv2.weight" ,
3204- "motion_encoder.enc.net_app.convs.2.conv2.1.weight" : "motion_encoder.res_blocks.1.conv2.weight" ,
3205- "motion_encoder.enc.net_app.convs.1.conv2.1.weight" : "motion_encoder.res_blocks.0.conv2.weight" ,
3206- "motion_encoder.enc.net_app.convs.7.conv1.1.bias" : "motion_encoder.res_blocks.6.conv1.act_fn.bias" ,
3207- "motion_encoder.enc.net_app.convs.6.conv1.1.bias" : "motion_encoder.res_blocks.5.conv1.act_fn.bias" ,
3208- "motion_encoder.enc.net_app.convs.5.conv1.1.bias" : "motion_encoder.res_blocks.4.conv1.act_fn.bias" ,
3209- "motion_encoder.enc.net_app.convs.4.conv1.1.bias" : "motion_encoder.res_blocks.3.conv1.act_fn.bias" ,
3210- "motion_encoder.enc.net_app.convs.3.conv1.1.bias" : "motion_encoder.res_blocks.2.conv1.act_fn.bias" ,
3211- "motion_encoder.enc.net_app.convs.2.conv1.1.bias" : "motion_encoder.res_blocks.1.conv1.act_fn.bias" ,
3212- "motion_encoder.enc.net_app.convs.1.conv1.1.bias" : "motion_encoder.res_blocks.0.conv1.act_fn.bias" ,
3213- "motion_encoder.enc.net_app.convs.7.conv2.2.bias" : "motion_encoder.res_blocks.6.conv2.act_fn.bias" ,
3214- "motion_encoder.enc.net_app.convs.6.conv2.2.bias" : "motion_encoder.res_blocks.5.conv2.act_fn.bias" ,
3215- "motion_encoder.enc.net_app.convs.5.conv2.2.bias" : "motion_encoder.res_blocks.4.conv2.act_fn.bias" ,
3216- "motion_encoder.enc.net_app.convs.4.conv2.2.bias" : "motion_encoder.res_blocks.3.conv2.act_fn.bias" ,
3217- "motion_encoder.enc.net_app.convs.3.conv2.2.bias" : "motion_encoder.res_blocks.2.conv2.act_fn.bias" ,
3218- "motion_encoder.enc.net_app.convs.2.conv2.2.bias" : "motion_encoder.res_blocks.1.conv2.act_fn.bias" ,
3219- "motion_encoder.enc.net_app.convs.1.conv2.2.bias" : "motion_encoder.res_blocks.0.conv2.act_fn.bias" ,
3220- "motion_encoder.enc.net_app.convs.7.skip.1.weight" : "motion_encoder.res_blocks.6.conv_skip.weight" ,
3221- "motion_encoder.enc.net_app.convs.6.skip.1.weight" : "motion_encoder.res_blocks.5.conv_skip.weight" ,
3222- "motion_encoder.enc.net_app.convs.5.skip.1.weight" : "motion_encoder.res_blocks.4.conv_skip.weight" ,
3223- "motion_encoder.enc.net_app.convs.4.skip.1.weight" : "motion_encoder.res_blocks.3.conv_skip.weight" ,
3224- "motion_encoder.enc.net_app.convs.3.skip.1.weight" : "motion_encoder.res_blocks.2.conv_skip.weight" ,
3225- "motion_encoder.enc.net_app.convs.2.skip.1.weight" : "motion_encoder.res_blocks.1.conv_skip.weight" ,
3226- "motion_encoder.enc.net_app.convs.1.skip.1.weight" : "motion_encoder.res_blocks.0.conv_skip.weight" ,
32273223 }
32283224
3225+ SPECIAL_KEYS_HANDLERS = {}
3226+ if any ("face_adapter" in k for k in checkpoint .keys ()):
3227+ TRANSFORMER_KEYS_RENAME_DICT .update (generate_face_adapter_mappings ())
3228+ SPECIAL_KEYS_HANDLERS [".linear1_kv." ] = (split_tensor_handler , [".to_k." , ".to_v." ])
3229+
3230+ if any ("motion_encoder" in k for k in checkpoint .keys ()):
3231+ TRANSFORMER_KEYS_RENAME_DICT .update (generate_motion_encoder_mappings ())
3232+
32293233 for key in list (checkpoint .keys ()):
3230- new_key = key [:]
3231- extra_key = ""
3232- index = 0
3234+ reshape_bias_handler (key , checkpoint )
3235+
3236+ for key in list (checkpoint .keys ()):
3237+ new_key = key
32333238 for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
3234- if isinstance (rename_key , list ):
3235- if replace_key in new_key :
3236- index = int (checkpoint [key ].shape [0 ] / 2 )
3237- new_key = new_key .replace (replace_key , rename_key [0 ])
3238- extra_key = new_key .replace (rename_key [0 ], rename_key [1 ])
3239- else :
3240- new_key = new_key .replace (replace_key , rename_key )
3241- if extra_key != "" :
3242- converted_state_dict [new_key ] = checkpoint [key ][index :]
3243- converted_state_dict [extra_key ] = checkpoint [key ][:index ]
3244- checkpoint .pop (key )
3245- else :
3246- if key == "motion_encoder.enc.net_app.convs.0.1.bias" :
3247- converted_state_dict [new_key ] = checkpoint .pop (key )[0 , :, 0 , 0 ]
3248- elif "motion_encoder.enc.net_app.convs." in key and ".bias" in key :
3249- converted_state_dict [new_key ] = checkpoint .pop (key )[0 , :, 0 , 0 ]
3250- else :
3251- converted_state_dict [new_key ] = checkpoint .pop (key )
3239+ new_key = new_key .replace (replace_key , rename_key )
3240+ converted_state_dict [new_key ] = checkpoint .pop (key )
3241+
3242+ for key in list (converted_state_dict .keys ()):
3243+ for pattern , (handler_fn , target_keys ) in SPECIAL_KEYS_HANDLERS .items ():
3244+ if pattern not in key :
3245+ continue
3246+ handler_fn (key , converted_state_dict , pattern , target_keys )
3247+ break
32523248
32533249 return converted_state_dict
32543250
0 commit comments