Skip to content

Commit 9e47293

Browse files
samedwardsFMsamadwar
authored andcommitted
Improve code readability
1 parent a224fe4 commit 9e47293

File tree

1 file changed

+73
-77
lines changed

1 file changed

+73
-77
lines changed

src/diffusers/loaders/single_file_utils.py

Lines changed: 73 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def infer_diffusers_model_type(checkpoint):
749749
elif checkpoint[target_key].shape[0] == 5120:
750750
model_type = "wan-vace-14B"
751751

752-
if CHECKPOINT_KEY_NAMES["wan-animate"] in checkpoint:
752+
if CHECKPOINT_KEY_NAMES["wan_animate"] in checkpoint:
753753
model_type = "wan-animate-14B"
754754

755755
elif checkpoint[target_key].shape[0] == 1536:
@@ -3132,13 +3132,62 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
31323132

31333133

31343134
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
3135+
def generate_motion_encoder_mappings():
3136+
mappings = {
3137+
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
3138+
"motion_encoder.enc.net_app.convs.0.0.weight": "motion_encoder.conv_in.weight",
3139+
"motion_encoder.enc.net_app.convs.0.1.bias": "motion_encoder.conv_in.act_fn.bias",
3140+
"motion_encoder.enc.net_app.convs.8.weight": "motion_encoder.conv_out.weight",
3141+
"motion_encoder.enc.fc": "motion_encoder.motion_network",
3142+
}
3143+
3144+
for i in range(7):
3145+
conv_idx = i + 1
3146+
mappings.update({
3147+
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.0.weight": f"motion_encoder.res_blocks.{i}.conv1.weight",
3148+
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.1.bias": f"motion_encoder.res_blocks.{i}.conv1.act_fn.bias",
3149+
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.1.weight": f"motion_encoder.res_blocks.{i}.conv2.weight",
3150+
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.2.bias": f"motion_encoder.res_blocks.{i}.conv2.act_fn.bias",
3151+
f"motion_encoder.enc.net_app.convs.{conv_idx}.skip.1.weight": f"motion_encoder.res_blocks.{i}.conv_skip.weight",
3152+
})
3153+
3154+
return mappings
3155+
3156+
def generate_face_adapter_mappings():
3157+
return {
3158+
"face_adapter.fuser_blocks": "face_adapter",
3159+
".k_norm.": ".norm_k.",
3160+
".q_norm.": ".norm_q.",
3161+
".linear1_q.": ".to_q.",
3162+
".linear2.": ".to_out.",
3163+
"conv1_local.conv": "conv1_local",
3164+
"conv2.conv": "conv2",
3165+
"conv3.conv": "conv3",
3166+
}
3167+
3168+
def split_tensor_handler(key, state_dict, split_pattern, target_keys):
3169+
tensor = state_dict.pop(key)
3170+
split_idx = tensor.shape[0] // 2
3171+
3172+
new_key_1 = key.replace(split_pattern, target_keys[0])
3173+
new_key_2 = key.replace(split_pattern, target_keys[1])
3174+
3175+
state_dict[new_key_1] = tensor[:split_idx]
3176+
state_dict[new_key_2] = tensor[split_idx:]
3177+
3178+
def reshape_bias_handler(key, state_dict):
3179+
if "motion_encoder.enc.net_app.convs." in key and ".bias" in key:
3180+
state_dict[key] = state_dict[key][0, :, 0, 0]
3181+
31353182
converted_state_dict = {}
31363183

3184+
# Strip model.diffusion_model prefix
31373185
keys = list(checkpoint.keys())
31383186
for k in keys:
31393187
if "model.diffusion_model." in k:
31403188
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
31413189

3190+
# Base transformer mappings
31423191
TRANSFORMER_KEYS_RENAME_DICT = {
31433192
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
31443193
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
@@ -3160,95 +3209,42 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
31603209
"ffn.0": "ffn.net.0.proj",
31613210
"ffn.2": "ffn.net.2",
31623211
# Hack to swap the layer names
3163-
# The original model calls the norms in following order: norm1, norm3, norm2
3164-
# We convert it to: norm1, norm2, norm3
31653212
"norm2": "norm__placeholder",
31663213
"norm3": "norm2",
31673214
"norm__placeholder": "norm3",
3168-
# For the I2V model
3215+
# I2V model
31693216
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
31703217
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
31713218
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
31723219
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
3173-
# For the VACE model
3220+
# VACE model
31743221
"before_proj": "proj_in",
31753222
"after_proj": "proj_out",
3176-
# For Wan Animate
3177-
"face_adapter.fuser_blocks": "face_adapter",
3178-
".k_norm.": ".norm_k.",
3179-
".q_norm.": ".norm_q.",
3180-
# Requires tensor split
3181-
".linear1_kv.": [".to_k.", ".to_v."],
3182-
".linear1_q.": ".to_q.",
3183-
".linear2.": ".to_out.",
3184-
"conv1_local.conv": "conv1_local",
3185-
"conv2.conv": "conv2",
3186-
"conv3.conv": "conv3",
3187-
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
3188-
"motion_encoder.enc.net_app.convs.0.0.weight": "motion_encoder.conv_in.weight",
3189-
"motion_encoder.enc.net_app.convs.0.1.bias": "motion_encoder.conv_in.act_fn.bias",
3190-
"motion_encoder.enc.net_app.convs.8.weight": "motion_encoder.conv_out.weight",
3191-
"motion_encoder.enc.fc": "motion_encoder.motion_network",
3192-
"motion_encoder.enc.net_app.convs.7.conv1.0.weight": "motion_encoder.res_blocks.6.conv1.weight",
3193-
"motion_encoder.enc.net_app.convs.6.conv1.0.weight": "motion_encoder.res_blocks.5.conv1.weight",
3194-
"motion_encoder.enc.net_app.convs.5.conv1.0.weight": "motion_encoder.res_blocks.4.conv1.weight",
3195-
"motion_encoder.enc.net_app.convs.4.conv1.0.weight": "motion_encoder.res_blocks.3.conv1.weight",
3196-
"motion_encoder.enc.net_app.convs.3.conv1.0.weight": "motion_encoder.res_blocks.2.conv1.weight",
3197-
"motion_encoder.enc.net_app.convs.2.conv1.0.weight": "motion_encoder.res_blocks.1.conv1.weight",
3198-
"motion_encoder.enc.net_app.convs.1.conv1.0.weight": "motion_encoder.res_blocks.0.conv1.weight",
3199-
"motion_encoder.enc.net_app.convs.7.conv2.1.weight": "motion_encoder.res_blocks.6.conv2.weight",
3200-
"motion_encoder.enc.net_app.convs.6.conv2.1.weight": "motion_encoder.res_blocks.5.conv2.weight",
3201-
"motion_encoder.enc.net_app.convs.5.conv2.1.weight": "motion_encoder.res_blocks.4.conv2.weight",
3202-
"motion_encoder.enc.net_app.convs.4.conv2.1.weight": "motion_encoder.res_blocks.3.conv2.weight",
3203-
"motion_encoder.enc.net_app.convs.3.conv2.1.weight": "motion_encoder.res_blocks.2.conv2.weight",
3204-
"motion_encoder.enc.net_app.convs.2.conv2.1.weight": "motion_encoder.res_blocks.1.conv2.weight",
3205-
"motion_encoder.enc.net_app.convs.1.conv2.1.weight": "motion_encoder.res_blocks.0.conv2.weight",
3206-
"motion_encoder.enc.net_app.convs.7.conv1.1.bias": "motion_encoder.res_blocks.6.conv1.act_fn.bias",
3207-
"motion_encoder.enc.net_app.convs.6.conv1.1.bias": "motion_encoder.res_blocks.5.conv1.act_fn.bias",
3208-
"motion_encoder.enc.net_app.convs.5.conv1.1.bias": "motion_encoder.res_blocks.4.conv1.act_fn.bias",
3209-
"motion_encoder.enc.net_app.convs.4.conv1.1.bias": "motion_encoder.res_blocks.3.conv1.act_fn.bias",
3210-
"motion_encoder.enc.net_app.convs.3.conv1.1.bias": "motion_encoder.res_blocks.2.conv1.act_fn.bias",
3211-
"motion_encoder.enc.net_app.convs.2.conv1.1.bias": "motion_encoder.res_blocks.1.conv1.act_fn.bias",
3212-
"motion_encoder.enc.net_app.convs.1.conv1.1.bias": "motion_encoder.res_blocks.0.conv1.act_fn.bias",
3213-
"motion_encoder.enc.net_app.convs.7.conv2.2.bias": "motion_encoder.res_blocks.6.conv2.act_fn.bias",
3214-
"motion_encoder.enc.net_app.convs.6.conv2.2.bias": "motion_encoder.res_blocks.5.conv2.act_fn.bias",
3215-
"motion_encoder.enc.net_app.convs.5.conv2.2.bias": "motion_encoder.res_blocks.4.conv2.act_fn.bias",
3216-
"motion_encoder.enc.net_app.convs.4.conv2.2.bias": "motion_encoder.res_blocks.3.conv2.act_fn.bias",
3217-
"motion_encoder.enc.net_app.convs.3.conv2.2.bias": "motion_encoder.res_blocks.2.conv2.act_fn.bias",
3218-
"motion_encoder.enc.net_app.convs.2.conv2.2.bias": "motion_encoder.res_blocks.1.conv2.act_fn.bias",
3219-
"motion_encoder.enc.net_app.convs.1.conv2.2.bias": "motion_encoder.res_blocks.0.conv2.act_fn.bias",
3220-
"motion_encoder.enc.net_app.convs.7.skip.1.weight": "motion_encoder.res_blocks.6.conv_skip.weight",
3221-
"motion_encoder.enc.net_app.convs.6.skip.1.weight": "motion_encoder.res_blocks.5.conv_skip.weight",
3222-
"motion_encoder.enc.net_app.convs.5.skip.1.weight": "motion_encoder.res_blocks.4.conv_skip.weight",
3223-
"motion_encoder.enc.net_app.convs.4.skip.1.weight": "motion_encoder.res_blocks.3.conv_skip.weight",
3224-
"motion_encoder.enc.net_app.convs.3.skip.1.weight": "motion_encoder.res_blocks.2.conv_skip.weight",
3225-
"motion_encoder.enc.net_app.convs.2.skip.1.weight": "motion_encoder.res_blocks.1.conv_skip.weight",
3226-
"motion_encoder.enc.net_app.convs.1.skip.1.weight": "motion_encoder.res_blocks.0.conv_skip.weight",
32273223
}
32283224

3225+
SPECIAL_KEYS_HANDLERS = {}
3226+
if any("face_adapter" in k for k in checkpoint.keys()):
3227+
TRANSFORMER_KEYS_RENAME_DICT.update(generate_face_adapter_mappings())
3228+
SPECIAL_KEYS_HANDLERS[".linear1_kv."] = (split_tensor_handler, [".to_k.", ".to_v."])
3229+
3230+
if any("motion_encoder" in k for k in checkpoint.keys()):
3231+
TRANSFORMER_KEYS_RENAME_DICT.update(generate_motion_encoder_mappings())
3232+
32293233
for key in list(checkpoint.keys()):
3230-
new_key = key[:]
3231-
extra_key = ""
3232-
index = 0
3234+
reshape_bias_handler(key, checkpoint)
3235+
3236+
for key in list(checkpoint.keys()):
3237+
new_key = key
32333238
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
3234-
if isinstance(rename_key, list):
3235-
if replace_key in new_key:
3236-
index = int(checkpoint[key].shape[0] / 2)
3237-
new_key = new_key.replace(replace_key, rename_key[0])
3238-
extra_key = new_key.replace(rename_key[0], rename_key[1])
3239-
else:
3240-
new_key = new_key.replace(replace_key, rename_key)
3241-
if extra_key != "":
3242-
converted_state_dict[new_key] = checkpoint[key][index:]
3243-
converted_state_dict[extra_key] = checkpoint[key][:index]
3244-
checkpoint.pop(key)
3245-
else:
3246-
if key == "motion_encoder.enc.net_app.convs.0.1.bias":
3247-
converted_state_dict[new_key] = checkpoint.pop(key)[0, :, 0, 0]
3248-
elif "motion_encoder.enc.net_app.convs." in key and ".bias" in key:
3249-
converted_state_dict[new_key] = checkpoint.pop(key)[0, :, 0, 0]
3250-
else:
3251-
converted_state_dict[new_key] = checkpoint.pop(key)
3239+
new_key = new_key.replace(replace_key, rename_key)
3240+
converted_state_dict[new_key] = checkpoint.pop(key)
3241+
3242+
for key in list(converted_state_dict.keys()):
3243+
for pattern, (handler_fn, target_keys) in SPECIAL_KEYS_HANDLERS.items():
3244+
if pattern not in key:
3245+
continue
3246+
handler_fn(key, converted_state_dict, pattern, target_keys)
3247+
break
32523248

32533249
return converted_state_dict
32543250

0 commit comments

Comments
 (0)