From e50a10820874e4ff4fcf90adfb1d0e14cb7fb5c8 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 17:22:20 +0530 Subject: [PATCH 01/53] Add AuraFlowLoraLoaderMixin --- src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 334 ++++++++++++++++++ .../transformers/auraflow_transformer_2d.py | 25 +- .../pipelines/aura_flow/pipeline_aura_flow.py | 3 +- 4 files changed, 360 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index b59150376599..92c701936eb6 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -64,6 +64,7 @@ def text_encoder_attn_modules(text_encoder): "AmusedLoraLoaderMixin", "StableDiffusionLoraLoaderMixin", "SD3LoraLoaderMixin", + "AuraFlowLoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "LTXVideoLoraLoaderMixin", "LoraLoaderMixin", @@ -95,6 +96,7 @@ def text_encoder_attn_modules(text_encoder): Mochi1LoraLoaderMixin, SanaLoraLoaderMixin, SD3LoraLoaderMixin, + AuraFlowLoraLoaderMixin, StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, ) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b8c44e480093..79ad45bd3b81 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1641,6 +1641,337 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t super().unfuse_lora(components=components) +class AuraFlowLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`AuraFlowTransformer2DModel`] + Specific to [`AuraFlowPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + return state_dict + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + ) + + + @classmethod + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SD3Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers): + raise ValueError( + "You must pass `transformer_lora_layers`." + ) + + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components) + class FluxLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`FluxTransformer2DModel`], @@ -1649,6 +1980,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin): Specific to [`StableDiffusion3Pipeline`]. """ +# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially +# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. +class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index b3f29e6b6224..ad03f98338b1 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import torch import torch.nn as nn @@ -32,6 +32,8 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormZero, FP32LayerNorm +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -253,7 +255,7 @@ def forward( return encoder_hidden_states, hidden_states -class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). @@ -451,6 +453,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] @@ -463,7 +466,19 @@ def forward( encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) - + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -538,6 +553,10 @@ def custom_forward(*inputs): shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size) ) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 8737b219c833..72a7e08f4fb6 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -24,6 +24,7 @@ from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from ...loaders import AuraFlowLoraLoaderMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -104,7 +105,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class AuraFlowPipeline(DiffusionPipeline): +class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin): r""" Args: tokenizer (`T5TokenizerFast`): From 658d0586866e2920002d942b097ef96a9b6a70f9 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 18:46:20 +0530 Subject: [PATCH 02/53] Add comments, remove qkv fusion --- src/diffusers/loaders/__init__.py | 2 +- src/diffusers/loaders/lora_pipeline.py | 7 ++++++- .../transformers/auraflow_transformer_2d.py | 19 +++++++++---------- .../pipelines/aura_flow/pipeline_aura_flow.py | 2 +- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 92c701936eb6..ec233ecf4a70 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -91,12 +91,12 @@ def text_encoder_attn_modules(text_encoder): AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, FluxLoraLoaderMixin, + AuraFlowLoraLoaderMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Mochi1LoraLoaderMixin, SanaLoraLoaderMixin, SD3LoraLoaderMixin, - AuraFlowLoraLoaderMixin, StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, ) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 79ad45bd3b81..1fd0bdc9c395 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1649,9 +1649,9 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["transformer"] transformer_name = TRANSFORMER_NAME - text_encoder_name = TEXT_ENCODER_NAME @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict @validate_hf_hub_args def lora_state_dict( cls, @@ -1742,6 +1742,7 @@ def lora_state_dict( return state_dict + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -1788,6 +1789,7 @@ def load_lora_weights( @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1866,6 +1868,7 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, # Unsafe code /> @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -1913,6 +1916,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -1956,6 +1960,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index ad03f98338b1..13fe6309ce6e 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -20,7 +20,8 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import is_torch_version, logging +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention_processor import ( Attention, @@ -32,8 +33,6 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormZero, FP32LayerNorm -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -344,8 +343,8 @@ def __init__( self.gradient_checkpointing = False - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: @@ -453,7 +452,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] @@ -466,18 +465,18 @@ def forward( encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 72a7e08f4fb6..5c9c803e36e3 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -18,13 +18,13 @@ from transformers import T5Tokenizer, UMT5EncoderModel from ...image_processor import VaeImageProcessor +from ...loaders import AuraFlowLoraLoaderMixin from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from ...loaders import AuraFlowLoraLoaderMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 4208d09601afd4afdce33f23d3f6921822703f67 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 18:50:50 +0530 Subject: [PATCH 03/53] Add Tests --- tests/lora/test_lora_layers_af.py | 90 +++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/lora/test_lora_layers_af.py diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py new file mode 100644 index 000000000000..2b050aa74da9 --- /dev/null +++ b/tests/lora/test_lora_layers_af.py @@ -0,0 +1,90 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest + +from diffusers import ( + FlowMatchEulerDiscreteScheduler, + AuraFlowPipeline, +) +from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device + + +if is_peft_available(): + pass + +sys.path.append(".") + +from utils import PeftLoraLoaderMixinTests # noqa: E402 + + +@require_peft_backend +class AFLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = AuraFlowPipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler() + scheduler_kwargs = {} + transformer_kwargs = { + "sample_size": 64, + "patch_size": 2, + "in_channels": 4, + "num_mmdit_layers": 4, + "num_single_dit_layers": 32, + "attention_head_dim": 256, + "num_attention_heads": 12, + "joint_attention_dim": 2048, + "caption_projection_dim": 3072, + "out_channels": 4, + "pos_embed_max_size": 1024, + } + vae_kwargs = { + "sample_size": 1024, + "in_channels": 3, + "out_channels": 3, + "block_out_channels": [ + 128, + 256, + 512, + 512 + ], + "layers_per_block": 2, + "latent_channels": 4, + "norm_num_groups": 32, + "use_quant_conv": True, + "use_post_quant_conv": True, + "shift_factor": None, + "scaling_factor": 0.13025, + } + has_three_text_encoders = False + + @require_torch_gpu + def test_af_lora(self): + """ + Test loading the loras that are saved with the diffusers and peft formats. + Related PR: https://github.com/huggingface/diffusers/pull/8584 + """ + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + lora_model_id = "Warlord-K/gorkem-auraflow-lora" + + lora_filename = "pytorch_lora_weights.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + pipe.unload_lora_weights() + + lora_filename = "lora_peft_format.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) From 98b19f66aa67bc193e491a85b6ebb4c59587db8c Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Tue, 30 Jul 2024 18:53:49 +0530 Subject: [PATCH 04/53] Add AuraFlowLoraLoaderMixin to documentation --- docs/source/en/api/loaders/lora.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 5dde55ada562..3c1ab7e5b6eb 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -20,6 +20,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi - [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux). - [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox). - [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi). +- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow). - [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`]. - [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more. @@ -41,6 +42,7 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse [[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin +<<<<<<< HEAD ## FluxLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin @@ -52,6 +54,11 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## Mochi1LoraLoaderMixin [[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin +======= +## AuraFlowLoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin +>>>>>>> a45d48a56 (Add AuraFlowLoraLoaderMixin to documentation) ## AmusedLoraLoaderMixin From 71f8bace8b616fdbeada27d0e71c0522ad52aa61 Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Mon, 12 Aug 2024 01:44:56 +0530 Subject: [PATCH 05/53] Add Suggested changes --- src/diffusers/loaders/lora_pipeline.py | 471 +++++++++++++++++- src/diffusers/loaders/peft.py | 1 + .../transformers/auraflow_transformer_2d.py | 2 +- tests/lora/test_lora_layers_af.py | 65 +-- 4 files changed, 496 insertions(+), 43 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1fd0bdc9c395..4d13134e47b1 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -22,6 +22,7 @@ USE_PEFT_BACKEND, convert_state_dict_to_diffusers, convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, deprecate, get_adapter_name, get_peft_kwargs, @@ -1747,7 +1748,8 @@ def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. @@ -1787,6 +1789,473 @@ def load_lora_weights( _pipeline=self, ) + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=None, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer + def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + transformer (`SD3Transformer2DModel`): + The Transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if len(state_dict.keys()) > 0: + # check with first key if is not in peft format + first_key = next(iter(state_dict.keys())) + if "lora_A" not in first_key: + state_dict = convert_unet_state_dict_to_peft(state_dict) + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + See `LoRALinearLayer` for more details. + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") + + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + + # Save the model + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer + def fuse_lora( + self, + components: List[str] = ["transformer", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + + + This is an experimental API. + + + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + ) + + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + + + This is an experimental API. + + + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components) + +class AuraFlowLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`AuraFlowTransformer2DModel`] + Specific to [`AuraFlowPipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + + """ + # Load the main state dict first which has the LoRA layers for transformer + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + state_dict = cls._fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + return state_dict + + def load_lora_weights( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index a791a250af08..7a8332bb289a 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -55,6 +55,7 @@ "MochiTransformer3DModel": lambda model_cls, weights: weights, "LTXVideoTransformer3DModel": lambda model_cls, weights: weights, "SanaTransformer2DModel": lambda model_cls, weights: weights, + "AuraFlowTransformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 13fe6309ce6e..5c8e6720ecc2 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -254,7 +254,7 @@ def forward( return encoder_hidden_states, hidden_states -class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): r""" A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index 2b050aa74da9..9615249633cd 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -15,6 +15,8 @@ import sys import unittest +from transformers import AutoTokenizer, T5EncoderModel + from diffusers import ( FlowMatchEulerDiscreteScheduler, AuraFlowPipeline, @@ -31,60 +33,41 @@ @require_peft_backend -class AFLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): +class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = AuraFlowPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_kwargs = {} + uses_flow_matching = True transformer_kwargs = { "sample_size": 64, - "patch_size": 2, + "patch_size": 1, "in_channels": 4, - "num_mmdit_layers": 4, - "num_single_dit_layers": 32, - "attention_head_dim": 256, - "num_attention_heads": 12, - "joint_attention_dim": 2048, - "caption_projection_dim": 3072, + "num_mmdit_layers": 1, + "num_single_dit_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "caption_projection_dim": 32, "out_channels": 4, - "pos_embed_max_size": 1024, + "pos_embed_max_size": 32, } + tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + vae_kwargs = { - "sample_size": 1024, + "sample_size": 32, "in_channels": 3, "out_channels": 3, - "block_out_channels": [ - 128, - 256, - 512, - 512 - ], - "layers_per_block": 2, + "block_out_channels": (4,), + "layers_per_block": 1, "latent_channels": 4, - "norm_num_groups": 32, - "use_quant_conv": True, - "use_post_quant_conv": True, + "norm_num_groups": 1, + "use_quant_conv": False, + "use_post_quant_conv": False, "shift_factor": None, "scaling_factor": 0.13025, } - has_three_text_encoders = False - - @require_torch_gpu - def test_af_lora(self): - """ - Test loading the loras that are saved with the diffusers and peft formats. - Related PR: https://github.com/huggingface/diffusers/pull/8584 - """ - components = self.get_dummy_components() - - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - lora_model_id = "Warlord-K/gorkem-auraflow-lora" - - lora_filename = "pytorch_lora_weights.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) - pipe.unload_lora_weights() - lora_filename = "lora_peft_format.safetensors" - pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + @property + def output_shape(self): + return (1, 64, 64, 3) \ No newline at end of file From 0eee03eefea320d35b7876eefb6df05994d4290d Mon Sep 17 00:00:00 2001 From: Warlord-K Date: Mon, 12 Aug 2024 11:06:37 +0530 Subject: [PATCH 06/53] Change attention_kwargs->joint_attention_kwargs --- src/diffusers/loaders/lora_pipeline.py | 2 +- .../models/transformers/auraflow_transformer_2d.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4d13134e47b1..3ed3ffc71b5c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2125,6 +2125,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -2429,7 +2430,6 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 5c8e6720ecc2..90f4f3fd5d8e 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -452,7 +452,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, - attention_kwargs: Optional[Dict[str, Any]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] @@ -465,18 +465,18 @@ def forward( encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): From 4e4f780e5466c67c1286fed049055490a7dd2cf6 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:17:39 +0100 Subject: [PATCH 07/53] Rebasing derp. --- src/diffusers/loaders/lora_pipeline.py | 476 +------------------------ 1 file changed, 1 insertion(+), 475 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 3ed3ffc71b5c..3377ccf10f95 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1641,479 +1641,6 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t """ super().unfuse_lora(components=components) - -class AuraFlowLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into [`AuraFlowTransformer2DModel`] - Specific to [`AuraFlowPipeline`]. - """ - - _lora_loadable_modules = ["transformer"] - transformer_name = TRANSFORMER_NAME - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict - @validate_hf_hub_args - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - - - We support loading A1111 formatted LoRA checkpoints in a limited capacity. - - This function is experimental and might change in the future. - - - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - - """ - # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - use_safetensors = kwargs.pop("use_safetensors", None) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - - state_dict = cls._fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - - return state_dict - - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_weights - def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs - ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. - - All kwargs are forwarded to `self.lora_state_dict`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is - loaded. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - ) - - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=None, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): - """ - This will load the LoRA layers specified in `state_dict` into `transformer`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - transformer (`SD3Transformer2DModel`): - The Transformer model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - """ - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if len(state_dict.keys()) > 0: - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) - - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - See `LoRALinearLayer` for more details. - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - from peft import LoraConfig - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as - # their prefixes. - keys = list(state_dict.keys()) - prefix = cls.text_encoder_name if prefix is None else prefix - - # Safe prefix to check with. - if any(cls.text_encoder_name in key for key in keys): - # Load the layers corresponding to text encoder and make necessary adjustments. - text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] - text_encoder_lora_state_dict = { - k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys - } - - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [ - k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix - ] - network_alphas = { - k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys - } - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) - - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - ) - - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `transformer`. - text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - """ - state_dict = {} - - if not (transformer_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") - - if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - - if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer - def fuse_lora( - self, - components: List[str] = ["transformer", "text_encoder"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - - - This is an experimental API. - - - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names - ) - - def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - - - This is an experimental API. - - - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - """ - super().unfuse_lora(components=components) - class AuraFlowLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`AuraFlowTransformer2DModel`] @@ -2338,7 +1865,6 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, # Unsafe code /> @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -2386,7 +1912,6 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -2446,6 +1971,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): """ super().unfuse_lora(components=components) + class FluxLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`FluxTransformer2DModel`], From c07d1f51858cf9114b0880eff27a9becfd6d0fa8 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 13 Dec 2024 16:26:26 +0000 Subject: [PATCH 08/53] fix --- src/diffusers/loaders/lora_pipeline.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 3377ccf10f95..1171ea5c76d3 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1979,10 +1979,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin): Specific to [`StableDiffusion3Pipeline`]. """ - -# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially -# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. -class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME From 1b7f99fb841445fa3c4e5ad91f25f29267aa8c6b Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 13 Dec 2024 16:27:08 +0000 Subject: [PATCH 09/53] fix --- src/diffusers/loaders/lora_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1171ea5c76d3..376596fe7cb2 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1979,6 +1979,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): Specific to [`StableDiffusion3Pipeline`]. """ + _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME From 875a3e0f908fb286014fea6908c582fed488e820 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:27:28 +0100 Subject: [PATCH 10/53] Quality fixes. --- src/diffusers/loaders/lora_pipeline.py | 2 +- tests/lora/test_lora_layers_af.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 376596fe7cb2..b08ed96e99eb 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1652,7 +1652,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index 9615249633cd..b61e33f7eba5 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -18,10 +18,10 @@ from transformers import AutoTokenizer, T5EncoderModel from diffusers import ( - FlowMatchEulerDiscreteScheduler, AuraFlowPipeline, + FlowMatchEulerDiscreteScheduler, ) -from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device +from diffusers.utils.testing_utils import is_peft_available, require_peft_backend if is_peft_available(): @@ -70,4 +70,4 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @property def output_shape(self): - return (1, 64, 64, 3) \ No newline at end of file + return (1, 64, 64, 3) From a242d7ae6915d0de32c42437f22820f502f99cfd Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 13 Dec 2024 16:29:05 +0000 Subject: [PATCH 11/53] make style --- src/diffusers/loaders/__init__.py | 2 +- src/diffusers/loaders/lora_pipeline.py | 8 +++----- .../models/transformers/auraflow_transformer_2d.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index ec233ecf4a70..70612cdea292 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -89,9 +89,9 @@ def text_encoder_attn_modules(text_encoder): from .ip_adapter import IPAdapterMixin from .lora_pipeline import ( AmusedLoraLoaderMixin, + AuraFlowLoraLoaderMixin, CogVideoXLoraLoaderMixin, FluxLoraLoaderMixin, - AuraFlowLoraLoaderMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Mochi1LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b08ed96e99eb..a0ae31575327 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1641,10 +1641,10 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "t """ super().unfuse_lora(components=components) + class AuraFlowLoraLoaderMixin(LoraBaseMixin): r""" - Load LoRA layers into [`AuraFlowTransformer2DModel`] - Specific to [`AuraFlowPipeline`]. + Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`]. """ _lora_loadable_modules = ["transformer"] @@ -1896,9 +1896,7 @@ def save_lora_weights( state_dict = {} if not (transformer_lora_layers): - raise ValueError( - "You must pass `transformer_lora_layers`." - ) + raise ValueError("You must pass `transformer_lora_layers`.") state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 90f4f3fd5d8e..2fa615a78bb4 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -20,7 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention_processor import ( From a73df6b94929988f45229efe9c2691d43de1d4ed Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:33:32 +0100 Subject: [PATCH 12/53] `make fix-copies` --- src/diffusers/loaders/lora_pipeline.py | 91 ++++++------------- .../transformers/auraflow_transformer_2d.py | 2 +- 2 files changed, 29 insertions(+), 64 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a0ae31575327..e570dc349a4c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1703,7 +1703,8 @@ def lora_state_dict( The subfolder location of a model file within a larger model repository on the Hub or locally. """ - # Load the main state dict first which has the LoRA layers for transformer + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1724,7 +1725,7 @@ def lora_state_dict( "framework": "pytorch", } - state_dict = cls._fetch_state_dict( + state_dict = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, @@ -1739,6 +1740,12 @@ def lora_state_dict( allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + return state_dict def load_lora_weights( @@ -1787,7 +1794,9 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1801,68 +1810,24 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. """ - from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - - keys = list(state_dict.keys()) - - transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] - state_dict = { - k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys - } - - if len(state_dict.keys()) > 0: - # check with first key if is not in peft format - first_key = next(iter(state_dict.keys())) - if "lora_A" not in first_key: - state_dict = convert_unet_state_dict_to_peft(state_dict) - - if adapter_name in getattr(transformer, "peft_config", {}): - raise ValueError( - f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." - ) - - rank = {} - for key, val in state_dict.items(): - if "lora_B" in key: - rank[key] = val.shape[1] - - lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict) - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - lora_config_kwargs.pop("use_dora") - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(transformer) - - # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks - # otherwise loading LoRA weights will lead to an error - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) - - if incompatible_keys is not None: - # check only for unexpected keys - unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) - if unexpected_keys: - logger.warning( - f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " - f" {unexpected_keys}. " - ) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def save_lora_weights( diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 2fa615a78bb4..1f93a1fc0d87 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -343,8 +343,8 @@ def __init__( self.gradient_checkpointing = False - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: From 894eac0aa81f27274902e86b9d4f0122cebd3fe3 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 13 Dec 2024 17:39:08 +0100 Subject: [PATCH 13/53] `ruff check --fix` --- src/diffusers/loaders/lora_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e570dc349a4c..64eb72dd860b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -22,7 +22,6 @@ USE_PEFT_BACKEND, convert_state_dict_to_diffusers, convert_state_dict_to_peft, - convert_unet_state_dict_to_peft, deprecate, get_adapter_name, get_peft_kwargs, From 2b364161929981b71349119f891ec5ef8845c16d Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 15 Dec 2024 05:08:54 +0100 Subject: [PATCH 14/53] Attept 1 to fix tests. --- tests/lora/test_lora_layers_af.py | 34 +++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index b61e33f7eba5..1b2835e96427 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -15,13 +15,19 @@ import sys import unittest +import torch from transformers import AutoTokenizer, T5EncoderModel from diffusers import ( AuraFlowPipeline, + AuraFlowTransformer2DModel, FlowMatchEulerDiscreteScheduler, ) -from diffusers.utils.testing_utils import is_peft_available, require_peft_backend +from diffusers.utils.testing_utils import ( + floats_tensor, + is_peft_available, + require_peft_backend, +) if is_peft_available(): @@ -49,8 +55,9 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "joint_attention_dim": 32, "caption_projection_dim": 32, "out_channels": 4, - "pos_embed_max_size": 32, + "pos_embed_max_size": 64, } + transformer_cls = AuraFlowTransformer2DModel tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" @@ -71,3 +78,26 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @property def output_shape(self): return (1, 64, 64, 3) + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 10 + num_channels = 4 + sizes = (32, 32) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 4, + "guidance_scale": 0.0, + "height": 8, + "width": 8, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs From 6b762b800c155dc8cb29822c06973c72d8f7e4d2 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 15 Dec 2024 05:57:02 +0100 Subject: [PATCH 15/53] Attept 2 to fix tests. --- src/diffusers/loaders/lora_pipeline.py | 199 +++++++++++++++++++++++-- tests/lora/test_lora_layers_af.py | 6 +- 2 files changed, 193 insertions(+), 12 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 64eb72dd860b..9d2773293780 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1646,7 +1646,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`]. """ - _lora_loadable_modules = ["transformer"] + _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME @classmethod @@ -1747,48 +1747,75 @@ def lora_state_dict( return state_dict + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state - dict is loaded into `self.transformer`. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is + loaded into `self.unet`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state + dict is loaded into `self.text_encoder`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) + is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( + self.load_lora_into_unet( state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + network_alphas=network_alphas, + unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=getattr(self, self.text_encoder_name) + if not hasattr(self, "text_encoder") + else self.text_encoder, + lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod @@ -1828,6 +1855,158 @@ def load_lora_into_transformer( low_cpu_mem_usage=low_cpu_mem_usage, ) + @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = cls.text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(cls.text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [ + k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix + ] + network_alphas = { + k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + **peft_kwargs, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + @classmethod def save_lora_weights( cls, diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index 1b2835e96427..62364380eb83 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -16,7 +16,7 @@ import unittest import torch -from transformers import AutoTokenizer, T5EncoderModel +from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, UMT5EncoderModel from diffusers import ( AuraFlowPipeline, @@ -59,7 +59,9 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): } transformer_cls = AuraFlowTransformer2DModel tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" - text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5" + + text_encoder_target_modules = ["q", "k", "v", "o"] vae_kwargs = { "sample_size": 32, From bc2a4663fd13e304b48b177b8d364186e3e9d527 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 15 Dec 2024 06:45:30 +0100 Subject: [PATCH 16/53] Attept 3 to fix tests. --- src/diffusers/loaders/lora_pipeline.py | 35 ++++++++++---- .../pipelines/aura_flow/pipeline_aura_flow.py | 46 ++++++++++++++++++- tests/lora/test_lora_layers_af.py | 13 +++++- 3 files changed, 81 insertions(+), 13 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9d2773293780..25efcbbc7964 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1651,7 +1651,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -1700,10 +1700,11 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. """ # Load the main state dict first which has the LoRA layers for either of - # transformer and text encoder or both. + # UNet and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1712,6 +1713,7 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False @@ -1738,16 +1740,32 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + network_alphas = None + # TODO: replace it with a method from `state_dict_utils` + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) + + return state_dict, network_alphas - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights with unet->transformer def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): @@ -1798,10 +1816,9 @@ def load_lora_weights( if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_unet( + self.load_lora_into_transformer( state_dict, - network_alphas=network_alphas, - unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 5c9c803e36e3..cfb002ea7764 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from transformers import T5Tokenizer, UMT5EncoderModel @@ -22,7 +22,14 @@ from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import logging, replace_example_docstring +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -125,6 +132,8 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin): _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" + transformer_name = "transformer" + text_encoder_name = "text_encoder" def __init__( self, @@ -215,6 +224,7 @@ def encode_prompt( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, max_sequence_length: int = 256, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -241,10 +251,21 @@ def encode_prompt( negative_prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for negative text embeddings. max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ if device is None: device = self._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -402,6 +423,7 @@ def __call__( max_sequence_length: int = 256, output_type: Optional[str] = "pil", return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[ImagePipelineOutput, Tuple]: r""" Function invoked when calling the pipeline for generation. @@ -457,6 +479,10 @@ def __call__( Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Examples: @@ -479,6 +505,8 @@ def __call__( negative_prompt_attention_mask, ) + self._joint_attention_kwargs = joint_attention_kwargs + # 2. Determine batch size. if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -488,6 +516,9 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -511,6 +542,7 @@ def __call__( prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, max_sequence_length=max_sequence_length, + lora_scale=lora_scale, ) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) @@ -551,6 +583,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, return_dict=False, + joint_attention_kwargs=self.joint_attention_kwargs, )[0] # perform guidance @@ -579,7 +612,16 @@ def __call__( # Offload all models self.maybe_free_model_hooks() + if self.text_encoder is not None: + if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + if not return_dict: return (image,) return ImagePipelineOutput(images=image) + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_af.py index 62364380eb83..247e8f83aab2 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_af.py @@ -16,7 +16,7 @@ import unittest import torch -from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, UMT5EncoderModel +from transformers import AutoTokenizer, UMT5EncoderModel from diffusers import ( AuraFlowPipeline, @@ -41,7 +41,7 @@ @require_peft_backend class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = AuraFlowPipeline - scheduler_cls = FlowMatchEulerDiscreteScheduler() + scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_kwargs = {} uses_flow_matching = True transformer_kwargs = { @@ -60,6 +60,7 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): transformer_cls = AuraFlowTransformer2DModel tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5" + attention_kwargs_name = "joint_attention_kwargs" text_encoder_target_modules = ["q", "k", "v", "o"] @@ -103,3 +104,11 @@ def get_dummy_inputs(self, with_generator=True): pipeline_inputs.update({"generator": generator}) return noise, input_ids, pipeline_inputs + + @unittest.skip("Not supported in AuraFlow.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Not supported in AuraFlow.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass From 1c7909565a8525b4891e01e2ad7feb6fc301e915 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:23:21 +0100 Subject: [PATCH 17/53] Address review comments. --- src/diffusers/loaders/lora_pipeline.py | 46 ++++++++----------- .../pipelines/aura_flow/pipeline_aura_flow.py | 2 - 2 files changed, 19 insertions(+), 29 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 25efcbbc7964..8632e9c8c106 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1648,10 +1648,11 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME + text_encoder_name = TEXT_ENCODER_NAME @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -1700,11 +1701,10 @@ def lora_state_dict( allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. - weight_name (`str`, *optional*, defaults to None): - Name of the serialized state dict file. + """ # Load the main state dict first which has the LoRA layers for either of - # UNet and text encoder or both. + # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1713,7 +1713,6 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) - unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) allow_pickle = False @@ -1740,30 +1739,14 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - network_alphas = None - # TODO: replace it with a method from `state_dict_utils` - if all( - ( - k.startswith("lora_te_") - or k.startswith("lora_unet_") - or k.startswith("lora_te1_") - or k.startswith("lora_te2_") - ) - for k in state_dict.keys() - ): - # Map SDXL blocks correctly. - if unet_config is not None: - # use unet config to remap block numbers - state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) - state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - - return state_dict, network_alphas + return state_dict # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights with unet->transformer def load_lora_weights( @@ -2025,10 +2008,12 @@ def load_lora_into_text_encoder( # Unsafe code /> @classmethod + # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer def save_lora_weights( cls, save_directory: Union[str, os.PathLike], - transformer_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -2042,6 +2027,9 @@ def save_lora_weights( Directory to save LoRA parameters to. Will be created if it doesn't exist. transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `transformer`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -2055,10 +2043,14 @@ def save_lora_weights( """ state_dict = {} - if not (transformer_lora_layers): - raise ValueError("You must pass `transformer_lora_layers`.") + if not (transformer_lora_layers or text_encoder_lora_layers): + raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if transformer_lora_layers: + state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + + if text_encoder_lora_layers: + state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) # Save the model cls.write_lora_layers( diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index cfb002ea7764..49c89227a193 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -132,8 +132,6 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin): _optional_components = [] model_cpu_offload_seq = "text_encoder->transformer->vae" - transformer_name = "transformer" - text_encoder_name = "text_encoder" def __init__( self, From 9454e845e6aeb828b9c4560946d1b34d8ec761d7 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 19 Dec 2024 10:25:49 +0100 Subject: [PATCH 18/53] Rebasing derp. --- docs/source/en/api/loaders/lora.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index 3c1ab7e5b6eb..34a5416b1ccc 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -42,7 +42,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse [[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin -<<<<<<< HEAD ## FluxLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin @@ -54,11 +53,9 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse ## Mochi1LoraLoaderMixin [[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin -======= ## AuraFlowLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin ->>>>>>> a45d48a56 (Add AuraFlowLoraLoaderMixin to documentation) ## AmusedLoraLoaderMixin From 28a4918a470669af58ed3bc7b4a41520cc2564d9 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 7 Jan 2025 03:56:56 +0100 Subject: [PATCH 19/53] Get more tests passing by copying from Flux. Address review comments. --- src/diffusers/loaders/lora_pipeline.py | 359 +++++++++++++++--- .../transformers/auraflow_transformer_2d.py | 12 +- .../pipelines/aura_flow/pipeline_aura_flow.py | 2 +- ...ers_af.py => test_lora_layers_auraflow.py} | 1 - 4 files changed, 320 insertions(+), 54 deletions(-) rename tests/lora/{test_lora_layers_af.py => test_lora_layers_auraflow.py} (98%) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 793dcdbee833..63b459bc3b3e 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1649,13 +1649,15 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME + _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + return_alphas: bool = False, **kwargs, ): r""" @@ -1739,21 +1741,57 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - return state_dict + # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. + is_kohya = any(".lora_down.weight" in k for k in state_dict) + if is_kohya: + state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) + # Kohya already takes care of scaling the LoRA parameters with alpha. + return (state_dict, None) if return_alphas else state_dict + + is_xlabs = any("processor" in k for k in state_dict) + if is_xlabs: + state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) + # xlabs doesn't use `alpha`. + return (state_dict, None) if return_alphas else state_dict + + is_bfl_control = any("query_norm.scale" in k for k in state_dict) + if is_bfl_control: + state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) + return (state_dict, None) if return_alphas else state_dict - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_weights with unet->transformer + # For state dicts like + # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA + keys = list(state_dict.keys()) + network_alphas = {} + for k in keys: + if "alpha" in k: + alpha_value = state_dict.get(k) + if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance( + alpha_value, float + ): + network_alphas[k] = state_dict.pop(k) + else: + raise ValueError( + f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." + ) + + if return_alphas: + return state_dict, network_alphas + else: + return state_dict + + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. @@ -1761,23 +1799,20 @@ def load_lora_weights( See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is - loaded into `self.unet`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state - dict is loaded into `self.text_encoder`. + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state + dict is loaded into `self.transformer`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1793,35 +1828,241 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs + ) - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: + has_lora_keys = any("lora" in key for key in state_dict.keys()) + + # Flux Control LoRAs also have norm keys + has_norm_keys = any( + norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys + ) + + if not (has_lora_keys or has_norm_keys): raise ValueError("Invalid LoRA checkpoint.") - self.load_lora_into_transformer( - state_dict, - transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, + transformer_lora_state_dict = { + k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k + } + transformer_norm_state_dict = { + k: state_dict.pop(k) + for k in list(state_dict.keys()) + if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) + } + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( + transformer, transformer_lora_state_dict, transformer_norm_state_dict ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=network_alphas, - text_encoder=getattr(self, self.text_encoder_name) - if not hasattr(self, "text_encoder") - else self.text_encoder, - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, + + if has_param_with_expanded_shape: + logger.info( + "The LoRA weights contain parameters that have different shapes that expected by the transformer. " + "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " + "To get a comprehensive list of parameter names that were modified, enable debug logging." + ) + transformer_lora_state_dict = self._maybe_expand_lora_state_dict( + transformer=transformer, lora_state_dict=transformer_lora_state_dict ) + if len(transformer_lora_state_dict) > 0: + self.load_lora_into_transformer( + transformer_lora_state_dict, + network_alphas=network_alphas, + transformer=transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + + if len(transformer_norm_state_dict) > 0: + transformer._transformer_norm_layers = self._load_norm_into_transformer( + transformer_norm_state_dict, + transformer=transformer, + discard_original_layers=False, + ) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + ) + @classmethod - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin._maybe_expand_transformer_param_shape_or_error_ + def _maybe_expand_transformer_param_shape_or_error_( + cls, + transformer: torch.nn.Module, + lora_state_dict=None, + norm_state_dict=None, + prefix=None, + ) -> bool: + """ + Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and + generalizes things a bit so that any parameter that needs expansion receives appropriate treatement. + """ + state_dict = {} + if lora_state_dict is not None: + state_dict.update(lora_state_dict) + if norm_state_dict is not None: + state_dict.update(norm_state_dict) + + # Remove prefix if present + prefix = prefix or cls.transformer_name + for key in list(state_dict.keys()): + if key.split(".")[0] == prefix: + state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) + + # Expand transformer parameter shapes if they don't match lora + has_param_with_shape_update = False + overwritten_params = {} + + is_peft_loaded = getattr(transformer, "peft_config", None) is not None + for name, module in transformer.named_modules(): + if isinstance(module, torch.nn.Linear): + module_weight = module.weight.data + module_bias = module.bias.data if module.bias is not None else None + bias = module_bias is not None + + lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name + lora_A_weight_name = f"{lora_base_name}.lora_A.weight" + lora_B_weight_name = f"{lora_base_name}.lora_B.weight" + if lora_A_weight_name not in state_dict: + continue + + in_features = state_dict[lora_A_weight_name].shape[1] + out_features = state_dict[lora_B_weight_name].shape[0] + + # This means there's no need for an expansion in the params, so we simply skip. + if tuple(module_weight.shape) == (out_features, in_features): + continue + + module_out_features, module_in_features = module_weight.shape + debug_message = "" + if in_features > module_in_features: + debug_message += ( + f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' + f"checkpoint contains higher number of features than expected. The number of input_features will be " + f"expanded from {module_in_features} to {in_features}" + ) + if out_features > module_out_features: + debug_message += ( + ", and the number of output features will be " + f"expanded from {module_out_features} to {out_features}." + ) + else: + debug_message += "." + if debug_message: + logger.debug(debug_message) + + if out_features > module_out_features or in_features > module_in_features: + has_param_with_shape_update = True + parent_module_name, _, current_module_name = name.rpartition(".") + parent_module = transformer.get_submodule(parent_module_name) + + with torch.device("meta"): + expanded_module = torch.nn.Linear( + in_features, out_features, bias=bias, dtype=module_weight.dtype + ) + # Only weights are expanded and biases are not. This is because only the input dimensions + # are changed while the output dimensions remain the same. The shape of the weight tensor + # is (out_features, in_features), while the shape of bias tensor is (out_features,), which + # explains the reason why only weights are expanded. + new_weight = torch.zeros_like( + expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype + ) + slices = tuple(slice(0, dim) for dim in module_weight.shape) + new_weight[slices] = module_weight + tmp_state_dict = {"weight": new_weight} + if module_bias is not None: + tmp_state_dict["bias"] = module_bias + expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True) + + setattr(parent_module, current_module_name, expanded_module) + + del tmp_state_dict + + if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: + attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] + new_value = int(expanded_module.weight.data.shape[1]) + old_value = getattr(transformer.config, attribute_name) + setattr(transformer.config, attribute_name, new_value) + logger.info( + f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." + ) + + # For `unload_lora_weights()`. + # TODO: this could lead to more memory overhead if the number of overwritten params + # are large. Should be revisited later and tackled through a `discard_original_layers` arg. + overwritten_params[f"{current_module_name}.weight"] = module_weight + if module_bias is not None: + overwritten_params[f"{current_module_name}.bias"] = module_bias + + if len(overwritten_params) > 0: + transformer._overwritten_params = overwritten_params + + return has_param_with_shape_update + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin._maybe_expand_lora_state_dict + def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): + expanded_module_names = set() + transformer_state_dict = transformer.state_dict() + prefix = f"{cls.transformer_name}." + + lora_module_names = [ + key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight") + ] + lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)] + lora_module_names = sorted(set(lora_module_names)) + transformer_module_names = sorted({name for name, _ in transformer.named_modules()}) + unexpected_modules = set(lora_module_names) - set(transformer_module_names) + if unexpected_modules: + logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") + + is_peft_loaded = getattr(transformer, "peft_config", None) is not None + for k in lora_module_names: + if k in unexpected_modules: + continue + + base_param_name = ( + f"{k.replace(prefix, '')}.base_layer.weight" + if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict + else f"{k.replace(prefix, '')}.weight" + ) + base_weight_param = transformer_state_dict[base_param_name] + lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] + + if base_weight_param.shape[1] > lora_A_param.shape[1]: + shape = (lora_A_param.shape[0], base_weight_param.shape[1]) + expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) + expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) + lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight + expanded_module_names.add(k) + elif base_weight_param.shape[1] < lora_A_param.shape[1]: + raise NotImplementedError( + f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." + ) + + if expanded_module_names: + logger.info( + f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." + ) + + return lora_state_dict + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->AuraFlowTransformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1831,7 +2072,11 @@ def load_lora_into_transformer( A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. - transformer (`SD3Transformer2DModel`): + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + transformer (`AuraFlowTransformer2DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use @@ -1840,20 +2085,23 @@ def load_lora_into_transformer( Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. """ - if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) # Load the layers corresponding to transformer. - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=None, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) + keys = list(state_dict.keys()) + transformer_present = any(key.startswith(cls.transformer_name) for key in keys) + if transformer_present: + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder @@ -2062,9 +2310,10 @@ def save_lora_weights( safe_serialization=safe_serialization, ) + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.fuse_lora def fuse_lora( self, - components: List[str] = ["transformer"], + components: List[str] = ["transformer", "text_encoder"], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, @@ -2101,11 +2350,25 @@ def fuse_lora( pipeline.fuse_lora(lora_scale=0.7) ``` """ + + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if ( + hasattr(transformer, "_transformer_norm_layers") + and isinstance(transformer._transformer_norm_layers, dict) + and len(transformer._transformer_norm_layers.keys()) > 0 + ): + logger.info( + "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer " + "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly " + "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." + ) + super().fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) - def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). @@ -2119,6 +2382,10 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. """ + transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer + if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: + transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) + super().unfuse_lora(components=components) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 1f93a1fc0d87..5f69fcfaafd2 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -452,7 +452,7 @@ def forward( hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, timestep: torch.LongTensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: height, width = hidden_states.shape[-2:] @@ -465,18 +465,18 @@ def forward( encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) - if joint_attention_kwargs is not None: - joint_attention_kwargs = joint_attention_kwargs.copy() - lora_scale = joint_attention_kwargs.pop("scale", 1.0) + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index c8ab8da2f3ea..a3e5b2840b4b 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -588,7 +588,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, return_dict=False, - joint_attention_kwargs=self.joint_attention_kwargs, + attention_kwargs=self.joint_attention_kwargs, )[0] # perform guidance diff --git a/tests/lora/test_lora_layers_af.py b/tests/lora/test_lora_layers_auraflow.py similarity index 98% rename from tests/lora/test_lora_layers_af.py rename to tests/lora/test_lora_layers_auraflow.py index 247e8f83aab2..c41239d8611e 100644 --- a/tests/lora/test_lora_layers_af.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -60,7 +60,6 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): transformer_cls = AuraFlowTransformer2DModel tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5" - attention_kwargs_name = "joint_attention_kwargs" text_encoder_target_modules = ["q", "k", "v", "o"] From d6028cdc53d6d30c14022f8045f87c7f59c466c1 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 7 Jan 2025 04:09:35 +0100 Subject: [PATCH 20/53] `joint_attention_kwargs`->`attention_kwargs` --- .../pipelines/aura_flow/pipeline_aura_flow.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index a3e5b2840b4b..4cb0e8ff19db 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -428,7 +428,7 @@ def __call__( max_sequence_length: int = 256, output_type: Optional[str] = "pil", return_dict: bool = True, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[ImagePipelineOutput, Tuple]: r""" Function invoked when calling the pipeline for generation. @@ -484,7 +484,7 @@ def __call__( Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. - joint_attention_kwargs (`dict`, *optional*): + attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -510,7 +510,7 @@ def __call__( negative_prompt_attention_mask, ) - self._joint_attention_kwargs = joint_attention_kwargs + self._attention_kwargs = attention_kwargs # 2. Determine batch size. if prompt is not None and isinstance(prompt, str): @@ -522,7 +522,7 @@ def __call__( device = self._execution_device lora_scale = ( - self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None ) # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -588,7 +588,7 @@ def __call__( encoder_hidden_states=prompt_embeds, timestep=timestep, return_dict=False, - attention_kwargs=self.joint_attention_kwargs, + attention_kwargs=self.attention_kwargs, )[0] # perform guidance @@ -631,5 +631,5 @@ def __call__( return ImagePipelineOutput(images=image) @property - def joint_attention_kwargs(self): - return self._joint_attention_kwargs + def attention_kwargs(self): + return self._attention_kwargs From 2d02c2c8d2301efe80ae4fde2475a219c81eabc0 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 7 Jan 2025 04:19:18 +0100 Subject: [PATCH 21/53] Add `lora_scale` property for te LoRAs. --- src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 4cb0e8ff19db..e873bbf95364 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -633,3 +633,7 @@ def __call__( @property def attention_kwargs(self): return self._attention_kwargs + + @property + def lora_scale(self): + return self._lora_scale From 2b934b458a25560776c7955562b6350d5e97f84f Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 7 Jan 2025 05:16:41 +0100 Subject: [PATCH 22/53] Make test better. --- tests/lora/test_lora_layers_auraflow.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index c41239d8611e..33d046e4f207 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -41,8 +41,9 @@ @require_peft_backend class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = AuraFlowPipeline - scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_kwargs = {} + scheduler_classes = [FlowMatchEulerDiscreteScheduler] uses_flow_matching = True transformer_kwargs = { "sample_size": 64, @@ -54,15 +55,9 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "num_attention_heads": 2, "joint_attention_dim": 32, "caption_projection_dim": 32, - "out_channels": 4, "pos_embed_max_size": 64, } transformer_cls = AuraFlowTransformer2DModel - tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" - text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5" - - text_encoder_target_modules = ["q", "k", "v", "o"] - vae_kwargs = { "sample_size": 32, "in_channels": 3, @@ -73,13 +68,16 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): "norm_num_groups": 1, "use_quant_conv": False, "use_post_quant_conv": False, - "shift_factor": None, - "scaling_factor": 0.13025, + "shift_factor": 0.0609, + "scaling_factor": 1.5035, } + tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" + text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5" + text_encoder_target_modules = ["q", "k", "v", "o"] @property def output_shape(self): - return (1, 64, 64, 3) + return (1, 8, 8, 3) def get_dummy_inputs(self, with_generator=True): batch_size = 1 From 532013f990cafd646d340f58ba599b9ace91d241 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 7 Jan 2025 05:56:36 +0100 Subject: [PATCH 23/53] Remove useless property. --- src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index e873bbf95364..4cb0e8ff19db 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -633,7 +633,3 @@ def __call__( @property def attention_kwargs(self): return self._attention_kwargs - - @property - def lora_scale(self): - return self._lora_scale From e06d8eb94fd92439543a5f7556c24d4c47efdd18 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 8 Jan 2025 15:38:37 +0100 Subject: [PATCH 24/53] Skip TE-only tests for AuraFlow. --- src/diffusers/loaders/lora_pipeline.py | 2 +- tests/lora/utils.py | 33 ++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index d8e506445c56..48a36bba7f03 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1647,7 +1647,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`]. """ - _lora_loadable_modules = ["transformer", "text_encoder"] + _lora_loadable_modules = ["transformer"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] diff --git a/tests/lora/utils.py b/tests/lora/utils.py index a22f86ad6b89..bbb5f411cc10 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import inspect import os import tempfile @@ -78,6 +79,16 @@ def initialize_dummy_state_dict(state_dict): POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"] +def require_te_lora_support(f): + @functools.wraps(f) + def wrapped(self: "PeftLoraLoaderMixinTests", *a, **kw): + if not self.supports_text_encoder_lora: + self.skipTest("Pipeline class doesn't support text encoder LoRA.") + return f(self, *a, **kw) + + return wrapped + + @require_peft_backend class PeftLoraLoaderMixinTests: @@ -273,6 +284,7 @@ def test_simple_inference(self): output_no_lora = pipe(**inputs)[0] self.assertTrue(output_no_lora.shape == self.output_shape) + @require_te_lora_support def test_simple_inference_with_text_lora(self): """ Tests a simple inference with lora attached on the text encoder @@ -434,6 +446,7 @@ def test_low_cpu_mem_usage_with_loading(self): "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.", ) + @require_te_lora_support def test_simple_inference_with_text_lora_and_scale(self): """ Tests a simple inference with lora attached on the text encoder + scale argument @@ -490,6 +503,7 @@ def test_simple_inference_with_text_lora_and_scale(self): "Lora + 0 scale should lead to same result as no LoRA", ) + @require_te_lora_support def test_simple_inference_with_text_lora_fused(self): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model @@ -530,6 +544,7 @@ def test_simple_inference_with_text_lora_fused(self): np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) + @require_te_lora_support def test_simple_inference_with_text_lora_unloaded(self): """ Tests a simple inference with lora attached to text encoder, then unloads the lora weights @@ -578,6 +593,7 @@ def test_simple_inference_with_text_lora_unloaded(self): "Fused lora should change the output", ) + @require_te_lora_support def test_simple_inference_with_text_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA. @@ -629,6 +645,7 @@ def test_simple_inference_with_text_lora_save_load(self): "Loading from saved checkpoints should give same results.", ) + @require_te_lora_support def test_simple_inference_with_partial_text_lora(self): """ Tests a simple inference with lora attached on the text encoder @@ -797,6 +814,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): "Loading from saved checkpoints should give same results.", ) + @require_te_lora_support def test_simple_inference_with_text_denoiser_lora_and_scale(self): """ Tests a simple inference with lora attached on the text encoder + Unet + scale argument @@ -863,6 +881,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): "The scaling parameter has not been correctly restored!", ) + @require_te_lora_support def test_simple_inference_with_text_lora_denoiser_fused(self): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model @@ -916,6 +935,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) + @require_te_lora_support def test_simple_inference_with_text_denoiser_lora_unloaded(self): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights @@ -968,6 +988,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): "Fused lora should change the output", ) + @require_te_lora_support def test_simple_inference_with_text_denoiser_lora_unfused( self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 ): @@ -1023,6 +1044,7 @@ def test_simple_inference_with_text_denoiser_lora_unfused( "Fused lora should not change the output", ) + @require_te_lora_support def test_simple_inference_with_text_denoiser_multi_adapter(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches @@ -1134,6 +1156,7 @@ def test_wrong_adapter_name_raises_error(self): pipe.set_adapters("adapter-1") _ = pipe(**inputs, generator=torch.manual_seed(0))[0] + @require_te_lora_support def test_simple_inference_with_text_denoiser_block_scale(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches @@ -1191,6 +1214,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self): "output with no lora and output with lora disabled should give same results", ) + @require_te_lora_support def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches @@ -1265,6 +1289,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): with self.assertRaises(ValueError): pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) + @require_te_lora_support def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" @@ -1354,6 +1379,7 @@ def all_possible_dict_opts(unet, value): pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error + @require_te_lora_support def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches @@ -1448,6 +1474,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): "output with no lora and output with lora disabled should give same results", ) + @require_te_lora_support def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches @@ -1674,6 +1701,7 @@ def test_get_list_adapters(self): self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) @require_peft_version_greater(peft_version="0.6.2") + @require_te_lora_support def test_simple_inference_with_text_lora_denoiser_fused_multi( self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 ): @@ -1852,6 +1880,7 @@ def test_unexpected_keys_warning(self): self.assertTrue(".diffusers_cat" in cap_logger.out) @unittest.skip("This is failing for now - need to investigate") + @require_te_lora_support def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights @@ -2098,3 +2127,7 @@ def test_correct_lora_configs_with_different_ranks(self): lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + + @property + def supports_text_encoder_lora(self): + return len({"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection(self.pipeline_class._lora_loadable_modules)) != 0 From 2b359094b650f199ab3d56c71293b11dc9f50bbf Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 10 Jan 2025 08:02:32 +0100 Subject: [PATCH 25/53] Support LoRA for non-CLIP TEs. --- src/diffusers/loaders/lora_pipeline.py | 136 ++++++++----------------- src/diffusers/models/lora.py | 4 +- src/diffusers/utils/__init__.py | 1 + 3 files changed, 48 insertions(+), 93 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 48a36bba7f03..d08380bd07e5 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -17,9 +17,11 @@ import torch from huggingface_hub.utils import validate_hf_hub_args +from torch import nn from ..utils import ( USE_PEFT_BACKEND, + StateDictType, convert_state_dict_to_diffusers, convert_state_dict_to_peft, deprecate, @@ -388,21 +390,13 @@ def load_lora_into_text_encoder( text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict, original_type=StateDictType.DIFFUSERS) + + for name, module in text_encoder.named_modules(): + if "lora_A" not in name and "lora_B" not in name and isinstance(module, (nn.Linear, nn.Conv2d)): + rank_key = f"{name.removesuffix(".base_layer")}.lora_B.weight" + if rank_key in text_encoder_lora_state_dict: + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ @@ -931,21 +925,13 @@ def load_lora_into_text_encoder( text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict, original_type=StateDictType.DIFFUSERS) + + for name, module in text_encoder.named_modules(): + if "lora_A" not in name and "lora_B" not in name and isinstance(module, (nn.Linear, nn.Conv2d)): + rank_key = f"{name.removesuffix(".base_layer")}.lora_B.weight" + if rank_key in text_encoder_lora_state_dict: + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ @@ -1440,21 +1426,13 @@ def load_lora_into_text_encoder( text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict, original_type=StateDictType.DIFFUSERS) + + for name, module in text_encoder.named_modules(): + if "lora_A" not in name and "lora_B" not in name and isinstance(module, (nn.Linear, nn.Conv2d)): + rank_key = f"{name.removesuffix(".base_layer")}.lora_B.weight" + if rank_key in text_encoder_lora_state_dict: + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ @@ -1647,7 +1625,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`]. """ - _lora_loadable_modules = ["transformer"] + _lora_loadable_modules = ["transformer", "text_encoder"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] @@ -2181,21 +2159,13 @@ def load_lora_into_text_encoder( text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict, original_type=StateDictType.DIFFUSERS) + + for name, module in text_encoder.named_modules(): + if "lora_A" not in name and "lora_B" not in name and isinstance(module, (nn.Linear, nn.Conv2d)): + rank_key = f"{name.removesuffix(".base_layer")}.lora_B.weight" + if rank_key in text_encoder_lora_state_dict: + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ @@ -2820,21 +2790,13 @@ def load_lora_into_text_encoder( text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict, original_type=StateDictType.DIFFUSERS) + + for name, module in text_encoder.named_modules(): + if "lora_A" not in name and "lora_B" not in name and isinstance(module, (nn.Linear, nn.Conv2d)): + rank_key = f"{name.removesuffix(".base_layer")}.lora_B.weight" + if rank_key in text_encoder_lora_state_dict: + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ @@ -3385,21 +3347,13 @@ def load_lora_into_text_encoder( text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict, original_type=StateDictType.DIFFUSERS) + + for name, module in text_encoder.named_modules(): + if "lora_A" not in name and "lora_B" not in name and isinstance(module, (nn.Linear, nn.Conv2d)): + rank_key = f"{name.removesuffix(".base_layer")}.lora_B.weight" + if rank_key in text_encoder_lora_state_dict: + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 4e9e0c07ca75..3b54303584bf 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -38,7 +38,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def text_encoder_attn_modules(text_encoder): +def text_encoder_attn_modules(text_encoder: nn.Module): attn_modules = [] if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): @@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder): return attn_modules -def text_encoder_mlp_modules(text_encoder): +def text_encoder_mlp_modules(text_encoder: nn.Module): mlp_modules = [] if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index f8de48ecfc78..ec2ffcec6358 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -116,6 +116,7 @@ ) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .state_dict_utils import ( + StateDictType, convert_all_state_dict_to_peft, convert_state_dict_to_diffusers, convert_state_dict_to_kohya, From 7e63330ef132397227b0ffc44838055f0082ee8c Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 19 Jan 2025 14:51:16 +0100 Subject: [PATCH 26/53] Restore LoRA tests. --- tests/lora/utils.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index bbb5f411cc10..e852983a65af 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -79,16 +79,6 @@ def initialize_dummy_state_dict(state_dict): POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"] -def require_te_lora_support(f): - @functools.wraps(f) - def wrapped(self: "PeftLoraLoaderMixinTests", *a, **kw): - if not self.supports_text_encoder_lora: - self.skipTest("Pipeline class doesn't support text encoder LoRA.") - return f(self, *a, **kw) - - return wrapped - - @require_peft_backend class PeftLoraLoaderMixinTests: @@ -284,7 +274,6 @@ def test_simple_inference(self): output_no_lora = pipe(**inputs)[0] self.assertTrue(output_no_lora.shape == self.output_shape) - @require_te_lora_support def test_simple_inference_with_text_lora(self): """ Tests a simple inference with lora attached on the text encoder @@ -446,7 +435,6 @@ def test_low_cpu_mem_usage_with_loading(self): "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.", ) - @require_te_lora_support def test_simple_inference_with_text_lora_and_scale(self): """ Tests a simple inference with lora attached on the text encoder + scale argument @@ -503,7 +491,6 @@ def test_simple_inference_with_text_lora_and_scale(self): "Lora + 0 scale should lead to same result as no LoRA", ) - @require_te_lora_support def test_simple_inference_with_text_lora_fused(self): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model @@ -544,7 +531,6 @@ def test_simple_inference_with_text_lora_fused(self): np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) - @require_te_lora_support def test_simple_inference_with_text_lora_unloaded(self): """ Tests a simple inference with lora attached to text encoder, then unloads the lora weights @@ -593,7 +579,6 @@ def test_simple_inference_with_text_lora_unloaded(self): "Fused lora should change the output", ) - @require_te_lora_support def test_simple_inference_with_text_lora_save_load(self): """ Tests a simple usecase where users could use saving utilities for LoRA. @@ -645,7 +630,6 @@ def test_simple_inference_with_text_lora_save_load(self): "Loading from saved checkpoints should give same results.", ) - @require_te_lora_support def test_simple_inference_with_partial_text_lora(self): """ Tests a simple inference with lora attached on the text encoder @@ -814,7 +798,6 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): "Loading from saved checkpoints should give same results.", ) - @require_te_lora_support def test_simple_inference_with_text_denoiser_lora_and_scale(self): """ Tests a simple inference with lora attached on the text encoder + Unet + scale argument @@ -881,7 +864,6 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): "The scaling parameter has not been correctly restored!", ) - @require_te_lora_support def test_simple_inference_with_text_lora_denoiser_fused(self): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model @@ -935,7 +917,6 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) - @require_te_lora_support def test_simple_inference_with_text_denoiser_lora_unloaded(self): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights @@ -988,7 +969,6 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): "Fused lora should change the output", ) - @require_te_lora_support def test_simple_inference_with_text_denoiser_lora_unfused( self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 ): @@ -1044,7 +1024,6 @@ def test_simple_inference_with_text_denoiser_lora_unfused( "Fused lora should not change the output", ) - @require_te_lora_support def test_simple_inference_with_text_denoiser_multi_adapter(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches @@ -1156,7 +1135,6 @@ def test_wrong_adapter_name_raises_error(self): pipe.set_adapters("adapter-1") _ = pipe(**inputs, generator=torch.manual_seed(0))[0] - @require_te_lora_support def test_simple_inference_with_text_denoiser_block_scale(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches @@ -1214,7 +1192,6 @@ def test_simple_inference_with_text_denoiser_block_scale(self): "output with no lora and output with lora disabled should give same results", ) - @require_te_lora_support def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches @@ -1289,7 +1266,6 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): with self.assertRaises(ValueError): pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) - @require_te_lora_support def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): """Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" @@ -1379,7 +1355,6 @@ def all_possible_dict_opts(unet, value): pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error - @require_te_lora_support def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches @@ -1474,7 +1449,6 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): "output with no lora and output with lora disabled should give same results", ) - @require_te_lora_support def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches @@ -1701,7 +1675,6 @@ def test_get_list_adapters(self): self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) @require_peft_version_greater(peft_version="0.6.2") - @require_te_lora_support def test_simple_inference_with_text_lora_denoiser_fused_multi( self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 ): @@ -1880,7 +1853,6 @@ def test_unexpected_keys_warning(self): self.assertTrue(".diffusers_cat" in cap_logger.out) @unittest.skip("This is failing for now - need to investigate") - @require_te_lora_support def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self): """ Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights From 5620384df363f0967c4dbc72925ac62160524a54 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 19 Jan 2025 14:55:14 +0100 Subject: [PATCH 27/53] Undo adding LoRA support for non-CLIP TEs. --- src/diffusers/loaders/lora_base.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 7ae8570a8893..0c584777affc 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -27,7 +27,6 @@ from ..models.modeling_utils import ModelMixin, load_state_dict from ..utils import ( USE_PEFT_BACKEND, - StateDictType, _get_model_file, convert_state_dict_to_diffusers, convert_state_dict_to_peft, @@ -51,6 +50,7 @@ if is_transformers_available(): from transformers import PreTrainedModel + from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules if is_peft_available(): from peft.tuners.tuners_utils import BaseTunerLayer @@ -356,13 +356,21 @@ def _load_lora_into_text_encoder( text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict, original_type=StateDictType.DIFFUSERS) - - for name, module in text_encoder.named_modules(): - if "lora_A" not in name and "lora_B" not in name and isinstance(module, (nn.Linear, nn.Conv2d)): - rank_key = f"{name.removesuffix(".base_layer")}.lora_B.weight" - if rank_key in text_encoder_lora_state_dict: - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] From cd691d3762131c548974d61c90889d49a9d1dc89 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 19 Jan 2025 14:56:00 +0100 Subject: [PATCH 28/53] Undo support for TE in AuraFlow LoRA. --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 70db7c84a65a..c23312153fe7 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1340,7 +1340,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`]. """ - _lora_loadable_modules = ["transformer", "text_encoder"] + _lora_loadable_modules = ["transformer"] transformer_name = TRANSFORMER_NAME text_encoder_name = TEXT_ENCODER_NAME _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] From 0fa5cd5d1406400a6b031db7f5d790de2e59c3e5 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 19 Jan 2025 14:57:44 +0100 Subject: [PATCH 29/53] `make fix-copies` --- src/diffusers/loaders/lora_pipeline.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index c23312153fe7..f11ac10b7077 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1635,10 +1635,17 @@ def _maybe_expand_transformer_param_shape_or_error_( in_features = state_dict[lora_A_weight_name].shape[1] out_features = state_dict[lora_B_weight_name].shape[0] + # Model maybe loaded with different quantization schemes which may flatten the params. + # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models + # preserve weight shape. + module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module) + # This means there's no need for an expansion in the params, so we simply skip. - if tuple(module_weight.shape) == (out_features, in_features): + if tuple(module_weight_shape) == (out_features, in_features): continue + # TODO (sayakpaul): We still need to consider if the module we're expanding is + # quantized and handle it accordingly if that is the case. module_out_features, module_in_features = module_weight.shape debug_message = "" if in_features > module_in_features: @@ -1735,13 +1742,16 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): base_weight_param = transformer_state_dict[base_param_name] lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] - if base_weight_param.shape[1] > lora_A_param.shape[1]: + # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization. + base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name) + + if base_module_shape[1] > lora_A_param.shape[1]: shape = (lora_A_param.shape[0], base_weight_param.shape[1]) expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight expanded_module_names.add(k) - elif base_weight_param.shape[1] < lora_A_param.shape[1]: + elif base_module_shape[1] < lora_A_param.shape[1]: raise NotImplementedError( f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." ) From 83e0825a74d8cc089e277321a34a36e423168f24 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 19 Jan 2025 15:09:12 +0100 Subject: [PATCH 30/53] Sync with upstream changes. --- src/diffusers/loaders/lora_pipeline.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index f11ac10b7077..406a380e096d 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1990,6 +1990,29 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * super().unfuse_lora(components=components) + @staticmethod + # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin._calculate_module_shape + def _calculate_module_shape( + model: "torch.nn.Module", + base_module: "torch.nn.Linear" = None, + base_weight_param_name: str = None, + ) -> "torch.Size": + def _get_weight_shape(weight: torch.Tensor): + return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape + + if base_module is not None: + return _get_weight_shape(base_module.weight) + elif base_weight_param_name is not None: + if not base_weight_param_name.endswith(".weight"): + raise ValueError( + f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}." + ) + module_path = base_weight_param_name.rsplit(".weight", 1)[0] + submodule = get_submodule_by_name(model, module_path) + return _get_weight_shape(submodule.weight) + + raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") + class FluxLoraLoaderMixin(LoraBaseMixin): r""" From 12fbd118538f84c0a83dfbe40e1cda857e46ad9f Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Sun, 19 Jan 2025 15:17:17 +0100 Subject: [PATCH 31/53] Remove unneeded stuff. --- src/diffusers/loaders/lora_pipeline.py | 415 ++---------------------- tests/lora/test_lora_layers_auraflow.py | 31 +- 2 files changed, 60 insertions(+), 386 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 406a380e096d..836f541601ea 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1342,16 +1342,13 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): _lora_loadable_modules = ["transformer"] transformer_name = TRANSFORMER_NAME - text_encoder_name = TEXT_ENCODER_NAME - _control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - return_alphas: bool = False, **kwargs, ): r""" @@ -1435,84 +1432,43 @@ def lora_state_dict( user_agent=user_agent, allow_pickle=allow_pickle, ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - # TODO (sayakpaul): to a follow-up to clean and try to unify the conditions. - is_kohya = any(".lora_down.weight" in k for k in state_dict) - if is_kohya: - state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) - # Kohya already takes care of scaling the LoRA parameters with alpha. - return (state_dict, None) if return_alphas else state_dict - - is_xlabs = any("processor" in k for k in state_dict) - if is_xlabs: - state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) - # xlabs doesn't use `alpha`. - return (state_dict, None) if return_alphas else state_dict - - is_bfl_control = any("query_norm.scale" in k for k in state_dict) - if is_bfl_control: - state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) - return (state_dict, None) if return_alphas else state_dict - - # For state dicts like - # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA - keys = list(state_dict.keys()) - network_alphas = {} - for k in keys: - if "alpha" in k: - alpha_value = state_dict.get(k) - if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance( - alpha_value, float - ): - network_alphas[k] = state_dict.pop(k) - else: - raise ValueError( - f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." - ) - - if return_alphas: - return state_dict, network_alphas - else: - return state_dict + return state_dict - # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_weights + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and - `self.text_encoder`. - - All kwargs are forwarded to `self.lora_state_dict`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is - loaded. - + `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See + [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state dict is loaded into `self.transformer`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. low_cpu_mem_usage (`bool`, *optional*): - `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) @@ -1522,251 +1478,24 @@ def load_lora_weights( pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - state_dict, network_alphas = self.lora_state_dict( - pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs - ) - - has_lora_keys = any("lora" in key for key in state_dict.keys()) - - # Flux Control LoRAs also have norm keys - has_norm_keys = any( - norm_key in key for key in state_dict.keys() for norm_key in self._control_lora_supported_norm_keys - ) + state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - if not (has_lora_keys or has_norm_keys): + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") - transformer_lora_state_dict = { - k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k - } - transformer_norm_state_dict = { - k: state_dict.pop(k) - for k in list(state_dict.keys()) - if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys) - } - - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_( - transformer, transformer_lora_state_dict, transformer_norm_state_dict - ) - - if has_param_with_expanded_shape: - logger.info( - "The LoRA weights contain parameters that have different shapes that expected by the transformer. " - "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. " - "To get a comprehensive list of parameter names that were modified, enable debug logging." - ) - transformer_lora_state_dict = self._maybe_expand_lora_state_dict( - transformer=transformer, lora_state_dict=transformer_lora_state_dict + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) - if len(transformer_lora_state_dict) > 0: - self.load_lora_into_transformer( - transformer_lora_state_dict, - network_alphas=network_alphas, - transformer=transformer, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - if len(transformer_norm_state_dict) > 0: - transformer._transformer_norm_layers = self._load_norm_into_transformer( - transformer_norm_state_dict, - transformer=transformer, - discard_original_layers=False, - ) - - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin._maybe_expand_transformer_param_shape_or_error_ - def _maybe_expand_transformer_param_shape_or_error_( - cls, - transformer: torch.nn.Module, - lora_state_dict=None, - norm_state_dict=None, - prefix=None, - ) -> bool: - """ - Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and - generalizes things a bit so that any parameter that needs expansion receives appropriate treatement. - """ - state_dict = {} - if lora_state_dict is not None: - state_dict.update(lora_state_dict) - if norm_state_dict is not None: - state_dict.update(norm_state_dict) - - # Remove prefix if present - prefix = prefix or cls.transformer_name - for key in list(state_dict.keys()): - if key.split(".")[0] == prefix: - state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key) - - # Expand transformer parameter shapes if they don't match lora - has_param_with_shape_update = False - overwritten_params = {} - - is_peft_loaded = getattr(transformer, "peft_config", None) is not None - for name, module in transformer.named_modules(): - if isinstance(module, torch.nn.Linear): - module_weight = module.weight.data - module_bias = module.bias.data if module.bias is not None else None - bias = module_bias is not None - - lora_base_name = name.replace(".base_layer", "") if is_peft_loaded else name - lora_A_weight_name = f"{lora_base_name}.lora_A.weight" - lora_B_weight_name = f"{lora_base_name}.lora_B.weight" - if lora_A_weight_name not in state_dict: - continue - - in_features = state_dict[lora_A_weight_name].shape[1] - out_features = state_dict[lora_B_weight_name].shape[0] - - # Model maybe loaded with different quantization schemes which may flatten the params. - # `bitsandbytes`, for example, flatten the weights when using 4bit. 8bit bnb models - # preserve weight shape. - module_weight_shape = cls._calculate_module_shape(model=transformer, base_module=module) - - # This means there's no need for an expansion in the params, so we simply skip. - if tuple(module_weight_shape) == (out_features, in_features): - continue - - # TODO (sayakpaul): We still need to consider if the module we're expanding is - # quantized and handle it accordingly if that is the case. - module_out_features, module_in_features = module_weight.shape - debug_message = "" - if in_features > module_in_features: - debug_message += ( - f'Expanding the nn.Linear input/output features for module="{name}" because the provided LoRA ' - f"checkpoint contains higher number of features than expected. The number of input_features will be " - f"expanded from {module_in_features} to {in_features}" - ) - if out_features > module_out_features: - debug_message += ( - ", and the number of output features will be " - f"expanded from {module_out_features} to {out_features}." - ) - else: - debug_message += "." - if debug_message: - logger.debug(debug_message) - - if out_features > module_out_features or in_features > module_in_features: - has_param_with_shape_update = True - parent_module_name, _, current_module_name = name.rpartition(".") - parent_module = transformer.get_submodule(parent_module_name) - - with torch.device("meta"): - expanded_module = torch.nn.Linear( - in_features, out_features, bias=bias, dtype=module_weight.dtype - ) - # Only weights are expanded and biases are not. This is because only the input dimensions - # are changed while the output dimensions remain the same. The shape of the weight tensor - # is (out_features, in_features), while the shape of bias tensor is (out_features,), which - # explains the reason why only weights are expanded. - new_weight = torch.zeros_like( - expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype - ) - slices = tuple(slice(0, dim) for dim in module_weight.shape) - new_weight[slices] = module_weight - tmp_state_dict = {"weight": new_weight} - if module_bias is not None: - tmp_state_dict["bias"] = module_bias - expanded_module.load_state_dict(tmp_state_dict, strict=True, assign=True) - - setattr(parent_module, current_module_name, expanded_module) - - del tmp_state_dict - - if current_module_name in _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX: - attribute_name = _MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX[current_module_name] - new_value = int(expanded_module.weight.data.shape[1]) - old_value = getattr(transformer.config, attribute_name) - setattr(transformer.config, attribute_name, new_value) - logger.info( - f"Set the {attribute_name} attribute of the model to {new_value} from {old_value}." - ) - - # For `unload_lora_weights()`. - # TODO: this could lead to more memory overhead if the number of overwritten params - # are large. Should be revisited later and tackled through a `discard_original_layers` arg. - overwritten_params[f"{current_module_name}.weight"] = module_weight - if module_bias is not None: - overwritten_params[f"{current_module_name}.bias"] = module_bias - - if len(overwritten_params) > 0: - transformer._overwritten_params = overwritten_params - - return has_param_with_shape_update - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin._maybe_expand_lora_state_dict - def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict): - expanded_module_names = set() - transformer_state_dict = transformer.state_dict() - prefix = f"{cls.transformer_name}." - - lora_module_names = [ - key[: -len(".lora_A.weight")] for key in lora_state_dict if key.endswith(".lora_A.weight") - ] - lora_module_names = [name[len(prefix) :] for name in lora_module_names if name.startswith(prefix)] - lora_module_names = sorted(set(lora_module_names)) - transformer_module_names = sorted({name for name, _ in transformer.named_modules()}) - unexpected_modules = set(lora_module_names) - set(transformer_module_names) - if unexpected_modules: - logger.debug(f"Found unexpected modules: {unexpected_modules}. These will be ignored.") - - is_peft_loaded = getattr(transformer, "peft_config", None) is not None - for k in lora_module_names: - if k in unexpected_modules: - continue - - base_param_name = ( - f"{k.replace(prefix, '')}.base_layer.weight" - if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict - else f"{k.replace(prefix, '')}.weight" - ) - base_weight_param = transformer_state_dict[base_param_name] - lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"] - - # TODO (sayakpaul): Handle the cases when we actually need to expand when using quantization. - base_module_shape = cls._calculate_module_shape(model=transformer, base_weight_param_name=base_param_name) - - if base_module_shape[1] > lora_A_param.shape[1]: - shape = (lora_A_param.shape[0], base_weight_param.shape[1]) - expanded_state_dict_weight = torch.zeros(shape, device=base_weight_param.device) - expanded_state_dict_weight[:, : lora_A_param.shape[1]].copy_(lora_A_param) - lora_state_dict[f"{prefix}{k}.lora_A.weight"] = expanded_state_dict_weight - expanded_module_names.add(k) - elif base_module_shape[1] < lora_A_param.shape[1]: - raise NotImplementedError( - f"This LoRA param ({k}.lora_A.weight) has an incompatible shape {lora_A_param.shape}. Please open an issue to file for a feature request - https://github.com/huggingface/diffusers/issues/new." - ) - - if expanded_module_names: - logger.info( - f"The following LoRA modules were zero padded to match the state dict of {cls.transformer_name}: {expanded_module_names}. Please open an issue if you think this was unexpected - https://github.com/huggingface/diffusers/issues/new." - ) - - return lora_state_dict - @classmethod - # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->AuraFlowTransformer2DModel + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel def load_lora_into_transformer( - cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1776,10 +1505,6 @@ def load_lora_into_transformer( A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). transformer (`AuraFlowTransformer2DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): @@ -1789,81 +1514,27 @@ def load_lora_into_transformer( Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. """ - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) # Load the layers corresponding to transformer. - keys = list(state_dict.keys()) - transformer_present = any(key.startswith(cls.transformer_name) for key in keys) - if transformer_present: - logger.info(f"Loading {cls.transformer_name}.") - transformer.load_lora_adapter( - state_dict, - network_alphas=network_alphas, - adapter_name=adapter_name, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - """ - _load_lora_into_text_encoder( - state_dict=state_dict, - network_alphas=network_alphas, - lora_scale=lora_scale, - text_encoder=text_encoder, - prefix=prefix, - text_encoder_name=cls.text_encoder_name, + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights with unet->transformer def save_lora_weights( cls, save_directory: Union[str, os.PathLike], transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -1877,9 +1548,6 @@ def save_lora_weights( Directory to save LoRA parameters to. Will be created if it doesn't exist. transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `transformer`. - text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main @@ -1893,15 +1561,12 @@ def save_lora_weights( """ state_dict = {} - if not (transformer_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") + if not transformer_lora_layers: + raise ValueError("You must pass `transformer_lora_layers`.") if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) - if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) - # Save the model cls.write_lora_layers( state_dict=state_dict, @@ -1912,7 +1577,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora_pipeline.Mochi1LoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -1952,24 +1617,11 @@ def fuse_lora( pipeline.fuse_lora(lora_scale=0.7) ``` """ - - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - if ( - hasattr(transformer, "_transformer_norm_layers") - and isinstance(transformer._transformer_norm_layers, dict) - and len(transformer._transformer_norm_layers.keys()) > 0 - ): - logger.info( - "The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer " - "as opposed to the LoRA layers that will co-exist separately until the 'fuse_lora()' method is called. That is to say, the normalization layers will always be directly " - "fused into the transformer and can only be unfused if `discard_original_layers=True` is passed." - ) - super().fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) - # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora_pipeline.Mochi1LoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): r""" Reverses the effect of @@ -1983,11 +1635,8 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer - if hasattr(transformer, "_transformer_norm_layers") and transformer._transformer_norm_layers: - transformer.load_state_dict(transformer._transformer_norm_layers, strict=False) - super().unfuse_lora(components=components) @staticmethod diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 33d046e4f207..0f30759b718c 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -41,9 +41,10 @@ @require_peft_backend class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = AuraFlowPipeline - scheduler_cls = FlowMatchEulerDiscreteScheduler() - scheduler_kwargs = {} + scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_classes = [FlowMatchEulerDiscreteScheduler] + scheduler_kwargs = {} + uses_flow_matching = True transformer_kwargs = { "sample_size": 64, @@ -103,9 +104,33 @@ def get_dummy_inputs(self, with_generator=True): return noise, input_ids, pipeline_inputs @unittest.skip("Not supported in AuraFlow.") - def test_modify_padding_mode(self): + def test_simple_inference_with_text_denoiser_block_scale(self): pass @unittest.skip("Not supported in AuraFlow.") def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): pass + + @unittest.skip("Not supported in AuraFlow.") + def test_modify_padding_mode(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + def test_simple_inference_with_partial_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + def test_simple_inference_with_text_lora(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + def test_simple_inference_with_text_lora_and_scale(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + def test_simple_inference_with_text_lora_fused(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in AuraFlow.") + def test_simple_inference_with_text_lora_save_load(self): + pass From cdd184d44f6a0e2faef795265fda251591fb987c Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 26 Feb 2025 10:41:34 +0100 Subject: [PATCH 32/53] Mirror `Lumina2`. --- src/diffusers/loaders/lora_pipeline.py | 31 +++---------------- .../transformers/auraflow_transformer_2d.py | 3 +- .../pipelines/aura_flow/pipeline_aura_flow.py | 6 +++- 3 files changed, 11 insertions(+), 29 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index dba659836bb7..96528add7871 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1345,7 +1345,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): @classmethod @validate_hf_hub_args - # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -1530,7 +1530,7 @@ def load_lora_into_transformer( ) @classmethod - # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights with unet->transformer + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], @@ -1577,7 +1577,7 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.Mochi1LoraLoaderMixin.fuse_lora + # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora def fuse_lora( self, components: List[str] = ["transformer"], @@ -1621,7 +1621,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) - # Copied from diffusers.loaders.lora_pipeline.Mochi1LoraLoaderMixin.unfuse_lora + # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): r""" Reverses the effect of @@ -1639,29 +1639,6 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * """ super().unfuse_lora(components=components) - @staticmethod - # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin._calculate_module_shape - def _calculate_module_shape( - model: "torch.nn.Module", - base_module: "torch.nn.Linear" = None, - base_weight_param_name: str = None, - ) -> "torch.Size": - def _get_weight_shape(weight: torch.Tensor): - return weight.quant_state.shape if weight.__class__.__name__ == "Params4bit" else weight.shape - - if base_module is not None: - return _get_weight_shape(base_module.weight) - elif base_weight_param_name is not None: - if not base_weight_param_name.endswith(".weight"): - raise ValueError( - f"Invalid `base_weight_param_name` passed as it does not end with '.weight' {base_weight_param_name=}." - ) - module_path = base_weight_param_name.rsplit(".weight", 1)[0] - submodule = get_submodule_by_name(model, module_path) - return _get_weight_shape(submodule.weight) - - raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") - class FluxLoraLoaderMixin(LoraBaseMixin): r""" diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 8d873fe4bae6..775d1c800bbd 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -254,7 +254,7 @@ def forward( return encoder_hidden_states, hidden_states -class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): r""" A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). @@ -467,6 +467,7 @@ def forward( lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 + if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 8103cc22ed33..580683d11b0b 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -507,6 +507,10 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -544,8 +548,8 @@ def __call__( callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) - self._attention_kwargs = attention_kwargs self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs # 2. Determine batch size. if prompt is not None and isinstance(prompt, str): From ce1939b3740fb1afcfd43428a08fef17f1105112 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 26 Feb 2025 10:59:07 +0100 Subject: [PATCH 33/53] Skip for MPS. --- tests/lora/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 05d04996ca28..55e32443c653 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2194,11 +2194,13 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): pipe_fp32 = initialize_pipeline(storage_dtype=None) pipe_fp32(**inputs, generator=torch.manual_seed(0))[0] - pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) - pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] + # MPS doesn't support float8 yet. + if torch_device not in {"mps"}: + pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) + pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] - pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) - pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] + pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] @require_peft_version_greater("0.14.0") def test_layerwise_casting_peft_input_autocast_denoiser(self): From 3b9e65543a080ab4eb978fe7c9c7f4d0198b1139 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 26 Feb 2025 11:42:01 +0100 Subject: [PATCH 34/53] Address review comments. --- .../transformers/auraflow_transformer_2d.py | 15 ++++++++++++ .../pipelines/aura_flow/pipeline_aura_flow.py | 23 +++++++++---------- tests/lora/utils.py | 1 - 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 775d1c800bbd..9a8ce781386a 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -452,6 +452,21 @@ def forward( attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + height, width = hidden_states.shape[-2:] # Apply patch embedding, timestep embedding, and project the caption embeddings. diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 580683d11b0b..5ca319dd611a 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -271,9 +271,6 @@ def encode_prompt( lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ - if device is None: - device = self._execution_device - # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin): @@ -283,6 +280,8 @@ def encode_prompt( if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) + if device is None: + device = self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -366,6 +365,11 @@ def encode_prompt( negative_prompt_embeds = None negative_prompt_attention_mask = None + if self.text_encoder is not None: + if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents @@ -423,6 +427,10 @@ def upcast_vae(self): def guidance_scale(self): return self._guidance_scale + @property + def attention_kwargs(self): + return self._attention_kwargs + @property def num_timesteps(self): return self._num_timesteps @@ -669,16 +677,7 @@ def __call__( # Offload all models self.maybe_free_model_hooks() - if self.text_encoder is not None: - if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - if not return_dict: return (image,) return ImagePipelineOutput(images=image) - - @property - def attention_kwargs(self): - return self._attention_kwargs diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 55e32443c653..38eeed1a99b6 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools import inspect import os import re From c11b14d7296d135081b6dc3f819a8186ee7e01dc Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 27 Feb 2025 15:28:07 +0100 Subject: [PATCH 35/53] Remove duplicated code. --- .../models/transformers/auraflow_transformer_2d.py | 13 ------------- src/diffusers/utils/__init__.py | 1 - 2 files changed, 14 deletions(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 9a8ce781386a..e3e11bdcf641 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -477,20 +477,7 @@ def forward( encoder_hidden_states = torch.cat( [self.register_tokens.repeat(encoder_hidden_states.size(0), 1, 1), encoder_hidden_states], dim=1 ) - if attention_kwargs is not None: - attention_kwargs = attention_kwargs.copy() - lora_scale = attention_kwargs.pop("scale", 1.0) - else: - lora_scale = 1.0 - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: - logger.warning( - "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." - ) # MMDiT blocks. for index_block, block in enumerate(self.joint_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b641d79fcfaf..08b1713d0e31 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -117,7 +117,6 @@ ) from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .state_dict_utils import ( - StateDictType, convert_all_state_dict_to_peft, convert_state_dict_to_diffusers, convert_state_dict_to_kohya, From 636f01cdf3d770a865be1c64d3d210e1dd301c4c Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 27 Feb 2025 15:43:26 +0100 Subject: [PATCH 36/53] Remove unnecessary code. --- src/diffusers/models/transformers/auraflow_transformer_2d.py | 2 +- tests/lora/test_lora_layers_auraflow.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index e3e11bdcf641..35f5d243e2f7 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from ..attention_processor import ( Attention, diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 0f30759b718c..94e158c2642b 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -45,7 +45,6 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): scheduler_classes = [FlowMatchEulerDiscreteScheduler] scheduler_kwargs = {} - uses_flow_matching = True transformer_kwargs = { "sample_size": 64, "patch_size": 1, @@ -74,7 +73,6 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): } tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5" - text_encoder_target_modules = ["q", "k", "v", "o"] @property def output_shape(self): From 75ba7dacd740a03e931d13db3f6c49a7b342fc2f Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 5 Mar 2025 10:13:57 +0100 Subject: [PATCH 37/53] Remove repeated docs. --- src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 5ca319dd611a..dff3f4a4a031 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -515,7 +515,7 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. - attention_kwargs: + attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). @@ -529,10 +529,6 @@ def __call__( will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. - attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). Examples: From c2daa8a9bd6b43bcea2d7e7d198577a2e56d92fb Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 5 Mar 2025 10:14:39 +0100 Subject: [PATCH 38/53] Propagate attention. --- .../models/transformers/auraflow_transformer_2d.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 35f5d243e2f7..065d1fa5e0cd 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -160,14 +160,15 @@ def __init__(self, dim, num_attention_heads, attention_head_dim): self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) self.ff = AuraFlowFeedForward(dim, dim * 4) - def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor): + def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, attention_kwargs: Optional[Dict[str, Any]] = None): residual = hidden_states + attention_kwargs = attention_kwargs or {} # Norm + Projection. norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) # Attention. - attn_output = self.attn(hidden_states=norm_hidden_states) + attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs) # Process attention outputs for the `hidden_states`. hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output) @@ -223,10 +224,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim): self.ff_context = AuraFlowFeedForward(dim, dim * 4) def forward( - self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor + self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, attention_kwargs: Optional[Dict[str, Any]] = None, ): residual = hidden_states residual_context = encoder_hidden_states + attention_kwargs = attention_kwargs or {} # Norm + Projection. norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) @@ -236,7 +238,7 @@ def forward( # Attention. attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states + hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **attention_kwargs, ) # Process attention outputs for the `hidden_states`. @@ -490,7 +492,7 @@ def forward( else: encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb + hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, attention_kwargs=attention_kwargs, ) # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text) @@ -507,7 +509,7 @@ def forward( ) else: - combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb) + combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs) hidden_states = combined_hidden_states[:, encoder_seq_len:] From 8aa2d6909da65ab69440f5b2ee211c4188de0802 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 6 Mar 2025 07:02:44 +0100 Subject: [PATCH 39/53] Fix TE target modules. --- tests/lora/test_lora_layers_auraflow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 94e158c2642b..72d70f7b2120 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -73,6 +73,7 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): } tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5" + text_encoder_target_modules = ["q", "k", "v", "o"] @property def output_shape(self): From b19942f106768d4b6eea1a3679fc17a5874f5dc5 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 6 Mar 2025 07:07:15 +0100 Subject: [PATCH 40/53] MPS fix for LoRA tests. --- tests/lora/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 38eeed1a99b6..03485c812950 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2224,7 +2224,7 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self): apply_layerwise_casting, ) - storage_dtype = torch.float8_e4m3fn + storage_dtype = torch.float8_e4m3fn if not torch_device == "mps" else torch.bfloat16 compute_dtype = torch.float32 def check_module(denoiser): From 50917570222ca151790fa9191d7d8d0afa9a2b11 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 6 Mar 2025 07:15:19 +0100 Subject: [PATCH 41/53] Unrelated TE LoRA tests fix. --- tests/lora/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 03485c812950..ccab78a6e3c4 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -641,9 +641,9 @@ def test_simple_inference_with_partial_text_lora(self): # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). text_lora_config = LoraConfig( r=4, - rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3}, + rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)}, lora_alpha=4, - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + target_modules=self.text_encoder_target_modules, init_lora_weights=False, use_dora=False, ) From dee9074b9e5a361314e393341bdb827dd6c79619 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Wed, 26 Mar 2025 03:32:07 +0100 Subject: [PATCH 42/53] Fix AuraFlow LoRA tests by applying to the right denoiser layers. Co-authored-by: AstraliteHeart <81396681+AstraliteHeart@users.noreply.github.com> --- tests/lora/test_lora_layers_auraflow.py | 1 + tests/lora/utils.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py index 72d70f7b2120..ac1fed608cc8 100644 --- a/tests/lora/test_lora_layers_auraflow.py +++ b/tests/lora/test_lora_layers_auraflow.py @@ -74,6 +74,7 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5" text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5" text_encoder_target_modules = ["q", "k", "v", "o"] + denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"] @property def output_shape(self): diff --git a/tests/lora/utils.py b/tests/lora/utils.py index ccab78a6e3c4..6f74c8abafe1 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -104,6 +104,7 @@ class PeftLoraLoaderMixinTests: vae_kwargs = None text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] def get_dummy_components(self, scheduler_cls=None, use_dora=False): if self.unet_kwargs and self.transformer_kwargs: @@ -157,7 +158,7 @@ def get_dummy_components(self, scheduler_cls=None, use_dora=False): denoiser_lora_config = LoraConfig( r=rank, lora_alpha=rank, - target_modules=["to_q", "to_k", "to_v", "to_out.0"], + target_modules=self.denoiser_target_modules, init_lora_weights=False, use_dora=use_dora, ) @@ -2040,7 +2041,7 @@ def test_lora_B_bias(self): bias_values = {} denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer for name, module in denoiser.named_modules(): - if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]): + if any(k in name for k in self.denoiser_target_modules): if module.bias is not None: bias_values[name] = module.bias.data.clone() From 65a3bf5f81f2f24073ab90ba2ef254df3662c767 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 8 Apr 2025 04:46:02 +0000 Subject: [PATCH 43/53] Apply style fixes --- .../transformers/auraflow_transformer_2d.py | 26 +++++++++++++++---- .../pipelines/aura_flow/pipeline_aura_flow.py | 4 +-- tests/lora/utils.py | 13 ++++++++-- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index 065d1fa5e0cd..8781424c61d3 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -160,7 +160,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim): self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False) self.ff = AuraFlowFeedForward(dim, dim * 4) - def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, attention_kwargs: Optional[Dict[str, Any]] = None): + def forward( + self, + hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + attention_kwargs: Optional[Dict[str, Any]] = None, + ): residual = hidden_states attention_kwargs = attention_kwargs or {} @@ -224,7 +229,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim): self.ff_context = AuraFlowFeedForward(dim, dim * 4) def forward( - self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, attention_kwargs: Optional[Dict[str, Any]] = None, + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + temb: torch.FloatTensor, + attention_kwargs: Optional[Dict[str, Any]] = None, ): residual = hidden_states residual_context = encoder_hidden_states @@ -238,7 +247,9 @@ def forward( # Attention. attn_output, context_attn_output = self.attn( - hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **attention_kwargs, + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + **attention_kwargs, ) # Process attention outputs for the `hidden_states`. @@ -492,7 +503,10 @@ def forward( else: encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, attention_kwargs=attention_kwargs, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + attention_kwargs=attention_kwargs, ) # Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text) @@ -509,7 +523,9 @@ def forward( ) else: - combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs) + combined_hidden_states = block( + hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs + ) hidden_states = combined_hidden_states[:, encoder_seq_len:] diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index dff3f4a4a031..7c98b3b71c48 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -564,9 +564,7 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - lora_scale = ( - self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None - ) + lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 1fd528397357..76a6aa581590 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2190,7 +2190,14 @@ def test_correct_lora_configs_with_different_ranks(self): @property def supports_text_encoder_lora(self): - return len({"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection(self.pipeline_class._lora_loadable_modules)) != 0 + return ( + len( + {"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection( + self.pipeline_class._lora_loadable_modules + ) + ) + != 0 + ) def test_layerwise_casting_inference_denoiser(self): from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS @@ -2249,7 +2256,9 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] - pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) + pipe_float8_e4m3_bf16 = initialize_pipeline( + storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16 + ) pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] @require_peft_version_greater("0.14.0") From 147a3564323dca494d5ac038ad8af34b1dbb0a62 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Apr 2025 10:46:09 +0530 Subject: [PATCH 44/53] empty commit From 0c91c1a185f675f8ec5535f76e30eff04c363179 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 8 Apr 2025 07:21:27 +0200 Subject: [PATCH 45/53] Fix the repo consistency issues. --- src/diffusers/loaders/lora_pipeline.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1b02dc771fd5..99f00014f9e6 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1611,7 +1611,11 @@ def fuse_lora( ``` """ super().fuse_lora( - components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora @@ -1630,7 +1634,7 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. """ - super().unfuse_lora(components=components) + super().unfuse_lora(components=components, **kwargs) class FluxLoraLoaderMixin(LoraBaseMixin): From e97a83ea7ed53b1ed27834bcbe390bbbc6ab28b3 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 8 Apr 2025 07:33:54 +0200 Subject: [PATCH 46/53] Remove unrelated changes. --- tests/lora/utils.py | 27 +++++++-------------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 76a6aa581590..b865feecf89c 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2188,17 +2188,6 @@ def test_correct_lora_configs_with_different_ranks(self): self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) - @property - def supports_text_encoder_lora(self): - return ( - len( - {"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection( - self.pipeline_class._lora_loadable_modules - ) - ) - != 0 - ) - def test_layerwise_casting_inference_denoiser(self): from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS @@ -2251,15 +2240,13 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): pipe_fp32 = initialize_pipeline(storage_dtype=None) pipe_fp32(**inputs, generator=torch.manual_seed(0))[0] - # MPS doesn't support float8 yet. - if torch_device not in {"mps"}: - pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) - pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] + pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) + pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] - pipe_float8_e4m3_bf16 = initialize_pipeline( - storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16 - ) - pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] + pipe_float8_e4m3_bf16 = initialize_pipeline( + storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16 + ) + pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] @require_peft_version_greater("0.14.0") def test_layerwise_casting_peft_input_autocast_denoiser(self): @@ -2284,7 +2271,7 @@ def test_layerwise_casting_peft_input_autocast_denoiser(self): apply_layerwise_casting, ) - storage_dtype = torch.float8_e4m3fn if not torch_device == "mps" else torch.bfloat16 + storage_dtype = torch.float8_e4m3fn compute_dtype = torch.float32 def check_module(denoiser): From a5b78d166f246789f343d638191b0cc5da322402 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 8 Apr 2025 07:36:28 +0200 Subject: [PATCH 47/53] Style. --- tests/lora/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index b865feecf89c..3d9f9043890d 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2243,9 +2243,7 @@ def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32) pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0] - pipe_float8_e4m3_bf16 = initialize_pipeline( - storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16 - ) + pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16) pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0] @require_peft_version_greater("0.14.0") From dbc84273f1d2a4124ea61334eb4eb8122b9943e4 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 8 Apr 2025 08:37:54 +0200 Subject: [PATCH 48/53] Fix `test_lora_fuse_nan`. --- tests/lora/utils.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 3d9f9043890d..ae9f47c82e2e 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1563,7 +1563,6 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): "output with no lora and output with lora disabled should give same results", ) - @skip_mps @pytest.mark.xfail( condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", @@ -1595,17 +1594,26 @@ def test_lora_fuse_nan(self): ].weight += float("inf") else: named_modules = [name for name, _ in pipe.transformer.named_modules()] - tower_name = ( - "transformer_blocks" - if any(name == "transformer_blocks" for name in named_modules) - else "blocks" - ) - transformer_tower = getattr(pipe.transformer, tower_name) - has_attn1 = any("attn1" in name for name in named_modules) - if has_attn1: - transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") - else: - transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") + possible_tower_names = [ + "transformer_blocks", + "blocks", + "joint_transformer_blocks", + "single_transformer_blocks", + ] + filtered_tower_names = [ + tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name) + ] + if len(filtered_tower_names) == 0: + pytest.xfail( + reason=f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}." + ) + for tower_name in filtered_tower_names: + transformer_tower = getattr(pipe.transformer, tower_name) + has_attn1 = any("attn1" in name for name in named_modules) + if has_attn1: + transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") + else: + transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") # with `safe_fusing=True` we should see an Error with self.assertRaises(ValueError): From ea14465e565c7242599bf98292598c71619d6284 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 8 Apr 2025 13:02:46 +0530 Subject: [PATCH 49/53] fix quality issues. --- tests/lora/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index ae9f47c82e2e..854853188435 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -38,7 +38,6 @@ require_peft_backend, require_peft_version_greater, require_transformers_version_greater, - skip_mps, torch_device, ) From a20d03d96f1cefff39f30de2ee15962655ebbbcd Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 8 Apr 2025 09:43:44 +0200 Subject: [PATCH 50/53] `pytest.xfail` -> `ValueError`. --- tests/lora/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 854853188435..9a3364ef4ab6 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -1603,9 +1603,10 @@ def test_lora_fuse_nan(self): tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name) ] if len(filtered_tower_names) == 0: - pytest.xfail( - reason=f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}." + reason = ( + f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}." ) + raise ValueError(reason) for tower_name in filtered_tower_names: transformer_tower = getattr(pipe.transformer, tower_name) has_attn1 = any("attn1" in name for name in named_modules) From fb5f5f7f24f4511a5ce76495149298f623fd6790 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Tue, 8 Apr 2025 11:38:56 +0200 Subject: [PATCH 51/53] Add back `skip_mps`. --- tests/lora/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 9a3364ef4ab6..768dcc6d1a71 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -39,6 +39,7 @@ require_peft_version_greater, require_transformers_version_greater, torch_device, + skip_mps, ) @@ -1562,6 +1563,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): "output with no lora and output with lora disabled should give same results", ) + @skip_mps @pytest.mark.xfail( condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"), reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", From 12dc911301ac0998ea57714cd2fce5ab9ae0894f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 8 Apr 2025 12:47:08 +0000 Subject: [PATCH 52/53] Apply style fixes --- tests/lora/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 768dcc6d1a71..ba9ede95a3f1 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -38,8 +38,8 @@ require_peft_backend, require_peft_version_greater, require_transformers_version_greater, - torch_device, skip_mps, + torch_device, ) From bc93160bfc86ef02d41748d27c09d73d58e4a36c Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Fri, 11 Apr 2025 14:34:31 +0200 Subject: [PATCH 53/53] `make fix-copies` --- src/diffusers/loaders/lora_pipeline.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 43bbd781f670..aa508cf87f40 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1753,7 +1753,7 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1771,6 +1771,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -1785,6 +1808,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod