Skip to content

Commit 912fb8d

Browse files
committed
Attept 2 to fix tests.
1 parent 80ac0d4 commit 912fb8d

File tree

2 files changed

+209
-13
lines changed

2 files changed

+209
-13
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 205 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1646,7 +1646,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
16461646
Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`].
16471647
"""
16481648

1649-
_lora_loadable_modules = ["transformer"]
1649+
_lora_loadable_modules = ["transformer", "text_encoder"]
16501650
transformer_name = TRANSFORMER_NAME
16511651

16521652
@classmethod
@@ -1747,11 +1747,13 @@ def lora_state_dict(
17471747

17481748
return state_dict
17491749

1750+
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_weights
17501751
def load_lora_weights(
17511752
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
17521753
):
17531754
"""
1754-
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
1755+
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
1756+
`self.text_encoder`.
17551757
17561758
All kwargs are forwarded to `self.lora_state_dict`.
17571759
@@ -1764,32 +1766,72 @@ def load_lora_weights(
17641766
Parameters:
17651767
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
17661768
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
1767-
kwargs (`dict`, *optional*):
1768-
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
17691769
adapter_name (`str`, *optional*):
17701770
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
17711771
`default_{i}` where i is the total number of adapters being loaded.
1772+
low_cpu_mem_usage (`bool`, *optional*):
1773+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1774+
weights.
1775+
kwargs (`dict`, *optional*):
1776+
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
17721777
"""
17731778
if not USE_PEFT_BACKEND:
17741779
raise ValueError("PEFT backend is required for this method.")
17751780

1781+
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
1782+
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
1783+
raise ValueError(
1784+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1785+
)
1786+
17761787
# if a dict is passed, copy it instead of modifying it inplace
17771788
if isinstance(pretrained_model_name_or_path_or_dict, dict):
17781789
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
17791790

17801791
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
17811792
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
17821793

1783-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1794+
is_correct_format = all("lora" in key for key in state_dict.keys())
17841795
if not is_correct_format:
17851796
raise ValueError("Invalid LoRA checkpoint.")
17861797

1787-
self.load_lora_into_transformer(
1788-
state_dict,
1789-
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
1790-
adapter_name=adapter_name,
1791-
_pipeline=self,
1792-
)
1798+
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
1799+
if len(transformer_state_dict) > 0:
1800+
self.load_lora_into_transformer(
1801+
state_dict,
1802+
transformer=getattr(self, self.transformer_name)
1803+
if not hasattr(self, "transformer")
1804+
else self.transformer,
1805+
adapter_name=adapter_name,
1806+
_pipeline=self,
1807+
low_cpu_mem_usage=low_cpu_mem_usage,
1808+
)
1809+
1810+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1811+
if len(text_encoder_state_dict) > 0:
1812+
self.load_lora_into_text_encoder(
1813+
text_encoder_state_dict,
1814+
network_alphas=None,
1815+
text_encoder=self.text_encoder,
1816+
prefix="text_encoder",
1817+
lora_scale=self.lora_scale,
1818+
adapter_name=adapter_name,
1819+
_pipeline=self,
1820+
low_cpu_mem_usage=low_cpu_mem_usage,
1821+
)
1822+
1823+
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1824+
if len(text_encoder_2_state_dict) > 0:
1825+
self.load_lora_into_text_encoder(
1826+
text_encoder_2_state_dict,
1827+
network_alphas=None,
1828+
text_encoder=self.text_encoder_2,
1829+
prefix="text_encoder_2",
1830+
lora_scale=self.lora_scale,
1831+
adapter_name=adapter_name,
1832+
_pipeline=self,
1833+
low_cpu_mem_usage=low_cpu_mem_usage,
1834+
)
17931835

17941836
@classmethod
17951837
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
@@ -1828,6 +1870,158 @@ def load_lora_into_transformer(
18281870
low_cpu_mem_usage=low_cpu_mem_usage,
18291871
)
18301872

1873+
@classmethod
1874+
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
1875+
def load_lora_into_text_encoder(
1876+
cls,
1877+
state_dict,
1878+
network_alphas,
1879+
text_encoder,
1880+
prefix=None,
1881+
lora_scale=1.0,
1882+
adapter_name=None,
1883+
_pipeline=None,
1884+
low_cpu_mem_usage=False,
1885+
):
1886+
"""
1887+
This will load the LoRA layers specified in `state_dict` into `text_encoder`
1888+
1889+
Parameters:
1890+
state_dict (`dict`):
1891+
A standard state dict containing the lora layer parameters. The key should be prefixed with an
1892+
additional `text_encoder` to distinguish between unet lora layers.
1893+
network_alphas (`Dict[str, float]`):
1894+
The value of the network alpha used for stable learning and preventing underflow. This value has the
1895+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
1896+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
1897+
text_encoder (`CLIPTextModel`):
1898+
The text encoder model to load the LoRA layers into.
1899+
prefix (`str`):
1900+
Expected prefix of the `text_encoder` in the `state_dict`.
1901+
lora_scale (`float`):
1902+
How much to scale the output of the lora linear layer before it is added with the output of the regular
1903+
lora layer.
1904+
adapter_name (`str`, *optional*):
1905+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
1906+
`default_{i}` where i is the total number of adapters being loaded.
1907+
low_cpu_mem_usage (`bool`, *optional*):
1908+
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
1909+
weights.
1910+
"""
1911+
if not USE_PEFT_BACKEND:
1912+
raise ValueError("PEFT backend is required for this method.")
1913+
1914+
peft_kwargs = {}
1915+
if low_cpu_mem_usage:
1916+
if not is_peft_version(">=", "0.13.1"):
1917+
raise ValueError(
1918+
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
1919+
)
1920+
if not is_transformers_version(">", "4.45.2"):
1921+
# Note from sayakpaul: It's not in `transformers` stable yet.
1922+
# https://github.com/huggingface/transformers/pull/33725/
1923+
raise ValueError(
1924+
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
1925+
)
1926+
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
1927+
1928+
from peft import LoraConfig
1929+
1930+
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
1931+
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
1932+
# their prefixes.
1933+
keys = list(state_dict.keys())
1934+
prefix = cls.text_encoder_name if prefix is None else prefix
1935+
1936+
# Safe prefix to check with.
1937+
if any(cls.text_encoder_name in key for key in keys):
1938+
# Load the layers corresponding to text encoder and make necessary adjustments.
1939+
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
1940+
text_encoder_lora_state_dict = {
1941+
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
1942+
}
1943+
1944+
if len(text_encoder_lora_state_dict) > 0:
1945+
logger.info(f"Loading {prefix}.")
1946+
rank = {}
1947+
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
1948+
1949+
# convert state dict
1950+
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
1951+
1952+
for name, _ in text_encoder_attn_modules(text_encoder):
1953+
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
1954+
rank_key = f"{name}.{module}.lora_B.weight"
1955+
if rank_key not in text_encoder_lora_state_dict:
1956+
continue
1957+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
1958+
1959+
for name, _ in text_encoder_mlp_modules(text_encoder):
1960+
for module in ("fc1", "fc2"):
1961+
rank_key = f"{name}.{module}.lora_B.weight"
1962+
if rank_key not in text_encoder_lora_state_dict:
1963+
continue
1964+
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
1965+
1966+
if network_alphas is not None:
1967+
alpha_keys = [
1968+
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
1969+
]
1970+
network_alphas = {
1971+
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
1972+
}
1973+
1974+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
1975+
1976+
if "use_dora" in lora_config_kwargs:
1977+
if lora_config_kwargs["use_dora"]:
1978+
if is_peft_version("<", "0.9.0"):
1979+
raise ValueError(
1980+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
1981+
)
1982+
else:
1983+
if is_peft_version("<", "0.9.0"):
1984+
lora_config_kwargs.pop("use_dora")
1985+
1986+
if "lora_bias" in lora_config_kwargs:
1987+
if lora_config_kwargs["lora_bias"]:
1988+
if is_peft_version("<=", "0.13.2"):
1989+
raise ValueError(
1990+
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
1991+
)
1992+
else:
1993+
if is_peft_version("<=", "0.13.2"):
1994+
lora_config_kwargs.pop("lora_bias")
1995+
1996+
lora_config = LoraConfig(**lora_config_kwargs)
1997+
1998+
# adapter_name
1999+
if adapter_name is None:
2000+
adapter_name = get_adapter_name(text_encoder)
2001+
2002+
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
2003+
2004+
# inject LoRA layers and load the state dict
2005+
# in transformers we automatically check whether the adapter name is already in use or not
2006+
text_encoder.load_adapter(
2007+
adapter_name=adapter_name,
2008+
adapter_state_dict=text_encoder_lora_state_dict,
2009+
peft_config=lora_config,
2010+
**peft_kwargs,
2011+
)
2012+
2013+
# scale LoRA layers with `lora_scale`
2014+
scale_lora_layers(text_encoder, weight=lora_scale)
2015+
2016+
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
2017+
2018+
# Offload back.
2019+
if is_model_cpu_offload:
2020+
_pipeline.enable_model_cpu_offload()
2021+
elif is_sequential_cpu_offload:
2022+
_pipeline.enable_sequential_cpu_offload()
2023+
# Unsafe code />
2024+
18312025
@classmethod
18322026
def save_lora_weights(
18332027
cls,

tests/lora/test_lora_layers_af.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import unittest
1717

1818
import torch
19-
from transformers import AutoTokenizer, T5EncoderModel
19+
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, UMT5EncoderModel
2020

2121
from diffusers import (
2222
AuraFlowPipeline,
@@ -59,7 +59,9 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
5959
}
6060
transformer_cls = AuraFlowTransformer2DModel
6161
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
62-
text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
62+
text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5"
63+
64+
text_encoder_target_modules = ["q", "k", "v", "o"]
6365

6466
vae_kwargs = {
6567
"sample_size": 32,

0 commit comments

Comments
 (0)