Skip to content

Commit 79bd7ec

Browse files
authored
Support more Wan loras (VACE) (#11726)
update
1 parent 9b834f8 commit 79bd7ec

File tree

1 file changed

+51
-36
lines changed

1 file changed

+51
-36
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,7 +1596,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
15961596
converted_state_dict = {}
15971597
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
15981598

1599-
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict if "blocks." in k})
1599+
block_numbers = {int(k.split(".")[1]) for k in original_state_dict if k.startswith("blocks.")}
1600+
min_block = min(block_numbers)
1601+
max_block = max(block_numbers)
1602+
16001603
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
16011604
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
16021605
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
@@ -1622,45 +1625,57 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16221625
# For the `diff_b` keys, we treat them as lora_bias.
16231626
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
16241627

1625-
for i in range(num_blocks):
1628+
for i in range(min_block, max_block + 1):
16261629
# Self-attention
16271630
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1628-
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = original_state_dict.pop(
1629-
f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1630-
)
1631-
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = original_state_dict.pop(
1632-
f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1633-
)
1634-
if f"blocks.{i}.self_attn.{o}.diff_b" in original_state_dict:
1635-
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.bias"] = original_state_dict.pop(
1636-
f"blocks.{i}.self_attn.{o}.diff_b"
1637-
)
1631+
original_key = f"blocks.{i}.self_attn.{o}.{lora_down_key}.weight"
1632+
converted_key = f"blocks.{i}.attn1.{c}.lora_A.weight"
1633+
if original_key in original_state_dict:
1634+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1635+
1636+
original_key = f"blocks.{i}.self_attn.{o}.{lora_up_key}.weight"
1637+
converted_key = f"blocks.{i}.attn1.{c}.lora_B.weight"
1638+
if original_key in original_state_dict:
1639+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1640+
1641+
original_key = f"blocks.{i}.self_attn.{o}.diff_b"
1642+
converted_key = f"blocks.{i}.attn1.{c}.lora_B.bias"
1643+
if original_key in original_state_dict:
1644+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
16381645

16391646
# Cross-attention
16401647
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
1641-
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1642-
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1643-
)
1644-
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1645-
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1646-
)
1647-
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
1648-
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
1649-
f"blocks.{i}.cross_attn.{o}.diff_b"
1650-
)
1648+
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1649+
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1650+
if original_key in original_state_dict:
1651+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1652+
1653+
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1654+
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1655+
if original_key in original_state_dict:
1656+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1657+
1658+
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
1659+
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
1660+
if original_key in original_state_dict:
1661+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
16511662

16521663
if is_i2v_lora:
16531664
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1654-
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1655-
f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1656-
)
1657-
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1658-
f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1659-
)
1660-
if f"blocks.{i}.cross_attn.{o}.diff_b" in original_state_dict:
1661-
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.bias"] = original_state_dict.pop(
1662-
f"blocks.{i}.cross_attn.{o}.diff_b"
1663-
)
1665+
original_key = f"blocks.{i}.cross_attn.{o}.{lora_down_key}.weight"
1666+
converted_key = f"blocks.{i}.attn2.{c}.lora_A.weight"
1667+
if original_key in original_state_dict:
1668+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1669+
1670+
original_key = f"blocks.{i}.cross_attn.{o}.{lora_up_key}.weight"
1671+
converted_key = f"blocks.{i}.attn2.{c}.lora_B.weight"
1672+
if original_key in original_state_dict:
1673+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
1674+
1675+
original_key = f"blocks.{i}.cross_attn.{o}.diff_b"
1676+
converted_key = f"blocks.{i}.attn2.{c}.lora_B.bias"
1677+
if original_key in original_state_dict:
1678+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
16641679

16651680
# FFN
16661681
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):
@@ -1674,10 +1689,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16741689
if original_key in original_state_dict:
16751690
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
16761691

1677-
if f"blocks.{i}.{o}.diff_b" in original_state_dict:
1678-
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.bias"] = original_state_dict.pop(
1679-
f"blocks.{i}.{o}.diff_b"
1680-
)
1692+
original_key = f"blocks.{i}.{o}.diff_b"
1693+
converted_key = f"blocks.{i}.ffn.{c}.lora_B.bias"
1694+
if original_key in original_state_dict:
1695+
converted_state_dict[converted_key] = original_state_dict.pop(original_key)
16811696

16821697
# Remaining.
16831698
if original_state_dict:

0 commit comments

Comments
 (0)