Skip to content

Commit 1b91856

Browse files
SurAyushsayakpaul
andauthored
Fix examples not loading LoRA adapter weights from checkpoint (huggingface#12690)
* Fix examples not loading LoRA adapter weights from checkpoint * Updated lora saving logic with accelerate save_model_hook and load_model_hook * Formatted the changes using ruff * import and upcasting changed --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 01e3555 commit 1b91856

File tree

1 file changed

+61
-13
lines changed

1 file changed

+61
-13
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from huggingface_hub import create_repo, upload_folder
3838
from packaging import version
3939
from peft import LoraConfig
40-
from peft.utils import get_peft_model_state_dict
40+
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
4141
from torchvision import transforms
4242
from tqdm.auto import tqdm
4343
from transformers import CLIPTextModel, CLIPTokenizer
@@ -46,7 +46,12 @@
4646
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
4747
from diffusers.optimization import get_scheduler
4848
from diffusers.training_utils import cast_training_params, compute_snr
49-
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
49+
from diffusers.utils import (
50+
check_min_version,
51+
convert_state_dict_to_diffusers,
52+
convert_unet_state_dict_to_peft,
53+
is_wandb_available,
54+
)
5055
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5156
from diffusers.utils.import_utils import is_xformers_available
5257
from diffusers.utils.torch_utils import is_compiled_module
@@ -708,6 +713,56 @@ def collate_fn(examples):
708713
num_workers=args.dataloader_num_workers,
709714
)
710715

716+
def save_model_hook(models, weights, output_dir):
717+
if accelerator.is_main_process:
718+
unet_lora_layers_to_save = None
719+
720+
for model in models:
721+
if isinstance(model, type(unwrap_model(unet))):
722+
unet_lora_layers_to_save = get_peft_model_state_dict(model)
723+
else:
724+
raise ValueError(f"Unexpected save model: {model.__class__}")
725+
726+
# make sure to pop weight so that corresponding model is not saved again
727+
weights.pop()
728+
729+
StableDiffusionPipeline.save_lora_weights(
730+
save_directory=output_dir,
731+
unet_lora_layers=unet_lora_layers_to_save,
732+
safe_serialization=True,
733+
)
734+
735+
def load_model_hook(models, input_dir):
736+
unet_ = None
737+
738+
while len(models) > 0:
739+
model = models.pop()
740+
if isinstance(model, type(unwrap_model(unet))):
741+
unet_ = model
742+
else:
743+
raise ValueError(f"unexpected save model: {model.__class__}")
744+
745+
# returns a tuple of state dictionary and network alphas
746+
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
747+
748+
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
749+
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
750+
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
751+
752+
if incompatible_keys is not None:
753+
# check only for unexpected keys
754+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
755+
# throw warning if some unexpected keys are found and continue loading
756+
if unexpected_keys:
757+
logger.warning(
758+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
759+
f" {unexpected_keys}. "
760+
)
761+
762+
# Make sure the trainable params are in float32
763+
if args.mixed_precision in ["fp16"]:
764+
cast_training_params([unet_], dtype=torch.float32)
765+
711766
# Scheduler and math around the number of training steps.
712767
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
713768
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
@@ -732,6 +787,10 @@ def collate_fn(examples):
732787
unet, optimizer, train_dataloader, lr_scheduler
733788
)
734789

790+
# Register the hooks for efficient saving and loading of LoRA weights
791+
accelerator.register_save_state_pre_hook(save_model_hook)
792+
accelerator.register_load_state_pre_hook(load_model_hook)
793+
735794
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
736795
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
737796
if args.max_train_steps is None:
@@ -906,17 +965,6 @@ def collate_fn(examples):
906965
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
907966
accelerator.save_state(save_path)
908967

909-
unwrapped_unet = unwrap_model(unet)
910-
unet_lora_state_dict = convert_state_dict_to_diffusers(
911-
get_peft_model_state_dict(unwrapped_unet)
912-
)
913-
914-
StableDiffusionPipeline.save_lora_weights(
915-
save_directory=save_path,
916-
unet_lora_layers=unet_lora_state_dict,
917-
safe_serialization=True,
918-
)
919-
920968
logger.info(f"Saved state to {save_path}")
921969

922970
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}

0 commit comments

Comments
 (0)