@@ -1596,7 +1596,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
15961596 converted_state_dict = {}
15971597 original_state_dict = {k [len ("diffusion_model." ) :]: v for k , v in state_dict .items ()}
15981598
1599- num_blocks = len ({k .split ("blocks." )[1 ].split ("." )[0 ] for k in original_state_dict if "blocks." in k })
1599+ block_numbers = {int (k .split ("." )[1 ]) for k in original_state_dict if k .startswith ("blocks." )}
1600+ min_block = min (block_numbers )
1601+ max_block = max (block_numbers )
1602+
16001603 is_i2v_lora = any ("k_img" in k for k in original_state_dict ) and any ("v_img" in k for k in original_state_dict )
16011604 lora_down_key = "lora_A" if any ("lora_A" in k for k in original_state_dict ) else "lora_down"
16021605 lora_up_key = "lora_B" if any ("lora_B" in k for k in original_state_dict ) else "lora_up"
@@ -1622,45 +1625,57 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16221625 # For the `diff_b` keys, we treat them as lora_bias.
16231626 # https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
16241627
1625- for i in range (num_blocks ):
1628+ for i in range (min_block , max_block + 1 ):
16261629 # Self-attention
16271630 for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1628- converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_A.weight" ] = original_state_dict .pop (
1629- f"blocks.{ i } .self_attn.{ o } .{ lora_down_key } .weight"
1630- )
1631- converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_B.weight" ] = original_state_dict .pop (
1632- f"blocks.{ i } .self_attn.{ o } .{ lora_up_key } .weight"
1633- )
1634- if f"blocks.{ i } .self_attn.{ o } .diff_b" in original_state_dict :
1635- converted_state_dict [f"blocks.{ i } .attn1.{ c } .lora_B.bias" ] = original_state_dict .pop (
1636- f"blocks.{ i } .self_attn.{ o } .diff_b"
1637- )
1631+ original_key = f"blocks.{ i } .self_attn.{ o } .{ lora_down_key } .weight"
1632+ converted_key = f"blocks.{ i } .attn1.{ c } .lora_A.weight"
1633+ if original_key in original_state_dict :
1634+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1635+
1636+ original_key = f"blocks.{ i } .self_attn.{ o } .{ lora_up_key } .weight"
1637+ converted_key = f"blocks.{ i } .attn1.{ c } .lora_B.weight"
1638+ if original_key in original_state_dict :
1639+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1640+
1641+ original_key = f"blocks.{ i } .self_attn.{ o } .diff_b"
1642+ converted_key = f"blocks.{ i } .attn1.{ c } .lora_B.bias"
1643+ if original_key in original_state_dict :
1644+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
16381645
16391646 # Cross-attention
16401647 for o , c in zip (["q" , "k" , "v" , "o" ], ["to_q" , "to_k" , "to_v" , "to_out.0" ]):
1641- converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = original_state_dict .pop (
1642- f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1643- )
1644- converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = original_state_dict .pop (
1645- f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1646- )
1647- if f"blocks.{ i } .cross_attn.{ o } .diff_b" in original_state_dict :
1648- converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.bias" ] = original_state_dict .pop (
1649- f"blocks.{ i } .cross_attn.{ o } .diff_b"
1650- )
1648+ original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1649+ converted_key = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1650+ if original_key in original_state_dict :
1651+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1652+
1653+ original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1654+ converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1655+ if original_key in original_state_dict :
1656+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1657+
1658+ original_key = f"blocks.{ i } .cross_attn.{ o } .diff_b"
1659+ converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.bias"
1660+ if original_key in original_state_dict :
1661+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
16511662
16521663 if is_i2v_lora :
16531664 for o , c in zip (["k_img" , "v_img" ], ["add_k_proj" , "add_v_proj" ]):
1654- converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_A.weight" ] = original_state_dict .pop (
1655- f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1656- )
1657- converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.weight" ] = original_state_dict .pop (
1658- f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1659- )
1660- if f"blocks.{ i } .cross_attn.{ o } .diff_b" in original_state_dict :
1661- converted_state_dict [f"blocks.{ i } .attn2.{ c } .lora_B.bias" ] = original_state_dict .pop (
1662- f"blocks.{ i } .cross_attn.{ o } .diff_b"
1663- )
1665+ original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_down_key } .weight"
1666+ converted_key = f"blocks.{ i } .attn2.{ c } .lora_A.weight"
1667+ if original_key in original_state_dict :
1668+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1669+
1670+ original_key = f"blocks.{ i } .cross_attn.{ o } .{ lora_up_key } .weight"
1671+ converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.weight"
1672+ if original_key in original_state_dict :
1673+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
1674+
1675+ original_key = f"blocks.{ i } .cross_attn.{ o } .diff_b"
1676+ converted_key = f"blocks.{ i } .attn2.{ c } .lora_B.bias"
1677+ if original_key in original_state_dict :
1678+ converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
16641679
16651680 # FFN
16661681 for o , c in zip (["ffn.0" , "ffn.2" ], ["net.0.proj" , "net.2" ]):
@@ -1674,10 +1689,10 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
16741689 if original_key in original_state_dict :
16751690 converted_state_dict [converted_key ] = original_state_dict .pop (original_key )
16761691
1677- if f"blocks.{ i } .{ o } .diff_b" in original_state_dict :
1678- converted_state_dict [ f"blocks.{ i } .ffn.{ c } .lora_B.bias" ] = original_state_dict . pop (
1679- f"blocks. { i } . { o } .diff_b"
1680- )
1692+ original_key = f"blocks.{ i } .{ o } .diff_b"
1693+ converted_key = f"blocks.{ i } .ffn.{ c } .lora_B.bias"
1694+ if original_key in original_state_dict :
1695+ converted_state_dict [ converted_key ] = original_state_dict . pop ( original_key )
16811696
16821697 # Remaining.
16831698 if original_state_dict :
0 commit comments