2727
2828import numpy as np
2929import torch
30- import torch .utils .checkpoint
3130import transformers
3231from accelerate import Accelerator
3332from accelerate .logging import get_logger
5352)
5453from diffusers .optimization import get_scheduler
5554from diffusers .training_utils import (
55+ _collate_lora_metadata ,
5656 _set_state_dict_into_text_encoder ,
5757 cast_training_params ,
5858 compute_density_for_timestep_sampling ,
@@ -358,7 +358,12 @@ def parse_args(input_args=None):
358358 default = 4 ,
359359 help = ("The dimension of the LoRA update matrices." ),
360360 )
361-
361+ parser .add_argument (
362+ "--lora_alpha" ,
363+ type = int ,
364+ default = 4 ,
365+ help = "LoRA alpha to be used for additional scaling." ,
366+ )
362367 parser .add_argument ("--lora_dropout" , type = float , default = 0.0 , help = "Dropout probability for LoRA layers" )
363368
364369 parser .add_argument (
@@ -1238,7 +1243,7 @@ def main(args):
12381243 # now we will add new LoRA weights the transformer layers
12391244 transformer_lora_config = LoraConfig (
12401245 r = args .rank ,
1241- lora_alpha = args .rank ,
1246+ lora_alpha = args .lora_alpha ,
12421247 lora_dropout = args .lora_dropout ,
12431248 init_lora_weights = "gaussian" ,
12441249 target_modules = target_modules ,
@@ -1247,7 +1252,7 @@ def main(args):
12471252 if args .train_text_encoder :
12481253 text_lora_config = LoraConfig (
12491254 r = args .rank ,
1250- lora_alpha = args .rank ,
1255+ lora_alpha = args .lora_alpha ,
12511256 lora_dropout = args .lora_dropout ,
12521257 init_lora_weights = "gaussian" ,
12531258 target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ],
@@ -1264,12 +1269,14 @@ def save_model_hook(models, weights, output_dir):
12641269 if accelerator .is_main_process :
12651270 transformer_lora_layers_to_save = None
12661271 text_encoder_one_lora_layers_to_save = None
1267-
1272+ modules_to_save = {}
12681273 for model in models :
12691274 if isinstance (model , type (unwrap_model (transformer ))):
12701275 transformer_lora_layers_to_save = get_peft_model_state_dict (model )
1276+ modules_to_save ["transformer" ] = model
12711277 elif isinstance (model , type (unwrap_model (text_encoder_one ))):
12721278 text_encoder_one_lora_layers_to_save = get_peft_model_state_dict (model )
1279+ modules_to_save ["text_encoder" ] = model
12731280 else :
12741281 raise ValueError (f"unexpected save model: { model .__class__ } " )
12751282
@@ -1280,6 +1287,7 @@ def save_model_hook(models, weights, output_dir):
12801287 output_dir ,
12811288 transformer_lora_layers = transformer_lora_layers_to_save ,
12821289 text_encoder_lora_layers = text_encoder_one_lora_layers_to_save ,
1290+ ** _collate_lora_metadata (modules_to_save ),
12831291 )
12841292
12851293 def load_model_hook (models , input_dir ):
@@ -1889,23 +1897,27 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18891897 # Save the lora layers
18901898 accelerator .wait_for_everyone ()
18911899 if accelerator .is_main_process :
1900+ modules_to_save = {}
18921901 transformer = unwrap_model (transformer )
18931902 if args .upcast_before_saving :
18941903 transformer .to (torch .float32 )
18951904 else :
18961905 transformer = transformer .to (weight_dtype )
18971906 transformer_lora_layers = get_peft_model_state_dict (transformer )
1907+ modules_to_save ["transformer" ] = transformer
18981908
18991909 if args .train_text_encoder :
19001910 text_encoder_one = unwrap_model (text_encoder_one )
19011911 text_encoder_lora_layers = get_peft_model_state_dict (text_encoder_one .to (torch .float32 ))
1912+ modules_to_save ["text_encoder" ] = text_encoder_one
19021913 else :
19031914 text_encoder_lora_layers = None
19041915
19051916 FluxPipeline .save_lora_weights (
19061917 save_directory = args .output_dir ,
19071918 transformer_lora_layers = transformer_lora_layers ,
19081919 text_encoder_lora_layers = text_encoder_lora_layers ,
1920+ ** _collate_lora_metadata (modules_to_save ),
19091921 )
19101922
19111923 # Final inference
0 commit comments