@@ -1703,7 +1703,8 @@ def lora_state_dict(
17031703 The subfolder location of a model file within a larger model repository on the Hub or locally.
17041704
17051705 """
1706- # Load the main state dict first which has the LoRA layers for transformer
1706+ # Load the main state dict first which has the LoRA layers for either of
1707+ # transformer and text encoder or both.
17071708 cache_dir = kwargs .pop ("cache_dir" , None )
17081709 force_download = kwargs .pop ("force_download" , False )
17091710 proxies = kwargs .pop ("proxies" , None )
@@ -1724,7 +1725,7 @@ def lora_state_dict(
17241725 "framework" : "pytorch" ,
17251726 }
17261727
1727- state_dict = cls . _fetch_state_dict (
1728+ state_dict = _fetch_state_dict (
17281729 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict ,
17291730 weight_name = weight_name ,
17301731 use_safetensors = use_safetensors ,
@@ -1739,6 +1740,12 @@ def lora_state_dict(
17391740 allow_pickle = allow_pickle ,
17401741 )
17411742
1743+ is_dora_scale_present = any ("dora_scale" in k for k in state_dict )
1744+ if is_dora_scale_present :
1745+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1746+ logger .warning (warn_msg )
1747+ state_dict = {k : v for k , v in state_dict .items () if "dora_scale" not in k }
1748+
17421749 return state_dict
17431750
17441751 def load_lora_weights (
@@ -1787,7 +1794,9 @@ def load_lora_weights(
17871794
17881795 @classmethod
17891796 # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
1790- def load_lora_into_transformer (cls , state_dict , transformer , adapter_name = None , _pipeline = None ):
1797+ def load_lora_into_transformer (
1798+ cls , state_dict , transformer , adapter_name = None , _pipeline = None , low_cpu_mem_usage = False
1799+ ):
17911800 """
17921801 This will load the LoRA layers specified in `state_dict` into `transformer`.
17931802
@@ -1801,68 +1810,24 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
18011810 adapter_name (`str`, *optional*):
18021811 Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
18031812 `default_{i}` where i is the total number of adapters being loaded.
1813+ low_cpu_mem_usage (`bool`, *optional*):
1814+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1815+ weights.
18041816 """
1805- from peft import LoraConfig , inject_adapter_in_model , set_peft_model_state_dict
1806-
1807- keys = list (state_dict .keys ())
1808-
1809- transformer_keys = [k for k in keys if k .startswith (cls .transformer_name )]
1810- state_dict = {
1811- k .replace (f"{ cls .transformer_name } ." , "" ): v for k , v in state_dict .items () if k in transformer_keys
1812- }
1813-
1814- if len (state_dict .keys ()) > 0 :
1815- # check with first key if is not in peft format
1816- first_key = next (iter (state_dict .keys ()))
1817- if "lora_A" not in first_key :
1818- state_dict = convert_unet_state_dict_to_peft (state_dict )
1819-
1820- if adapter_name in getattr (transformer , "peft_config" , {}):
1821- raise ValueError (
1822- f"Adapter name { adapter_name } already in use in the transformer - please select a new adapter name."
1823- )
1824-
1825- rank = {}
1826- for key , val in state_dict .items ():
1827- if "lora_B" in key :
1828- rank [key ] = val .shape [1 ]
1829-
1830- lora_config_kwargs = get_peft_kwargs (rank , network_alpha_dict = None , peft_state_dict = state_dict )
1831- if "use_dora" in lora_config_kwargs :
1832- if lora_config_kwargs ["use_dora" ] and is_peft_version ("<" , "0.9.0" ):
1833- raise ValueError (
1834- "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1835- )
1836- else :
1837- lora_config_kwargs .pop ("use_dora" )
1838- lora_config = LoraConfig (** lora_config_kwargs )
1839-
1840- # adapter_name
1841- if adapter_name is None :
1842- adapter_name = get_adapter_name (transformer )
1843-
1844- # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
1845- # otherwise loading LoRA weights will lead to an error
1846- is_model_cpu_offload , is_sequential_cpu_offload = cls ._optionally_disable_offloading (_pipeline )
1847-
1848- inject_adapter_in_model (lora_config , transformer , adapter_name = adapter_name )
1849- incompatible_keys = set_peft_model_state_dict (transformer , state_dict , adapter_name )
1850-
1851- if incompatible_keys is not None :
1852- # check only for unexpected keys
1853- unexpected_keys = getattr (incompatible_keys , "unexpected_keys" , None )
1854- if unexpected_keys :
1855- logger .warning (
1856- f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1857- f" { unexpected_keys } . "
1858- )
1817+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
1818+ raise ValueError (
1819+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1820+ )
18591821
1860- # Offload back.
1861- if is_model_cpu_offload :
1862- _pipeline .enable_model_cpu_offload ()
1863- elif is_sequential_cpu_offload :
1864- _pipeline .enable_sequential_cpu_offload ()
1865- # Unsafe code />
1822+ # Load the layers corresponding to transformer.
1823+ logger .info (f"Loading { cls .transformer_name } ." )
1824+ transformer .load_lora_adapter (
1825+ state_dict ,
1826+ network_alphas = None ,
1827+ adapter_name = adapter_name ,
1828+ _pipeline = _pipeline ,
1829+ low_cpu_mem_usage = low_cpu_mem_usage ,
1830+ )
18661831
18671832 @classmethod
18681833 def save_lora_weights (
0 commit comments