Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2213,6 +2213,10 @@ def convert_key(key: str) -> str:

state_dict = {convert_key(k): v for k, v in state_dict.items()}

has_default = any("default." in k for k in state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's starting with "default.", then let's be explicit about that:

Suggested change
has_default = any("default." in k for k in state_dict)
has_default = any(k.startswith("default.") for k in state_dict)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"default" isn't the prefix though, e.g.transformer_blocks.0.attn.add_v_proj.lora_A.default.weight

if has_default:
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
Comment on lines +2217 to +2218
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, that it's done as intended:

Suggested change
if has_default:
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
if has_default:
state_dict = {k[len("default."):]: v for k, v in state_dict.items()}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as here - 'default' is not in the key's prefix, so this won't be the intended behavior in this case


converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4940,7 +4940,8 @@ def lora_state_dict(
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
has_default = any("default." in k for k in state_dict)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as #12581 (comment) - 'default' is not in the key's prefix, so this won't be the intended behavior in this case

if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)

out = (state_dict, metadata) if return_lora_metadata else state_dict
Expand Down
Loading