@@ -1646,7 +1646,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
16461646 Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`].
16471647 """
16481648
1649- _lora_loadable_modules = ["transformer" ]
1649+ _lora_loadable_modules = ["transformer" , "text_encoder" ]
16501650 transformer_name = TRANSFORMER_NAME
16511651
16521652 @classmethod
@@ -1747,11 +1747,13 @@ def lora_state_dict(
17471747
17481748 return state_dict
17491749
1750+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_weights
17501751 def load_lora_weights (
17511752 self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], adapter_name = None , ** kwargs
17521753 ):
17531754 """
1754- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
1755+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
1756+ `self.text_encoder`.
17551757
17561758 All kwargs are forwarded to `self.lora_state_dict`.
17571759
@@ -1764,32 +1766,72 @@ def load_lora_weights(
17641766 Parameters:
17651767 pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
17661768 See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1767- kwargs (`dict`, *optional*):
1768- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
17691769 adapter_name (`str`, *optional*):
17701770 Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
17711771 `default_{i}` where i is the total number of adapters being loaded.
1772+ low_cpu_mem_usage (`bool`, *optional*):
1773+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1774+ weights.
1775+ kwargs (`dict`, *optional*):
1776+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
17721777 """
17731778 if not USE_PEFT_BACKEND :
17741779 raise ValueError ("PEFT backend is required for this method." )
17751780
1781+ low_cpu_mem_usage = kwargs .pop ("low_cpu_mem_usage" , _LOW_CPU_MEM_USAGE_DEFAULT_LORA )
1782+ if low_cpu_mem_usage and is_peft_version ("<" , "0.13.0" ):
1783+ raise ValueError (
1784+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1785+ )
1786+
17761787 # if a dict is passed, copy it instead of modifying it inplace
17771788 if isinstance (pretrained_model_name_or_path_or_dict , dict ):
17781789 pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict .copy ()
17791790
17801791 # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
17811792 state_dict = self .lora_state_dict (pretrained_model_name_or_path_or_dict , ** kwargs )
17821793
1783- is_correct_format = all ("lora" in key or "dora_scale" in key for key in state_dict .keys ())
1794+ is_correct_format = all ("lora" in key for key in state_dict .keys ())
17841795 if not is_correct_format :
17851796 raise ValueError ("Invalid LoRA checkpoint." )
17861797
1787- self .load_lora_into_transformer (
1788- state_dict ,
1789- transformer = getattr (self , self .transformer_name ) if not hasattr (self , "transformer" ) else self .transformer ,
1790- adapter_name = adapter_name ,
1791- _pipeline = self ,
1792- )
1798+ transformer_state_dict = {k : v for k , v in state_dict .items () if "transformer." in k }
1799+ if len (transformer_state_dict ) > 0 :
1800+ self .load_lora_into_transformer (
1801+ state_dict ,
1802+ transformer = getattr (self , self .transformer_name )
1803+ if not hasattr (self , "transformer" )
1804+ else self .transformer ,
1805+ adapter_name = adapter_name ,
1806+ _pipeline = self ,
1807+ low_cpu_mem_usage = low_cpu_mem_usage ,
1808+ )
1809+
1810+ text_encoder_state_dict = {k : v for k , v in state_dict .items () if "text_encoder." in k }
1811+ if len (text_encoder_state_dict ) > 0 :
1812+ self .load_lora_into_text_encoder (
1813+ text_encoder_state_dict ,
1814+ network_alphas = None ,
1815+ text_encoder = self .text_encoder ,
1816+ prefix = "text_encoder" ,
1817+ lora_scale = self .lora_scale ,
1818+ adapter_name = adapter_name ,
1819+ _pipeline = self ,
1820+ low_cpu_mem_usage = low_cpu_mem_usage ,
1821+ )
1822+
1823+ text_encoder_2_state_dict = {k : v for k , v in state_dict .items () if "text_encoder_2." in k }
1824+ if len (text_encoder_2_state_dict ) > 0 :
1825+ self .load_lora_into_text_encoder (
1826+ text_encoder_2_state_dict ,
1827+ network_alphas = None ,
1828+ text_encoder = self .text_encoder_2 ,
1829+ prefix = "text_encoder_2" ,
1830+ lora_scale = self .lora_scale ,
1831+ adapter_name = adapter_name ,
1832+ _pipeline = self ,
1833+ low_cpu_mem_usage = low_cpu_mem_usage ,
1834+ )
17931835
17941836 @classmethod
17951837 # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
@@ -1828,6 +1870,158 @@ def load_lora_into_transformer(
18281870 low_cpu_mem_usage = low_cpu_mem_usage ,
18291871 )
18301872
1873+ @classmethod
1874+ # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
1875+ def load_lora_into_text_encoder (
1876+ cls ,
1877+ state_dict ,
1878+ network_alphas ,
1879+ text_encoder ,
1880+ prefix = None ,
1881+ lora_scale = 1.0 ,
1882+ adapter_name = None ,
1883+ _pipeline = None ,
1884+ low_cpu_mem_usage = False ,
1885+ ):
1886+ """
1887+ This will load the LoRA layers specified in `state_dict` into `text_encoder`
1888+
1889+ Parameters:
1890+ state_dict (`dict`):
1891+ A standard state dict containing the lora layer parameters. The key should be prefixed with an
1892+ additional `text_encoder` to distinguish between unet lora layers.
1893+ network_alphas (`Dict[str, float]`):
1894+ The value of the network alpha used for stable learning and preventing underflow. This value has the
1895+ same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1896+ link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1897+ text_encoder (`CLIPTextModel`):
1898+ The text encoder model to load the LoRA layers into.
1899+ prefix (`str`):
1900+ Expected prefix of the `text_encoder` in the `state_dict`.
1901+ lora_scale (`float`):
1902+ How much to scale the output of the lora linear layer before it is added with the output of the regular
1903+ lora layer.
1904+ adapter_name (`str`, *optional*):
1905+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1906+ `default_{i}` where i is the total number of adapters being loaded.
1907+ low_cpu_mem_usage (`bool`, *optional*):
1908+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1909+ weights.
1910+ """
1911+ if not USE_PEFT_BACKEND :
1912+ raise ValueError ("PEFT backend is required for this method." )
1913+
1914+ peft_kwargs = {}
1915+ if low_cpu_mem_usage :
1916+ if not is_peft_version (">=" , "0.13.1" ):
1917+ raise ValueError (
1918+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1919+ )
1920+ if not is_transformers_version (">" , "4.45.2" ):
1921+ # Note from sayakpaul: It's not in `transformers` stable yet.
1922+ # https://github.com/huggingface/transformers/pull/33725/
1923+ raise ValueError (
1924+ "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
1925+ )
1926+ peft_kwargs ["low_cpu_mem_usage" ] = low_cpu_mem_usage
1927+
1928+ from peft import LoraConfig
1929+
1930+ # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
1931+ # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
1932+ # their prefixes.
1933+ keys = list (state_dict .keys ())
1934+ prefix = cls .text_encoder_name if prefix is None else prefix
1935+
1936+ # Safe prefix to check with.
1937+ if any (cls .text_encoder_name in key for key in keys ):
1938+ # Load the layers corresponding to text encoder and make necessary adjustments.
1939+ text_encoder_keys = [k for k in keys if k .startswith (prefix ) and k .split ("." )[0 ] == prefix ]
1940+ text_encoder_lora_state_dict = {
1941+ k .replace (f"{ prefix } ." , "" ): v for k , v in state_dict .items () if k in text_encoder_keys
1942+ }
1943+
1944+ if len (text_encoder_lora_state_dict ) > 0 :
1945+ logger .info (f"Loading { prefix } ." )
1946+ rank = {}
1947+ text_encoder_lora_state_dict = convert_state_dict_to_diffusers (text_encoder_lora_state_dict )
1948+
1949+ # convert state dict
1950+ text_encoder_lora_state_dict = convert_state_dict_to_peft (text_encoder_lora_state_dict )
1951+
1952+ for name , _ in text_encoder_attn_modules (text_encoder ):
1953+ for module in ("out_proj" , "q_proj" , "k_proj" , "v_proj" ):
1954+ rank_key = f"{ name } .{ module } .lora_B.weight"
1955+ if rank_key not in text_encoder_lora_state_dict :
1956+ continue
1957+ rank [rank_key ] = text_encoder_lora_state_dict [rank_key ].shape [1 ]
1958+
1959+ for name , _ in text_encoder_mlp_modules (text_encoder ):
1960+ for module in ("fc1" , "fc2" ):
1961+ rank_key = f"{ name } .{ module } .lora_B.weight"
1962+ if rank_key not in text_encoder_lora_state_dict :
1963+ continue
1964+ rank [rank_key ] = text_encoder_lora_state_dict [rank_key ].shape [1 ]
1965+
1966+ if network_alphas is not None :
1967+ alpha_keys = [
1968+ k for k in network_alphas .keys () if k .startswith (prefix ) and k .split ("." )[0 ] == prefix
1969+ ]
1970+ network_alphas = {
1971+ k .replace (f"{ prefix } ." , "" ): v for k , v in network_alphas .items () if k in alpha_keys
1972+ }
1973+
1974+ lora_config_kwargs = get_peft_kwargs (rank , network_alphas , text_encoder_lora_state_dict , is_unet = False )
1975+
1976+ if "use_dora" in lora_config_kwargs :
1977+ if lora_config_kwargs ["use_dora" ]:
1978+ if is_peft_version ("<" , "0.9.0" ):
1979+ raise ValueError (
1980+ "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1981+ )
1982+ else :
1983+ if is_peft_version ("<" , "0.9.0" ):
1984+ lora_config_kwargs .pop ("use_dora" )
1985+
1986+ if "lora_bias" in lora_config_kwargs :
1987+ if lora_config_kwargs ["lora_bias" ]:
1988+ if is_peft_version ("<=" , "0.13.2" ):
1989+ raise ValueError (
1990+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
1991+ )
1992+ else :
1993+ if is_peft_version ("<=" , "0.13.2" ):
1994+ lora_config_kwargs .pop ("lora_bias" )
1995+
1996+ lora_config = LoraConfig (** lora_config_kwargs )
1997+
1998+ # adapter_name
1999+ if adapter_name is None :
2000+ adapter_name = get_adapter_name (text_encoder )
2001+
2002+ is_model_cpu_offload , is_sequential_cpu_offload = cls ._optionally_disable_offloading (_pipeline )
2003+
2004+ # inject LoRA layers and load the state dict
2005+ # in transformers we automatically check whether the adapter name is already in use or not
2006+ text_encoder .load_adapter (
2007+ adapter_name = adapter_name ,
2008+ adapter_state_dict = text_encoder_lora_state_dict ,
2009+ peft_config = lora_config ,
2010+ ** peft_kwargs ,
2011+ )
2012+
2013+ # scale LoRA layers with `lora_scale`
2014+ scale_lora_layers (text_encoder , weight = lora_scale )
2015+
2016+ text_encoder .to (device = text_encoder .device , dtype = text_encoder .dtype )
2017+
2018+ # Offload back.
2019+ if is_model_cpu_offload :
2020+ _pipeline .enable_model_cpu_offload ()
2021+ elif is_sequential_cpu_offload :
2022+ _pipeline .enable_sequential_cpu_offload ()
2023+ # Unsafe code />
2024+
18312025 @classmethod
18322026 def save_lora_weights (
18332027 cls ,
0 commit comments