From d3fbd7bbc147a11b74cad2da826c2f236f44b33c Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 17 Sep 2024 16:09:01 +0200 Subject: [PATCH 01/36] [WIP][LoRA] Implement hot-swapping of LoRA This PR adds the possibility to hot-swap LoRA adapters. It is WIP. Description As of now, users can already load multiple LoRA adapters. They can offload existing adapters or they can unload them (i.e. delete them). However, they cannot "hotswap" adapters yet, i.e. substitute the weights from one LoRA adapter with the weights of another, without the need to create a separate LoRA adapter. Generally, hot-swapping may not appear not super useful but when the model is compiled, it is necessary to prevent recompilation. See #9279 for more context. Caveats To hot-swap a LoRA adapter for another, these two adapters should target exactly the same layers and the "hyper-parameters" of the two adapters should be identical. For instance, the LoRA alpha has to be the same: Given that we keep the alpha from the first adapter, the LoRA scaling would be incorrect for the second adapter otherwise. Theoretically, we could override the scaling dict with the alpha values derived from the second adapter's config, but changing the dict will trigger a guard for recompilation, defeating the main purpose of the feature. I also found that compilation flags can have an impact on whether this works or not. E.g. when passing "reduce-overhead", there will be errors of the type: > input name: arg861_1. data pointer changed from 139647332027392 to 139647331054592 I don't know enough about compilation to determine whether this is problematic or not. Current state This is obviously WIP right now to collect feedback and discuss which direction to take this. If this PR turns out to be useful, the hot-swapping functions will be added to PEFT itself and can be imported here (or there is a separate copy in diffusers to avoid the need for a min PEFT version to use this feature). Moreover, more tests need to be added to better cover this feature, although we don't necessarily need tests for the hot-swapping functionality itself, since those tests will be added to PEFT. Furthermore, as of now, this is only implemented for the unet. Other pipeline components have yet to implement this feature. Finally, it should be properly documented. I would like to collect feedback on the current state of the PR before putting more time into finalizing it. --- src/diffusers/loaders/lora_pipeline.py | 17 +++- src/diffusers/loaders/unet.py | 114 +++++++++++++++++++++++-- tests/pipelines/test_pipelines.py | 38 +++++++++ 3 files changed, 161 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7d644d684153..a54b82161ecc 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -63,7 +63,11 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): text_encoder_name = TEXT_ENCODER_NAME def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name=None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and @@ -88,6 +92,7 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + hotswap TODO """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -109,6 +114,7 @@ def load_lora_weights( unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, _pipeline=self, + hotswap=hotswap, ) self.load_lora_into_text_encoder( state_dict, @@ -232,7 +238,7 @@ def lora_state_dict( return state_dict, network_alphas @classmethod - def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, hotswap: bool = False): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -250,6 +256,7 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + hotswap TODO """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -263,7 +270,11 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( - state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + hotswap=hotswap, ) @classmethod diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 32ace77b6224..970416ea1dd9 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -66,7 +66,7 @@ class UNet2DConditionLoadersMixin: unet_name = UNET_NAME @validate_hf_hub_args - def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], hotswap: bool = False, **kwargs): r""" Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be defined in @@ -115,6 +115,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict `default_{i}` where i is the total number of adapters being loaded. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + hotswap TODO Example: @@ -209,6 +210,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline, + hotswap=hotswap, ) else: raise ValueError( @@ -268,7 +270,7 @@ def _process_custom_diffusion(self, state_dict): return attn_processors - def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): + def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, hotswap: bool = False): # This method does the following things: # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy # format. For legacy format no filtering is applied. @@ -299,10 +301,12 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict if len(state_dict_to_be_used) > 0: - if adapter_name in getattr(self, "peft_config", {}): + if adapter_name in getattr(self, "peft_config", {}) and not hotswap: raise ValueError( f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." ) + elif adapter_name not in getattr(self, "peft_config", {}) and hotswap: + raise ValueError(f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name.") state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used) @@ -336,8 +340,108 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) - inject_adapter_in_model(lora_config, self, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name) + + def _check_hotswap_configs_compatible(config0, config1): + # To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they + # use different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the + # weights from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these + # values as well, but that's not implemented yet, and it would trigger a re-compilation if the model is compiled. + + # TODO: This is a very rough check at the moment and there are probably better ways than to error out + config_keys_to_check = ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"] + config0 = config0.to_dict() + config1 = config1.to_dict() + for key in config_keys_to_check: + val0 = config0[key] + val1 = config1[key] + if val0 != val1: + raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}") + + def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name): + """ + Swap out the LoRA weights from the model with the weights from state_dict. + + It is assumed that the existing adapter and the new adapter are compatible. + + Args: + model: nn.Module + The model with the loaded adapter. + state_dict: dict[str, torch.Tensor] + The state dict of the new adapter, which needs to be compatible (targeting same modules etc.). + adapter_name: Optional[str] + The name of the adapter that should be hot-swapped. + + Raises: + RuntimeError + If the old and the new adapter are not compatible, a RuntimeError is raised. + """ + from operator import attrgetter + + ####################### + # INSERT ADAPTER NAME # + ####################### + + remapped_state_dict = {} + expected_str = adapter_name + "." + for key, val in state_dict.items(): + if expected_str not in key: + prefix, _, suffix = key.rpartition(".") + key = f"{prefix}.{adapter_name}.{suffix}" + remapped_state_dict[key] = val + state_dict = remapped_state_dict + + #################### + # CHECK STATE_DICT # + #################### + + # Ensure that all the keys of the new adapter correspond exactly to the keys of the old adapter, otherwise + # hot-swapping is not possible + parameter_prefix = "lora_" # hard-coded for now + is_compiled = hasattr(model, "_orig_mod") + # TODO: there is probably a more precise way to identify the adapter keys + missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)} + unexpected_keys = set() + + # first: dry run, not swapping anything + for key, new_val in state_dict.items(): + try: + old_val = attrgetter(key)(model) + except AttributeError: + unexpected_keys.add(key) + continue + + if is_compiled: + missing_keys.remove("_orig_mod." + key) + else: + missing_keys.remove(key) + + if missing_keys or unexpected_keys: + msg = "Hot swapping the adapter did not succeed." + if missing_keys: + msg += f" Missing keys: {', '.join(sorted(missing_keys))}." + if unexpected_keys: + msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}." + raise RuntimeError(msg) + + ################### + # ACTUAL SWAPPING # + ################### + + for key, new_val in state_dict.items(): + # no need to account for potential _orig_mod in key here, as torch handles that + old_val = attrgetter(key)(model) + old_val.data = new_val.data.to(device=old_val.device) + # TODO: wanted to use swap_tensors but this somehow does not work on nn.Parameter + # torch.utils.swap_tensors(old_val.data, new_val.data) + + if hotswap: + _check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) + _hotswap_adapter_from_state_dict(self, state_dict, adapter_name) + # the hotswap function raises if there are incompatible keys, so if we reach this point we can set it to None + incompatible_keys = None + else: + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name) if incompatible_keys is not None: # check only for unexpected keys diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index c73a12a4cbf8..7d86bde98bab 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -18,6 +18,7 @@ import os import random import shutil +import subprocess import sys import tempfile import traceback @@ -2014,3 +2015,40 @@ def test_ddpm_ddim_equality_batched(self): # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_images - ddim_images).max() < 1e-1 + + +class TestLoraHotSwapping: + def test_hotswapping_peft_config_incompatible_raises(self): + # TODO + pass + + def test_hotswapping_no_existing_adapter_raises(self): + # TODO + pass + + def test_hotswapping_works(self): + # TODO + pass + + def test_hotswapping_compiled_model_does_not_trigger_recompilation(self): + # TODO: kinda slow, should it get a slow marker? + env = {"TORCH_LOGS": "guards,recompiles"} + here = os.path.dirname(__file__) + file_name = os.path.join(here, "run_compiled_model_hotswap.py") + + process = subprocess.Popen( + [sys.executable, file_name], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + + # Communicate will read the output and error streams, preventing deadlock + stdout, stderr = process.communicate() + exit_code = process.returncode + + # sanity check: + assert exit_code == 0 + + # check that the recompilation message is not present + assert "__recompiles" not in stderr.decode() From 84bae62e2c03d5393fd39f76a0636fd2dfa14e2a Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 18 Sep 2024 12:52:36 +0200 Subject: [PATCH 02/36] Reviewer feedback --- src/diffusers/loaders/lora_pipeline.py | 4 +- src/diffusers/loaders/unet.py | 21 +++- tests/pipelines/run_compiled_model_hotswap.py | 103 ++++++++++++++++++ tests/pipelines/test_pipelines.py | 10 +- 4 files changed, 127 insertions(+), 11 deletions(-) create mode 100644 tests/pipelines/run_compiled_model_hotswap.py diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a54b82161ecc..ad68b5879b01 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -238,7 +238,9 @@ def lora_state_dict( return state_dict, network_alphas @classmethod - def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, hotswap: bool = False): + def load_lora_into_unet( + cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, hotswap: bool = False + ): """ This will load the LoRA layers specified in `state_dict` into `unet`. diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 970416ea1dd9..e40fd5f6cd30 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -66,7 +66,12 @@ class UNet2DConditionLoadersMixin: unet_name = UNET_NAME @validate_hf_hub_args - def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], hotswap: bool = False, **kwargs): + def load_attn_procs( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + hotswap: bool = False, + **kwargs, + ): r""" Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be defined in @@ -270,7 +275,9 @@ def _process_custom_diffusion(self, state_dict): return attn_processors - def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, hotswap: bool = False): + def _process_lora( + self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, hotswap: bool = False + ): # This method does the following things: # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy # format. For legacy format no filtering is applied. @@ -306,7 +313,9 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." ) elif adapter_name not in getattr(self, "peft_config", {}) and hotswap: - raise ValueError(f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name.") + raise ValueError( + f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name." + ) state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used) @@ -340,7 +349,6 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) - def _check_hotswap_configs_compatible(config0, config1): # To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they # use different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the @@ -351,9 +359,10 @@ def _check_hotswap_configs_compatible(config0, config1): config_keys_to_check = ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"] config0 = config0.to_dict() config1 = config1.to_dict() + sentinel = object() for key in config_keys_to_check: - val0 = config0[key] - val1 = config1[key] + val0 = config0.get(key, sentinel) + val1 = config1.get(key, sentinel) if val0 != val1: raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}") diff --git a/tests/pipelines/run_compiled_model_hotswap.py b/tests/pipelines/run_compiled_model_hotswap.py new file mode 100644 index 000000000000..42c9261af9ac --- /dev/null +++ b/tests/pipelines/run_compiled_model_hotswap.py @@ -0,0 +1,103 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch +from peft import LoraConfig + +from diffusers import UNet2DConditionModel +from diffusers.utils.testing_utils import floats_tensor + + +torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + +def get_small_unet(): + # from UNet2DConditionModelTests + init_dict = { + "block_out_channels": (4, 8), + "norm_num_groups": 4, + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 8, + "attention_head_dim": 2, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 1, + "sample_size": 16, + } + model = UNet2DConditionModel(**init_dict) + return model + + +def get_unet_lora_config(): + # from test_models_unet_2d_condition.py + rank = 4 + unet_lora_config = LoraConfig( + r=rank, + lora_alpha=rank, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + return unet_lora_config + + +def get_dummy_input(): + # from UNet2DConditionModelTests + batch_size = 4 + num_channels = 4 + sizes = (16, 16) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + +def check_hotswap(hotswap): + dummy_input = get_dummy_input() + unet = get_small_unet() + # lora_config = get_unet_lora_config() + # unet.add_adapter(lora_config) + unet.to(torch_device) + + # Note: When using the compile flag "reduce-overhead", there will be errors of the type + # > input name: arg861_1. data pointer changed from 139647332027392 to 139647331054592 + unet = torch.compile(unet) + + torch.manual_seed(42) + out0 = unet(**dummy_input) + + if hotswap: + unet.load_lora_weights("ybelkada/sd-1.5-pokemon-lora-peft", adapter_name="foo", hotswap=hotswap) + else: + # offloading the old and loading the new adapter will result in recompilation + unet.set_lora_device(adapter_names=["foo"], device="cpu") + unet.load_lora_weights("ybelkada/sd-1.5-pokemon-lora-peft", adapter_name="bar") + + torch.manual_seed(42) + out1 = unet(**dummy_input) + + # sanity check: since it's the same LoRA, the results should be identical + out0, out1 = np.array(out0.images[0]), np.array(out1.images[0]) + assert not (out0 == 0).all() + assert (out0 == out1).all() + + +if __name__ == "__main__": + # check_hotswap(False) will trigger recompilation + check_hotswap(True) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 7d86bde98bab..6f4638f2dac4 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -76,6 +76,7 @@ require_compel, require_flax, require_onnxruntime, + require_peft_backend, require_torch_2, require_torch_gpu, run_test_in_subprocess, @@ -2030,6 +2031,10 @@ def test_hotswapping_works(self): # TODO pass + @slow + @require_torch_2 + @require_torch_gpu + @require_peft_backend def test_hotswapping_compiled_model_does_not_trigger_recompilation(self): # TODO: kinda slow, should it get a slow marker? env = {"TORCH_LOGS": "guards,recompiles"} @@ -2037,10 +2042,7 @@ def test_hotswapping_compiled_model_does_not_trigger_recompilation(self): file_name = os.path.join(here, "run_compiled_model_hotswap.py") process = subprocess.Popen( - [sys.executable, file_name], - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE + [sys.executable, file_name], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) # Communicate will read the output and error streams, preventing deadlock From 63ece9d4d0ce92c0a4e63029a7e822da4d1fd6a9 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 16 Oct 2024 15:31:45 +0200 Subject: [PATCH 03/36] Reviewer feedback, adjust test --- src/diffusers/loaders/unet.py | 2 - tests/pipelines/run_compiled_model_hotswap.py | 102 +++++++++++++----- tests/pipelines/test_pipelines.py | 14 +-- 3 files changed, 80 insertions(+), 38 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index e40fd5f6cd30..8043300eacaf 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -440,8 +440,6 @@ def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name): # no need to account for potential _orig_mod in key here, as torch handles that old_val = attrgetter(key)(model) old_val.data = new_val.data.to(device=old_val.device) - # TODO: wanted to use swap_tensors but this somehow does not work on nn.Parameter - # torch.utils.swap_tensors(old_val.data, new_val.data) if hotswap: _check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) diff --git a/tests/pipelines/run_compiled_model_hotswap.py b/tests/pipelines/run_compiled_model_hotswap.py index 42c9261af9ac..af37e9f015a3 100644 --- a/tests/pipelines/run_compiled_model_hotswap.py +++ b/tests/pipelines/run_compiled_model_hotswap.py @@ -11,13 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""This is a standalone script that checks that we can hotswap a LoRA adapter on a compiles model +By itself, this script is not super interesting but when we collect the compile logs, we can check that hotswapping +does not trigger recompilation. This is done in the TestLoraHotSwapping class in test_pipelines.py. + +Running this script with `check_hotswap(False)` will load the LoRA adapter without hotswapping, which will result in +recompilation. + +""" + +import os +import tempfile -import numpy as np import torch -from peft import LoraConfig +from peft import LoraConfig, get_peft_model_state_dict +from peft.tuners.tuners_utils import BaseTunerLayer -from diffusers import UNet2DConditionModel +from diffusers import StableDiffusionPipeline, UNet2DConditionModel from diffusers.utils.testing_utils import floats_tensor @@ -26,6 +37,7 @@ def get_small_unet(): # from UNet2DConditionModelTests + torch.manual_seed(0) init_dict = { "block_out_channels": (4, 8), "norm_num_groups": 4, @@ -39,7 +51,7 @@ def get_small_unet(): "sample_size": 16, } model = UNet2DConditionModel(**init_dict) - return model + return model.to(torch_device) def get_unet_lora_config(): @@ -68,34 +80,70 @@ def get_dummy_input(): return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} -def check_hotswap(hotswap): - dummy_input = get_dummy_input() - unet = get_small_unet() - # lora_config = get_unet_lora_config() - # unet.add_adapter(lora_config) - unet.to(torch_device) +def get_lora_state_dicts(modules_to_save): + state_dicts = {} + for module_name, module in modules_to_save.items(): + if module is not None: + state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) + return state_dicts - # Note: When using the compile flag "reduce-overhead", there will be errors of the type - # > input name: arg861_1. data pointer changed from 139647332027392 to 139647331054592 - unet = torch.compile(unet) - torch.manual_seed(42) - out0 = unet(**dummy_input) +def set_lora_device(model, adapter_names, device): + # copied from LoraBaseMixin.set_lora_device + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + for adapter_name in adapter_names: + module.lora_A[adapter_name].to(device) + module.lora_B[adapter_name].to(device) + # this is a param, not a module, so device placement is not in-place -> re-assign + if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None: + if adapter_name in module.lora_magnitude_vector: + module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[adapter_name].to( + device + ) - if hotswap: - unet.load_lora_weights("ybelkada/sd-1.5-pokemon-lora-peft", adapter_name="foo", hotswap=hotswap) - else: - # offloading the old and loading the new adapter will result in recompilation - unet.set_lora_device(adapter_names=["foo"], device="cpu") - unet.load_lora_weights("ybelkada/sd-1.5-pokemon-lora-peft", adapter_name="bar") +def check_hotswap(do_hotswap): + dummy_input = get_dummy_input() + unet = get_small_unet() + lora_config = get_unet_lora_config() + unet.add_adapter(lora_config) torch.manual_seed(42) - out1 = unet(**dummy_input) - - # sanity check: since it's the same LoRA, the results should be identical - out0, out1 = np.array(out0.images[0]), np.array(out1.images[0]) - assert not (out0 == 0).all() - assert (out0 == out1).all() + out_base = unet(**dummy_input)["sample"] + # sanity check + assert not (out_base == 0).all() + + with tempfile.TemporaryDirectory() as tmp_dirname: + lora_state_dicts = get_lora_state_dicts({"unet": unet}) + StableDiffusionPipeline.save_lora_weights( + save_directory=tmp_dirname, safe_serialization=True, **lora_state_dicts + ) + del unet + + unet = get_small_unet() + file_name = os.path.join(tmp_dirname, "pytorch_lora_weights.safetensors") + unet.load_attn_procs(file_name) + # unet = torch.compile(unet, mode="reduce-overhead") + + torch.manual_seed(42) + out0 = unet(**dummy_input)["sample"] + + # sanity check: still same result + atol, rtol = 1e-5, 1e-5 + assert torch.allclose(out_base, out0, atol=atol, rtol=rtol) + + if do_hotswap: + unet.load_attn_procs(file_name, adapter_name="default_0", hotswap=True) + else: + # offloading the old and loading the new adapter will result in recompilation + set_lora_device(unet, adapter_names=["default_0"], device="cpu") + unet.load_attn_procs(file_name, adapter_name="other_name", hotswap=False) + + torch.manual_seed(42) + out1 = unet(**dummy_input)["sample"] + + # sanity check: since it's the same LoRA, the results should be identical + assert torch.allclose(out0, out1, atol=atol, rtol=rtol) if __name__ == "__main__": diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 6f4638f2dac4..1c505bf01d53 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2019,17 +2019,13 @@ def test_ddpm_ddim_equality_batched(self): class TestLoraHotSwapping: - def test_hotswapping_peft_config_incompatible_raises(self): - # TODO - pass + """Test that hotswapping does not result in recompilation. - def test_hotswapping_no_existing_adapter_raises(self): - # TODO - pass + We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively + tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require + recompilation. - def test_hotswapping_works(self): - # TODO - pass + """ @slow @require_torch_2 From c7378ed93073ffd06ee19c27385a11c2d6eabcb4 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 16 Oct 2024 15:44:25 +0200 Subject: [PATCH 04/36] Fix, doc --- tests/pipelines/run_compiled_model_hotswap.py | 2 +- tests/pipelines/test_pipelines.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/run_compiled_model_hotswap.py b/tests/pipelines/run_compiled_model_hotswap.py index af37e9f015a3..cdce553ccd40 100644 --- a/tests/pipelines/run_compiled_model_hotswap.py +++ b/tests/pipelines/run_compiled_model_hotswap.py @@ -123,7 +123,7 @@ def check_hotswap(do_hotswap): unet = get_small_unet() file_name = os.path.join(tmp_dirname, "pytorch_lora_weights.safetensors") unet.load_attn_procs(file_name) - # unet = torch.compile(unet, mode="reduce-overhead") + unet = torch.compile(unet, mode="reduce-overhead") torch.manual_seed(42) out0 = unet(**dummy_input)["sample"] diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 484a929f4811..ce12cea31c2e 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2070,6 +2070,9 @@ class TestLoraHotSwapping: tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require recompilation. + The reason why we need to shell out instead of just running the script inside of the test is that shelling out is + required to collect the torch.compile logs. + """ @slow From 7c67b388a9d49aad6b30c08b845d660e651f3de7 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 16 Oct 2024 15:46:41 +0200 Subject: [PATCH 05/36] Make fix --- src/diffusers/loaders/lora_pipeline.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index edb38ebdb1cd..28a81df9662e 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -807,7 +807,14 @@ def lora_state_dict( @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet def load_lora_into_unet( - cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, + state_dict, + network_alphas, + unet, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -829,6 +836,7 @@ def load_lora_into_unet( low_cpu_mem_usage (`boo`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap TODO """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -852,6 +860,7 @@ def load_lora_into_unet( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod From ea12e0dbc3519dcc78673f7ebf1fc066a1cbdbbc Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 16 Oct 2024 18:30:48 +0200 Subject: [PATCH 06/36] Fix for possible g++ error --- tests/pipelines/test_pipelines.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index ce12cea31c2e..08b7805274ab 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2081,7 +2081,8 @@ class TestLoraHotSwapping: @require_peft_backend def test_hotswapping_compiled_model_does_not_trigger_recompilation(self): # TODO: kinda slow, should it get a slow marker? - env = {"TORCH_LOGS": "guards,recompiles"} + env = os.environ.copy() + env["TORCH_LOGS"] = "guards,recompiles" here = os.path.dirname(__file__) file_name = os.path.join(here, "run_compiled_model_hotswap.py") From ec4b0d5d19a0d4ebfd5f1efca38649407736b4dc Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 18 Oct 2024 14:56:33 +0200 Subject: [PATCH 07/36] Add test for recompilation w/o hotswapping --- tests/pipelines/run_compiled_model_hotswap.py | 8 +++++--- tests/pipelines/test_pipelines.py | 19 +++++++++++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/pipelines/run_compiled_model_hotswap.py b/tests/pipelines/run_compiled_model_hotswap.py index cdce553ccd40..665f266215f3 100644 --- a/tests/pipelines/run_compiled_model_hotswap.py +++ b/tests/pipelines/run_compiled_model_hotswap.py @@ -22,6 +22,7 @@ """ import os +import sys import tempfile import torch @@ -36,7 +37,7 @@ def get_small_unet(): - # from UNet2DConditionModelTests + # from UNet2DConditionModelTests torch.manual_seed(0) init_dict = { "block_out_channels": (4, 8), @@ -147,5 +148,6 @@ def check_hotswap(do_hotswap): if __name__ == "__main__": - # check_hotswap(False) will trigger recompilation - check_hotswap(True) + # check_hotswap(True) does not trigger recompilation + # check_hotswap(False) triggers recompilation + check_hotswap(do_hotswap=sys.argv[1] == "1") diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 08b7805274ab..875a9c88c1d5 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2080,14 +2080,14 @@ class TestLoraHotSwapping: @require_torch_gpu @require_peft_backend def test_hotswapping_compiled_model_does_not_trigger_recompilation(self): - # TODO: kinda slow, should it get a slow marker? env = os.environ.copy() env["TORCH_LOGS"] = "guards,recompiles" here = os.path.dirname(__file__) file_name = os.path.join(here, "run_compiled_model_hotswap.py") + # first test with hotswapping: should not trigger recompilation process = subprocess.Popen( - [sys.executable, file_name], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE + [sys.executable, file_name, "1"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) # Communicate will read the output and error streams, preventing deadlock @@ -2099,3 +2099,18 @@ def test_hotswapping_compiled_model_does_not_trigger_recompilation(self): # check that the recompilation message is not present assert "__recompiles" not in stderr.decode() + + # next, contingency check: without hotswapping, we *do* get recompilation + process = subprocess.Popen( + [sys.executable, file_name, "0"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + # Communicate will read the output and error streams, preventing deadlock + stdout, stderr = process.communicate() + exit_code = process.returncode + + # sanity check: + assert exit_code == 0 + + # check that the recompilation message is not present + assert "__recompiles" in stderr.decode() From 488f2f0d49b2ff18ce2ee4db13d6a1e02c31fa70 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 7 Feb 2025 14:48:57 +0100 Subject: [PATCH 08/36] Make hotswap work Requires https://github.com/huggingface/peft/pull/2366 More changes to make hotswapping work. Together with the mentioned PEFT PR, the tests pass for me locally. List of changes: - docstring for hotswap - remove code copied from PEFT, import from PEFT now - adjustments to PeftAdapterMixin.load_lora_adapter (unfortunately, some state dict renaming was necessary, LMK if there is a better solution) - adjustments to UNet2DConditionLoadersMixin._process_lora: LMK if this is even necessary or not, I'm unsure what the overall relationship is between this and PeftAdapterMixin.load_lora_adapter - also in UNet2DConditionLoadersMixin._process_lora, I saw that there is no LoRA unloading when loading the adapter fails, so I added it there (in line with what happens in PeftAdapterMixin.load_lora_adapter) - rewritten tests to avoid shelling out, make the test more precise by making sure that the outputs align, parametrize it - also checked the pipeline code mentioned in this comment: https://github.com/huggingface/diffusers/pull/9453#issuecomment-2418508871; when running this inside the with torch._dynamo.config.patch(error_on_recompile=True) context, there is no error, so I think hotswapping is now working with pipelines. --- src/diffusers/loaders/peft.py | 73 +++++++- src/diffusers/loaders/unet.py | 153 ++++++---------- tests/pipelines/run_compiled_model_hotswap.py | 153 ---------------- tests/pipelines/test_pipelines.py | 168 ++++++++++++++---- 4 files changed, 256 insertions(+), 291 deletions(-) delete mode 100644 tests/pipelines/run_compiled_model_hotswap.py diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 0d26738eec62..d93d881683e9 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -138,7 +138,9 @@ def _optionally_disable_offloading(cls, _pipeline): """ return _func_optionally_disable_offloading(_pipeline=_pipeline) - def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): + def load_lora_adapter( + self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs + ): r""" Loads a LoRA adapter into the underlying model. @@ -182,6 +184,28 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place. + This means that, instead of loading an additional adapter, this will take the existing adapter weights + and replace them with the weights of the new adapter. This can be faster and more memory efficient. + However, the main advantage of hotswapping is that when the model is compiled with torch.compile, + loading the new adapter does not require recompilation of the model. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + from peft.utils.hotswap import prepare_model_for_compiled_hotswap + + model = ... # load diffusers model with first LoRA adapter + max_rank = ... # the highest rank among all LoRAs that you want to load + prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling + model = torch.compile(model) + model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter + ``` + + There are some limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/v0.14.0/en/package_reference/hotswap """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer @@ -235,10 +259,15 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys} if len(state_dict) > 0: - if adapter_name in getattr(self, "peft_config", {}): + if adapter_name in getattr(self, "peft_config", {}) and not hotswap: raise ValueError( f"Adapter name {adapter_name} already in use in the model - please select a new adapter name." ) + elif adapter_name not in getattr(self, "peft_config", {}) and hotswap: + raise ValueError( + f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. " + "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." + ) # check with first key if is not in peft format first_key = next(iter(state_dict.keys())) @@ -296,11 +325,47 @@ def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="trans if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + if hotswap: + try: + from peft.utils.hotswap import _check_hotswap_configs_compatible, hotswap_adapter_from_state_dict + except ImportError as exc: + msg = ( + "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " + "from source." + ) + raise ImportError(msg) from exc + + if hotswap: + + def map_state_dict_for_hotswap(sd): + # For hotswapping, we need the adapter name to be present in the state dict keys + new_sd = {} + for k, v in sd.items(): + if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"): + k = k[:-7] + f".{adapter_name}.weight" + elif k.endswith("lora_B.bias"): # lora_bias=True option + k = k[:-5] + f".{adapter_name}.bias" + new_sd[k] = v + return new_sd + # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, # we should also delete the `peft_config` associated to the `adapter_name`. try: - inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + if hotswap: + state_dict = map_state_dict_for_hotswap(state_dict) + _check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) + hotswap_adapter_from_state_dict( + model=self, + state_dict=state_dict, + adapter_name=adapter_name, + config=lora_config, + ) + # the hotswap function raises if there are incompatible keys, so if we reach this point we can set + # it to None + incompatible_keys = None + else: + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) except Exception as e: # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. if hasattr(self, "peft_config"): diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 46dea6da25da..f6a6e8eb671b 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -308,6 +308,7 @@ def _process_lora( raise ValueError("PEFT backend is required for this method.") from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + from peft.tuners.tuners_utils import BaseTunerLayer keys = list(state_dict.keys()) @@ -333,7 +334,8 @@ def _process_lora( ) elif adapter_name not in getattr(self, "peft_config", {}) and hotswap: raise ValueError( - f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name." + f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. " + "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." ) state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used) @@ -382,106 +384,59 @@ def _process_lora( if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - def _check_hotswap_configs_compatible(config0, config1): - # To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they - # use different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the - # weights from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these - # values as well, but that's not implemented yet, and it would trigger a re-compilation if the model is compiled. - - # TODO: This is a very rough check at the moment and there are probably better ways than to error out - config_keys_to_check = ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"] - config0 = config0.to_dict() - config1 = config1.to_dict() - sentinel = object() - for key in config_keys_to_check: - val0 = config0.get(key, sentinel) - val1 = config1.get(key, sentinel) - if val0 != val1: - raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}") - - def _hotswap_adapter_from_state_dict(model, state_dict, adapter_name): - """ - Swap out the LoRA weights from the model with the weights from state_dict. - - It is assumed that the existing adapter and the new adapter are compatible. - - Args: - model: nn.Module - The model with the loaded adapter. - state_dict: dict[str, torch.Tensor] - The state dict of the new adapter, which needs to be compatible (targeting same modules etc.). - adapter_name: Optional[str] - The name of the adapter that should be hot-swapped. - - Raises: - RuntimeError - If the old and the new adapter are not compatible, a RuntimeError is raised. - """ - from operator import attrgetter - - ####################### - # INSERT ADAPTER NAME # - ####################### - - remapped_state_dict = {} - expected_str = adapter_name + "." - for key, val in state_dict.items(): - if expected_str not in key: - prefix, _, suffix = key.rpartition(".") - key = f"{prefix}.{adapter_name}.{suffix}" - remapped_state_dict[key] = val - state_dict = remapped_state_dict - - #################### - # CHECK STATE_DICT # - #################### - - # Ensure that all the keys of the new adapter correspond exactly to the keys of the old adapter, otherwise - # hot-swapping is not possible - parameter_prefix = "lora_" # hard-coded for now - is_compiled = hasattr(model, "_orig_mod") - # TODO: there is probably a more precise way to identify the adapter keys - missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)} - unexpected_keys = set() - - # first: dry run, not swapping anything - for key, new_val in state_dict.items(): - try: - old_val = attrgetter(key)(model) - except AttributeError: - unexpected_keys.add(key) - continue - - if is_compiled: - missing_keys.remove("_orig_mod." + key) - else: - missing_keys.remove(key) - - if missing_keys or unexpected_keys: - msg = "Hot swapping the adapter did not succeed." - if missing_keys: - msg += f" Missing keys: {', '.join(sorted(missing_keys))}." - if unexpected_keys: - msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}." - raise RuntimeError(msg) - - ################### - # ACTUAL SWAPPING # - ################### - - for key, new_val in state_dict.items(): - # no need to account for potential _orig_mod in key here, as torch handles that - old_val = attrgetter(key)(model) - old_val.data = new_val.data.to(device=old_val.device) + if hotswap: + try: + from peft.utils.hotswap import _check_hotswap_configs_compatible, hotswap_adapter_from_state_dict + except ImportError as exc: + msg = ( + "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " + "from source." + ) + raise ImportError(msg) from exc if hotswap: - _check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) - _hotswap_adapter_from_state_dict(self, state_dict, adapter_name) - # the hotswap function raises if there are incompatible keys, so if we reach this point we can set it to None - incompatible_keys = None - else: - inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + + def map_state_dict_for_hotswap(sd): + # For hotswapping, we need the adapter name to be present in the state dict keys + new_sd = {} + for k, v in sd.items(): + if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"): + k = k[:-7] + f".{adapter_name}.weight" + elif k.endswith("lora_B.bias"): # lora_bias=True option + k = k[:-5] + f".{adapter_name}.bias" + new_sd[k] = v + return new_sd + + # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, + # we should also delete the `peft_config` associated to the `adapter_name`. + try: + if hotswap: + _check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) + hotswap_adapter_from_state_dict( + model=self, + state_dict=state_dict, + adapter_name=adapter_name, + config=lora_config, + ) + # the hotswap function raises if there are incompatible keys, so if we reach this point we can set + # it to None + incompatible_keys = None + else: + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + except Exception as e: + # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. + if hasattr(self, "peft_config"): + for module in self.modules(): + if isinstance(module, BaseTunerLayer): + active_adapters = module.active_adapters + for active_adapter in active_adapters: + if adapter_name in active_adapter: + module.delete_adapter(adapter_name) + + self.peft_config.pop(adapter_name) + logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") + raise warn_msg = "" if incompatible_keys is not None: diff --git a/tests/pipelines/run_compiled_model_hotswap.py b/tests/pipelines/run_compiled_model_hotswap.py deleted file mode 100644 index 665f266215f3..000000000000 --- a/tests/pipelines/run_compiled_model_hotswap.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2024 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""This is a standalone script that checks that we can hotswap a LoRA adapter on a compiles model - -By itself, this script is not super interesting but when we collect the compile logs, we can check that hotswapping -does not trigger recompilation. This is done in the TestLoraHotSwapping class in test_pipelines.py. - -Running this script with `check_hotswap(False)` will load the LoRA adapter without hotswapping, which will result in -recompilation. - -""" - -import os -import sys -import tempfile - -import torch -from peft import LoraConfig, get_peft_model_state_dict -from peft.tuners.tuners_utils import BaseTunerLayer - -from diffusers import StableDiffusionPipeline, UNet2DConditionModel -from diffusers.utils.testing_utils import floats_tensor - - -torch_device = "cuda" if torch.cuda.is_available() else "cpu" - - -def get_small_unet(): - # from UNet2DConditionModelTests - torch.manual_seed(0) - init_dict = { - "block_out_channels": (4, 8), - "norm_num_groups": 4, - "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), - "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), - "cross_attention_dim": 8, - "attention_head_dim": 2, - "out_channels": 4, - "in_channels": 4, - "layers_per_block": 1, - "sample_size": 16, - } - model = UNet2DConditionModel(**init_dict) - return model.to(torch_device) - - -def get_unet_lora_config(): - # from test_models_unet_2d_condition.py - rank = 4 - unet_lora_config = LoraConfig( - r=rank, - lora_alpha=rank, - target_modules=["to_q", "to_k", "to_v", "to_out.0"], - init_lora_weights=False, - use_dora=False, - ) - return unet_lora_config - - -def get_dummy_input(): - # from UNet2DConditionModelTests - batch_size = 4 - num_channels = 4 - sizes = (16, 16) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} - - -def get_lora_state_dicts(modules_to_save): - state_dicts = {} - for module_name, module in modules_to_save.items(): - if module is not None: - state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) - return state_dicts - - -def set_lora_device(model, adapter_names, device): - # copied from LoraBaseMixin.set_lora_device - for module in model.modules(): - if isinstance(module, BaseTunerLayer): - for adapter_name in adapter_names: - module.lora_A[adapter_name].to(device) - module.lora_B[adapter_name].to(device) - # this is a param, not a module, so device placement is not in-place -> re-assign - if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None: - if adapter_name in module.lora_magnitude_vector: - module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[adapter_name].to( - device - ) - - -def check_hotswap(do_hotswap): - dummy_input = get_dummy_input() - unet = get_small_unet() - lora_config = get_unet_lora_config() - unet.add_adapter(lora_config) - torch.manual_seed(42) - out_base = unet(**dummy_input)["sample"] - # sanity check - assert not (out_base == 0).all() - - with tempfile.TemporaryDirectory() as tmp_dirname: - lora_state_dicts = get_lora_state_dicts({"unet": unet}) - StableDiffusionPipeline.save_lora_weights( - save_directory=tmp_dirname, safe_serialization=True, **lora_state_dicts - ) - del unet - - unet = get_small_unet() - file_name = os.path.join(tmp_dirname, "pytorch_lora_weights.safetensors") - unet.load_attn_procs(file_name) - unet = torch.compile(unet, mode="reduce-overhead") - - torch.manual_seed(42) - out0 = unet(**dummy_input)["sample"] - - # sanity check: still same result - atol, rtol = 1e-5, 1e-5 - assert torch.allclose(out_base, out0, atol=atol, rtol=rtol) - - if do_hotswap: - unet.load_attn_procs(file_name, adapter_name="default_0", hotswap=True) - else: - # offloading the old and loading the new adapter will result in recompilation - set_lora_device(unet, adapter_names=["default_0"], device="cpu") - unet.load_attn_procs(file_name, adapter_name="other_name", hotswap=False) - - torch.manual_seed(42) - out1 = unet(**dummy_input)["sample"] - - # sanity check: since it's the same LoRA, the results should be identical - assert torch.allclose(out0, out1, atol=atol, rtol=rtol) - - -if __name__ == "__main__": - # check_hotswap(True) does not trigger recompilation - # check_hotswap(False) triggers recompilation - check_hotswap(do_hotswap=sys.argv[1] == "1") diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 9b534a92c5d3..8cc000541c32 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -18,7 +18,6 @@ import os import random import shutil -import subprocess import sys import tempfile import traceback @@ -2179,54 +2178,153 @@ def test_ddpm_ddim_equality_batched(self): assert np.abs(ddpm_images - ddim_images).max() < 1e-1 -class TestLoraHotSwapping: +class TestLoraHotSwapping(unittest.TestCase): """Test that hotswapping does not result in recompilation. We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require recompilation. - The reason why we need to shell out instead of just running the script inside of the test is that shelling out is - required to collect the torch.compile logs. + See + https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252 + for the analogous PEFT test. """ - @slow - @require_torch_2 - @require_torch_gpu - @require_peft_backend - def test_hotswapping_compiled_model_does_not_trigger_recompilation(self): - env = os.environ.copy() - env["TORCH_LOGS"] = "guards,recompiles" - here = os.path.dirname(__file__) - file_name = os.path.join(here, "run_compiled_model_hotswap.py") - - # first test with hotswapping: should not trigger recompilation - process = subprocess.Popen( - [sys.executable, file_name, "1"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) + def tearDown(self): + # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, + # there will be recompilation errors, as torch caches the model when run in the same process. + super().tearDown() + torch._dynamo.reset() + gc.collect() + backend_empty_cache(torch_device) - # Communicate will read the output and error streams, preventing deadlock - stdout, stderr = process.communicate() - exit_code = process.returncode + def get_small_unet(self): + # from diffusers UNet2DConditionModelTests + torch.manual_seed(0) + init_dict = { + "block_out_channels": (4, 8), + "norm_num_groups": 4, + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 8, + "attention_head_dim": 2, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 1, + "sample_size": 16, + } + model = UNet2DConditionModel(**init_dict) + return model.to(torch_device) + + def get_unet_lora_config(self, lora_rank, lora_alpha): + # from diffusers test_models_unet_2d_condition.py + from peft import LoraConfig + + unet_lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + return unet_lora_config + + def get_dummy_input(self): + # from UNet2DConditionModelTests + batch_size = 4 + num_channels = 4 + sizes = (16, 16) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + def check_hotswap(self, do_compile, rank0, rank1): + """ + Check that hotswapping works on a small unet. + + Steps: + - create 2 LoRA adapters and save them + - load the first adapter + - hotswap the second adapter + - check that the outputs are correct + - optionally compile the model + + Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would + fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is + fine. + + """ + from peft.utils.hotswap import prepare_model_for_compiled_hotswap + + dummy_input = self.get_dummy_input() + alpha0, alpha1 = rank0, rank1 + max_rank = max([rank0, rank1]) + lora_config0 = self.get_unet_lora_config(rank0, alpha0) + lora_config1 = self.get_unet_lora_config(rank1, alpha1) + + unet = self.get_small_unet() + unet.add_adapter(lora_config0, adapter_name="adapter0") + with torch.inference_mode(): + output0_before = unet(**dummy_input)["sample"] + + unet.add_adapter(lora_config1, adapter_name="adapter1") + unet.set_adapter("adapter1") + with torch.inference_mode(): + output1_before = unet(**dummy_input)["sample"] # sanity check: - assert exit_code == 0 + tol = 5e-3 + assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol) + + with tempfile.TemporaryDirectory() as tmp_dirname: + unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0") + unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1") + del unet + + unet = self.get_small_unet() + file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors") + file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors") + unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0") + + if do_compile or (rank0 != rank1): + # no need to prepare if the model is not compiled or if the ranks are identical + prepare_model_for_compiled_hotswap( + unet, + config={"adapter0": lora_config0, "adapter1": lora_config1}, + target_rank=max_rank, + ) + if do_compile: + unet = torch.compile(unet, mode="reduce-overhead") - # check that the recompilation message is not present - assert "__recompiles" not in stderr.decode() + with torch.inference_mode(): + output0_after = unet(**dummy_input)["sample"] + assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) - # next, contingency check: without hotswapping, we *do* get recompilation - process = subprocess.Popen( - [sys.executable, file_name, "0"], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) + unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True) - # Communicate will read the output and error streams, preventing deadlock - stdout, stderr = process.communicate() - exit_code = process.returncode + # we need to call forward to potentially trigger recompilation + with torch.inference_mode(): + output1_after = unet(**dummy_input)["sample"] + assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) - # sanity check: - assert exit_code == 0 + @slow + @require_torch_2 + @require_torch_accelerator + @require_peft_backend + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + def test_hotswapping_diffusers_model(self, rank0, rank1): + self.check_hotswap(do_compile=False, rank0=rank0, rank1=rank1) - # check that the recompilation message is not present - assert "__recompiles" in stderr.decode() + @slow + @require_torch_2 + @require_torch_accelerator + @require_peft_backend + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + def test_hotswapping_compiled_diffusers_model(self, rank0, rank1): + # It's important to add this context to raise an error on recompilation + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_hotswap(do_compile=True, rank0=rank0, rank1=rank1) From 5ab14604ac688ed2e03866005c6e9ecd351a13e0 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 10 Feb 2025 13:04:35 +0100 Subject: [PATCH 09/36] Address reviewer feedback: - Revert deprecated method - Fix PEFT doc link to main - Don't use private function - Clarify magic numbers - Add pipeline test Moreover: - Extend docstrings - Extend existing test for outputs != 0 - Extend existing test for wrong adapter name --- src/diffusers/loaders/lora_pipeline.py | 69 +++++++++++++++- src/diffusers/loaders/peft.py | 10 +-- src/diffusers/loaders/unet.py | 13 +-- tests/pipelines/test_pipelines.py | 105 ++++++++++++++++++++++++- 4 files changed, 178 insertions(+), 19 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 92d5fbd7edf0..865b95aa641b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -107,7 +107,28 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap TODO + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place. + This means that, instead of loading an additional adapter, this will take the existing adapter weights + and replace them with the weights of the new adapter. This can be faster and more memory efficient. + However, the main advantage of hotswapping is that when the model is compiled with torch.compile, + loading the new adapter does not require recompilation of the model. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + from peft.utils.hotswap import prepare_model_for_compiled_hotswap + + model = ... # load diffusers model with first LoRA adapter + max_rank = ... # the highest rank among all LoRAs that you want to load + prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling + model = torch.compile(model) + model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter + ``` + + There are some limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -298,7 +319,28 @@ def load_lora_into_unet( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. - hotswap TODO + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place. + This means that, instead of loading an additional adapter, this will take the existing adapter weights + and replace them with the weights of the new adapter. This can be faster and more memory efficient. + However, the main advantage of hotswapping is that when the model is compiled with torch.compile, + loading the new adapter does not require recompilation of the model. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + from peft.utils.hotswap import prepare_model_for_compiled_hotswap + + model = ... # load diffusers model with first LoRA adapter + max_rank = ... # the highest rank among all LoRAs that you want to load + prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling + model = torch.compile(model) + model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter + ``` + + There are some limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -747,7 +789,28 @@ def load_lora_into_unet( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. - hotswap TODO + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place. + This means that, instead of loading an additional adapter, this will take the existing adapter weights + and replace them with the weights of the new adapter. This can be faster and more memory efficient. + However, the main advantage of hotswapping is that when the model is compiled with torch.compile, + loading the new adapter does not require recompilation of the model. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + from peft.utils.hotswap import prepare_model_for_compiled_hotswap + + model = ... # load diffusers model with first LoRA adapter + max_rank = ... # the highest rank among all LoRAs that you want to load + prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling + model = torch.compile(model) + model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter + ``` + + There are some limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index d93d881683e9..84ead88eabfa 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -205,7 +205,7 @@ def load_lora_adapter( ``` There are some limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/v0.14.0/en/package_reference/hotswap + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer @@ -327,7 +327,7 @@ def load_lora_adapter( if hotswap: try: - from peft.utils.hotswap import _check_hotswap_configs_compatible, hotswap_adapter_from_state_dict + from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict except ImportError as exc: msg = ( "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " @@ -342,9 +342,9 @@ def map_state_dict_for_hotswap(sd): new_sd = {} for k, v in sd.items(): if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"): - k = k[:-7] + f".{adapter_name}.weight" + k = k[: -len(".weight")] + f".{adapter_name}.weight" elif k.endswith("lora_B.bias"): # lora_bias=True option - k = k[:-5] + f".{adapter_name}.bias" + k = k[: -len(".bias")] + f".{adapter_name}.bias" new_sd[k] = v return new_sd @@ -353,7 +353,7 @@ def map_state_dict_for_hotswap(sd): try: if hotswap: state_dict = map_state_dict_for_hotswap(state_dict) - _check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) + check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) hotswap_adapter_from_state_dict( model=self, state_dict=state_dict, diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index f6a6e8eb671b..8bd7eb16f1c5 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -64,12 +64,7 @@ class UNet2DConditionLoadersMixin: unet_name = UNET_NAME @validate_hf_hub_args - def load_attn_procs( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - hotswap: bool = False, - **kwargs, - ): + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be defined in @@ -121,7 +116,6 @@ def load_attn_procs( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap TODO Example: @@ -227,7 +221,6 @@ def load_attn_procs( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, ) else: raise ValueError( @@ -386,7 +379,7 @@ def _process_lora( if hotswap: try: - from peft.utils.hotswap import _check_hotswap_configs_compatible, hotswap_adapter_from_state_dict + from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict except ImportError as exc: msg = ( "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " @@ -411,7 +404,7 @@ def map_state_dict_for_hotswap(sd): # we should also delete the `peft_config` associated to the `adapter_name`. try: if hotswap: - _check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) + check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) hotswap_adapter_from_state_dict( model=self, state_dict=state_dict, diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 8cc000541c32..c9783d634fa8 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2276,9 +2276,11 @@ def check_hotswap(self, do_compile, rank0, rank1): with torch.inference_mode(): output1_before = unet(**dummy_input)["sample"] - # sanity check: + # sanity checks: tol = 5e-3 assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol) + assert not (output0_before == 0).all() + assert not (output1_before == 0).all() with tempfile.TemporaryDirectory() as tmp_dirname: unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0") @@ -2311,6 +2313,12 @@ def check_hotswap(self, do_compile, rank0, rank1): output1_after = unet(**dummy_input)["sample"] assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) + # check error when not passing valid adapter name + name = "does-not-exist" + msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" + with self.assertRaisesRegex(ValueError, msg): + unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True) + @slow @require_torch_2 @require_torch_accelerator @@ -2328,3 +2336,98 @@ def test_hotswapping_compiled_diffusers_model(self, rank0, rank1): # It's important to add this context to raise an error on recompilation with torch._dynamo.config.patch(error_on_recompile=True): self.check_hotswap(do_compile=True, rank0=rank0, rank1=rank1) + + ############ + # PIPELINE # + ############ + + def get_lora_state_dicts(self, modules_to_save, adapter_name): + from peft import get_peft_model_state_dict + + state_dicts = {} + for module_name, module in modules_to_save.items(): + if module is not None: + state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict( + module, adapter_name=adapter_name + ) + return state_dicts + + def get_dummy_input_pipeline(self): + pipeline_inputs = { + "prompt": "A painting of a squirrel eating a burger", + "num_inference_steps": 5, + "guidance_scale": 6.0, + "output_type": "np", + "return_dict": False, + } + return pipeline_inputs + + def check_pipeline_hotswap(self, rank0, rank1): + # Similar to check_hotswap but more realistic: check a whole pipeline to be closer to how users would use it + from peft.utils.hotswap import prepare_model_for_compiled_hotswap + + dummy_input = self.get_dummy_input_pipeline() + pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + + alpha0, alpha1 = rank0, rank1 + max_rank = max([rank0, rank1]) + lora_config0 = self.get_unet_lora_config(rank0, alpha0) + lora_config1 = self.get_unet_lora_config(rank1, alpha1) + + pipeline.unet.add_adapter(lora_config0, adapter_name="adapter0") + output0_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0] + + pipeline.unet.add_adapter(lora_config1, adapter_name="adapter1") + pipeline.unet.set_adapter("adapter1") + output1_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0] + + # sanity check + tol = 1e-3 + assert not np.allclose(output0_before, output1_before, atol=tol, rtol=tol) + assert not (output0_before == 0).all() + assert not (output1_before == 0).all() + + with tempfile.TemporaryDirectory() as tmp_dirname: + lora0_state_dicts = self.get_lora_state_dicts({"unet": pipeline.unet}, adapter_name="adapter0") + StableDiffusionPipeline.save_lora_weights( + save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, **lora0_state_dicts + ) + lora1_state_dicts = self.get_lora_state_dicts({"unet": pipeline.unet}, adapter_name="adapter1") + StableDiffusionPipeline.save_lora_weights( + save_directory=os.path.join(tmp_dirname, "adapter1"), safe_serialization=True, **lora1_state_dicts + ) + del pipeline + + pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + file_name0 = os.path.join(tmp_dirname, "adapter0", "pytorch_lora_weights.safetensors") + file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors") + + pipeline.load_lora_weights(file_name0) + if rank0 != rank1: + prepare_model_for_compiled_hotswap( + pipeline.unet, + config={"adapter0": lora_config0, "adapter1": lora_config1}, + target_rank=max_rank, + ) + + pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead") + output0_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0] + + # sanity check: still same result + assert np.allclose(output0_before, output0_after, atol=tol, rtol=tol) + + pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0") + output1_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0] + + # sanity check: since it's the same LoRA, the results should be identical + assert np.allclose(output1_before, output1_after, atol=tol, rtol=tol) + + @slow + @require_torch_2 + @require_torch_accelerator + @require_peft_backend + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + def test_hotswapping_compiled_diffusers_pipline(self, rank0, rank1): + # It's important to add this context to raise an error on recompilation + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_pipeline_hotswap(rank0=rank0, rank1=rank1) From bc157e6af3eec39c7ca1d37cb5406c1e346bd5d2 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 10 Feb 2025 14:24:23 +0100 Subject: [PATCH 10/36] Change order of test decorators parameterized.expand seems to ignore skip decorators if added in last place (i.e. innermost decorator). --- tests/pipelines/test_pipelines.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index c9783d634fa8..1cfbd77fd3f4 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2319,19 +2319,19 @@ def check_hotswap(self, do_compile, rank0, rank1): with self.assertRaisesRegex(ValueError, msg): unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True) + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa @slow @require_torch_2 @require_torch_accelerator @require_peft_backend - @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_diffusers_model(self, rank0, rank1): self.check_hotswap(do_compile=False, rank0=rank0, rank1=rank1) + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa @slow @require_torch_2 @require_torch_accelerator @require_peft_backend - @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_diffusers_model(self, rank0, rank1): # It's important to add this context to raise an error on recompilation with torch._dynamo.config.patch(error_on_recompile=True): @@ -2422,11 +2422,11 @@ def check_pipeline_hotswap(self, rank0, rank1): # sanity check: since it's the same LoRA, the results should be identical assert np.allclose(output1_before, output1_after, atol=tol, rtol=tol) + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa @slow @require_torch_2 @require_torch_accelerator @require_peft_backend - @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_diffusers_pipline(self, rank0, rank1): # It's important to add this context to raise an error on recompilation with torch._dynamo.config.patch(error_on_recompile=True): From bd1da66b4fd27ae7c8c9874cd8ede8613d91be06 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 11 Feb 2025 16:21:09 +0100 Subject: [PATCH 11/36] Split model and pipeline tests Also increase test coverage by also targeting conv2d layers (support of which was added recently on the PEFT PR). --- tests/models/test_modeling_common.py | 189 ++++++++++++++++++++++++++ tests/pipelines/test_pipelines.py | 191 +++++++-------------------- 2 files changed, 238 insertions(+), 142 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index c3cb082b0ef1..4e5d7d0b68ca 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -56,15 +56,19 @@ from diffusers.utils.hub_utils import _add_variant from diffusers.utils.testing_utils import ( CaptureLogger, + backend_empty_cache, + floats_tensor, get_python_version, is_torch_compile, numpy_cosine_similarity_distance, + require_peft_backend, require_torch_2, require_torch_accelerator, require_torch_accelerator_with_training, require_torch_gpu, require_torch_multi_gpu, run_test_in_subprocess, + slow, torch_all_close, torch_device, ) @@ -1519,3 +1523,188 @@ def test_push_to_hub_library_name(self): # Reset repo delete_repo(self.repo_id, token=TOKEN) + + +class TestLoraHotSwappingForModel(unittest.TestCase): + """Test that hotswapping does not result in recompilation on the model directly. + + We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively + tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require + recompilation. + + See + https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252 + for the analogous PEFT test. + + """ + + def tearDown(self): + # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, + # there will be recompilation errors, as torch caches the model when run in the same process. + super().tearDown() + torch._dynamo.reset() + gc.collect() + backend_empty_cache(torch_device) + + def get_small_unet(self): + # from diffusers UNet2DConditionModelTests + torch.manual_seed(0) + init_dict = { + "block_out_channels": (4, 8), + "norm_num_groups": 4, + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 8, + "attention_head_dim": 2, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 1, + "sample_size": 16, + } + model = UNet2DConditionModel(**init_dict) + return model.to(torch_device) + + def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules): + # from diffusers test_models_unet_2d_condition.py + from peft import LoraConfig + + unet_lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=target_modules, + init_lora_weights=False, + use_dora=False, + ) + return unet_lora_config + + def get_dummy_input(self): + # from UNet2DConditionModelTests + batch_size = 4 + num_channels = 4 + sizes = (16, 16) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + def check_model_hotswap(self, do_compile, rank0, rank1, target_modules): + """ + Check that hotswapping works on a small unet. + + Steps: + - create 2 LoRA adapters and save them + - load the first adapter + - hotswap the second adapter + - check that the outputs are correct + - optionally compile the model + + Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would + fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is + fine. + + """ + from peft.utils.hotswap import prepare_model_for_compiled_hotswap + + dummy_input = self.get_dummy_input() + alpha0, alpha1 = rank0, rank1 + max_rank = max([rank0, rank1]) + lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules) + lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules) + + unet = self.get_small_unet() + unet.add_adapter(lora_config0, adapter_name="adapter0") + with torch.inference_mode(): + output0_before = unet(**dummy_input)["sample"] + + unet.add_adapter(lora_config1, adapter_name="adapter1") + unet.set_adapter("adapter1") + with torch.inference_mode(): + output1_before = unet(**dummy_input)["sample"] + + # sanity checks: + tol = 5e-3 + assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol) + assert not (output0_before == 0).all() + assert not (output1_before == 0).all() + + with tempfile.TemporaryDirectory() as tmp_dirname: + unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0") + unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1") + del unet + + unet = self.get_small_unet() + file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors") + file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors") + unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0") + + if do_compile or (rank0 != rank1): + # no need to prepare if the model is not compiled or if the ranks are identical + prepare_model_for_compiled_hotswap( + unet, + config={"adapter0": lora_config0, "adapter1": lora_config1}, + target_rank=max_rank, + ) + if do_compile: + unet = torch.compile(unet, mode="reduce-overhead") + + with torch.inference_mode(): + output0_after = unet(**dummy_input)["sample"] + assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) + + unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True) + + # we need to call forward to potentially trigger recompilation + with torch.inference_mode(): + output1_after = unet(**dummy_input)["sample"] + assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) + + # check error when not passing valid adapter name + name = "does-not-exist" + msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" + with self.assertRaisesRegex(ValueError, msg): + unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True) + + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + @slow + @require_torch_2 + @require_torch_accelerator + @require_peft_backend + def test_hotswapping_model(self, rank0, rank1): + self.check_model_hotswap( + do_compile=False, rank0=rank0, rank1=rank1, target_modules=["to_q", "to_k", "to_v", "to_out.0"] + ) + + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + @slow + @require_torch_2 + @require_torch_accelerator + @require_peft_backend + def test_hotswapping_compiled_model_linear(self, rank0, rank1): + # It's important to add this context to raise an error on recompilation + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + @slow + @require_torch_2 + @require_torch_accelerator + @require_peft_backend + def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): + # It's important to add this context to raise an error on recompilation + target_modules = ["conv", "conv1", "conv2"] + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + @slow + @require_torch_2 + @require_torch_accelerator + @require_peft_backend + def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1): + # It's important to add this context to raise an error on recompilation + target_modules = ["to_q", "conv"] + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 1cfbd77fd3f4..145609b37f82 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2178,8 +2178,8 @@ def test_ddpm_ddim_equality_batched(self): assert np.abs(ddpm_images - ddim_images).max() < 1e-1 -class TestLoraHotSwapping(unittest.TestCase): - """Test that hotswapping does not result in recompilation. +class TestLoraHotSwappingForPipeline(unittest.TestCase): + """Test that hotswapping does not result in recompilation in a pipeline. We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively tested there. The goal of this test is specifically to ensure that hotswapping with diffusers does not require @@ -2199,148 +2199,19 @@ def tearDown(self): gc.collect() backend_empty_cache(torch_device) - def get_small_unet(self): - # from diffusers UNet2DConditionModelTests - torch.manual_seed(0) - init_dict = { - "block_out_channels": (4, 8), - "norm_num_groups": 4, - "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), - "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), - "cross_attention_dim": 8, - "attention_head_dim": 2, - "out_channels": 4, - "in_channels": 4, - "layers_per_block": 1, - "sample_size": 16, - } - model = UNet2DConditionModel(**init_dict) - return model.to(torch_device) - - def get_unet_lora_config(self, lora_rank, lora_alpha): + def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules): # from diffusers test_models_unet_2d_condition.py from peft import LoraConfig unet_lora_config = LoraConfig( r=lora_rank, lora_alpha=lora_alpha, - target_modules=["to_q", "to_k", "to_v", "to_out.0"], + target_modules=target_modules, init_lora_weights=False, use_dora=False, ) return unet_lora_config - def get_dummy_input(self): - # from UNet2DConditionModelTests - batch_size = 4 - num_channels = 4 - sizes = (16, 16) - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} - - def check_hotswap(self, do_compile, rank0, rank1): - """ - Check that hotswapping works on a small unet. - - Steps: - - create 2 LoRA adapters and save them - - load the first adapter - - hotswap the second adapter - - check that the outputs are correct - - optionally compile the model - - Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would - fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is - fine. - - """ - from peft.utils.hotswap import prepare_model_for_compiled_hotswap - - dummy_input = self.get_dummy_input() - alpha0, alpha1 = rank0, rank1 - max_rank = max([rank0, rank1]) - lora_config0 = self.get_unet_lora_config(rank0, alpha0) - lora_config1 = self.get_unet_lora_config(rank1, alpha1) - - unet = self.get_small_unet() - unet.add_adapter(lora_config0, adapter_name="adapter0") - with torch.inference_mode(): - output0_before = unet(**dummy_input)["sample"] - - unet.add_adapter(lora_config1, adapter_name="adapter1") - unet.set_adapter("adapter1") - with torch.inference_mode(): - output1_before = unet(**dummy_input)["sample"] - - # sanity checks: - tol = 5e-3 - assert not torch.allclose(output0_before, output1_before, atol=tol, rtol=tol) - assert not (output0_before == 0).all() - assert not (output1_before == 0).all() - - with tempfile.TemporaryDirectory() as tmp_dirname: - unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0") - unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1") - del unet - - unet = self.get_small_unet() - file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors") - file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors") - unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0") - - if do_compile or (rank0 != rank1): - # no need to prepare if the model is not compiled or if the ranks are identical - prepare_model_for_compiled_hotswap( - unet, - config={"adapter0": lora_config0, "adapter1": lora_config1}, - target_rank=max_rank, - ) - if do_compile: - unet = torch.compile(unet, mode="reduce-overhead") - - with torch.inference_mode(): - output0_after = unet(**dummy_input)["sample"] - assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) - - unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True) - - # we need to call forward to potentially trigger recompilation - with torch.inference_mode(): - output1_after = unet(**dummy_input)["sample"] - assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) - - # check error when not passing valid adapter name - name = "does-not-exist" - msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" - with self.assertRaisesRegex(ValueError, msg): - unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True) - - @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa - @slow - @require_torch_2 - @require_torch_accelerator - @require_peft_backend - def test_hotswapping_diffusers_model(self, rank0, rank1): - self.check_hotswap(do_compile=False, rank0=rank0, rank1=rank1) - - @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa - @slow - @require_torch_2 - @require_torch_accelerator - @require_peft_backend - def test_hotswapping_compiled_diffusers_model(self, rank0, rank1): - # It's important to add this context to raise an error on recompilation - with torch._dynamo.config.patch(error_on_recompile=True): - self.check_hotswap(do_compile=True, rank0=rank0, rank1=rank1) - - ############ - # PIPELINE # - ############ - def get_lora_state_dicts(self, modules_to_save, adapter_name): from peft import get_peft_model_state_dict @@ -2352,7 +2223,7 @@ def get_lora_state_dicts(self, modules_to_save, adapter_name): ) return state_dicts - def get_dummy_input_pipeline(self): + def get_dummy_input(self): pipeline_inputs = { "prompt": "A painting of a squirrel eating a burger", "num_inference_steps": 5, @@ -2362,21 +2233,23 @@ def get_dummy_input_pipeline(self): } return pipeline_inputs - def check_pipeline_hotswap(self, rank0, rank1): + def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules): # Similar to check_hotswap but more realistic: check a whole pipeline to be closer to how users would use it from peft.utils.hotswap import prepare_model_for_compiled_hotswap - dummy_input = self.get_dummy_input_pipeline() + dummy_input = self.get_dummy_input() pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) alpha0, alpha1 = rank0, rank1 max_rank = max([rank0, rank1]) - lora_config0 = self.get_unet_lora_config(rank0, alpha0) - lora_config1 = self.get_unet_lora_config(rank1, alpha1) + lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules) + lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules) + torch.manual_seed(0) pipeline.unet.add_adapter(lora_config0, adapter_name="adapter0") output0_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0] + torch.manual_seed(1) pipeline.unet.add_adapter(lora_config1, adapter_name="adapter1") pipeline.unet.set_adapter("adapter1") output1_before = pipeline(**dummy_input, generator=torch.manual_seed(0))[0] @@ -2403,14 +2276,15 @@ def check_pipeline_hotswap(self, rank0, rank1): file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors") pipeline.load_lora_weights(file_name0) - if rank0 != rank1: + if do_compile or (rank0 != rank1): prepare_model_for_compiled_hotswap( pipeline.unet, config={"adapter0": lora_config0, "adapter1": lora_config1}, target_rank=max_rank, ) + if do_compile: + pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead") - pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead") output0_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0] # sanity check: still same result @@ -2427,7 +2301,40 @@ def check_pipeline_hotswap(self, rank0, rank1): @require_torch_2 @require_torch_accelerator @require_peft_backend - def test_hotswapping_compiled_diffusers_pipline(self, rank0, rank1): + def test_hotswapping_pipeline(self, rank0, rank1): + self.check_pipeline_hotswap( + do_compile=False, rank0=rank0, rank1=rank1, target_modules=["to_q", "to_k", "to_v", "to_out.0"] + ) + + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + @slow + @require_torch_2 + @require_torch_accelerator + @require_peft_backend + def test_hotswapping_compiled_pipline_linear(self, rank0, rank1): + # It's important to add this context to raise an error on recompilation + target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + @slow + @require_torch_2 + @require_torch_accelerator + @require_peft_backend + def test_hotswapping_compiled_pipline_conv2d(self, rank0, rank1): + # It's important to add this context to raise an error on recompilation + target_modules = ["conv", "conv1", "conv2"] + with torch._dynamo.config.patch(error_on_recompile=True): + self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + + @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa + @slow + @require_torch_2 + @require_torch_accelerator + @require_peft_backend + def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1): # It's important to add this context to raise an error on recompilation + target_modules = ["to_q", "conv"] with torch._dynamo.config.patch(error_on_recompile=True): - self.check_pipeline_hotswap(rank0=rank0, rank1=rank1) + self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) From 119a8edbab6309b3ed6bd5dcf8b235c1e68bfbc6 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 12 Feb 2025 11:42:46 +0100 Subject: [PATCH 12/36] Reviewer feedback: Move decorator to test classes ... instead of having them on each test method. --- tests/models/test_modeling_common.py | 21 +++++---------------- tests/pipelines/test_pipelines.py | 21 +++++---------------- 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 4e5d7d0b68ca..d40ae3c56dc1 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1525,6 +1525,11 @@ def test_push_to_hub_library_name(self): delete_repo(self.repo_id, token=TOKEN) +@slow +@require_torch_2 +@require_torch_accelerator +@require_peft_backend +@is_torch_compile class TestLoraHotSwappingForModel(unittest.TestCase): """Test that hotswapping does not result in recompilation on the model directly. @@ -1667,20 +1672,12 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules): unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa - @slow - @require_torch_2 - @require_torch_accelerator - @require_peft_backend def test_hotswapping_model(self, rank0, rank1): self.check_model_hotswap( do_compile=False, rank0=rank0, rank1=rank1, target_modules=["to_q", "to_k", "to_v", "to_out.0"] ) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa - @slow - @require_torch_2 - @require_torch_accelerator - @require_peft_backend def test_hotswapping_compiled_model_linear(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "to_k", "to_v", "to_out.0"] @@ -1688,10 +1685,6 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1): self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa - @slow - @require_torch_2 - @require_torch_accelerator - @require_peft_backend def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["conv", "conv1", "conv2"] @@ -1699,10 +1692,6 @@ def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa - @slow - @require_torch_2 - @require_torch_accelerator - @require_peft_backend def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "conv"] diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 145609b37f82..eb6a7fda4873 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2178,6 +2178,11 @@ def test_ddpm_ddim_equality_batched(self): assert np.abs(ddpm_images - ddim_images).max() < 1e-1 +@slow +@require_torch_2 +@require_torch_accelerator +@require_peft_backend +@is_torch_compile class TestLoraHotSwappingForPipeline(unittest.TestCase): """Test that hotswapping does not result in recompilation in a pipeline. @@ -2297,20 +2302,12 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules): assert np.allclose(output1_before, output1_after, atol=tol, rtol=tol) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa - @slow - @require_torch_2 - @require_torch_accelerator - @require_peft_backend def test_hotswapping_pipeline(self, rank0, rank1): self.check_pipeline_hotswap( do_compile=False, rank0=rank0, rank1=rank1, target_modules=["to_q", "to_k", "to_v", "to_out.0"] ) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa - @slow - @require_torch_2 - @require_torch_accelerator - @require_peft_backend def test_hotswapping_compiled_pipline_linear(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "to_k", "to_v", "to_out.0"] @@ -2318,10 +2315,6 @@ def test_hotswapping_compiled_pipline_linear(self, rank0, rank1): self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa - @slow - @require_torch_2 - @require_torch_accelerator - @require_peft_backend def test_hotswapping_compiled_pipline_conv2d(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["conv", "conv1", "conv2"] @@ -2329,10 +2322,6 @@ def test_hotswapping_compiled_pipline_conv2d(self, rank0, rank1): self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa - @slow - @require_torch_2 - @require_torch_accelerator - @require_peft_backend def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "conv"] From a715559f16842945db36c4608149e1b841257cf9 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 13 Feb 2025 12:07:47 +0100 Subject: [PATCH 13/36] Apply suggestions from code review Co-authored-by: hlky --- src/diffusers/loaders/unet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 8bd7eb16f1c5..14dc369180a5 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -394,9 +394,9 @@ def map_state_dict_for_hotswap(sd): new_sd = {} for k, v in sd.items(): if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"): - k = k[:-7] + f".{adapter_name}.weight" + k = k[: -len(".weight")] + f".{adapter_name}.weight" elif k.endswith("lora_B.bias"): # lora_bias=True option - k = k[:-5] + f".{adapter_name}.bias" + k = k[: -len(".bias")] + f".{adapter_name}.bias" new_sd[k] = v return new_sd From e40390d98596f773a1f9812375ac8a2cabbdc42f Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 13 Feb 2025 14:23:53 +0100 Subject: [PATCH 14/36] Reviewer feedback: version check, TODO comment --- src/diffusers/loaders/peft.py | 6 +++--- src/diffusers/loaders/unet.py | 8 +++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 84ead88eabfa..11d4a5c87ed9 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -326,14 +326,14 @@ def load_lora_adapter( peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage if hotswap: - try: + if is_peft_version(">", "0.14.0"): from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict - except ImportError as exc: + else: msg = ( "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " "from source." ) - raise ImportError(msg) from exc + raise ImportError(msg) if hotswap: diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 14dc369180a5..2f2b2f6329eb 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -378,14 +378,14 @@ def _process_lora( peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage if hotswap: - try: + if is_peft_version(">", "0.14.0"): from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict - except ImportError as exc: + else: msg = ( "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " "from source." ) - raise ImportError(msg) from exc + raise ImportError(msg) if hotswap: @@ -418,6 +418,8 @@ def map_state_dict_for_hotswap(sd): inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) except Exception as e: + # TODO: add test in line with: + # https://github.com/huggingface/diffusers/pull/10188/files#diff-b544edcc938e163009735ef4fa963abd0a41615c175552160c9e0f94ceb7f552 # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. if hasattr(self, "peft_config"): for module in self.modules(): From 1b834ecfef933565d278d7554959725ae72f8e06 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 14 Feb 2025 12:23:34 +0100 Subject: [PATCH 15/36] Add enable_lora_hotswap method --- src/diffusers/loaders/lora_base.py | 14 ++++++++++ src/diffusers/loaders/peft.py | 37 ++++++++++++++++++++++-- src/diffusers/loaders/unet.py | 39 ++++++++++++++++++++++++-- tests/models/test_modeling_common.py | 27 +++++++++++------- tests/pipelines/test_pipelines.py | 42 +++++++++++++++++++++------- 5 files changed, 134 insertions(+), 25 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 0c584777affc..e195894b30be 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -898,3 +898,17 @@ def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. # if _lora_scale has not been set, return 1 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 + + def enable_lora_hotswap(self, **kwargs) -> None: + """Enables the possibility to hotswap LoRA adapters. + + Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of + the loaded adapters differ. + + Args: + target_rank (`int`): + The highest rank among all the adapters that will be loaded. + """ + for component in self.components.values(): + if hasattr(component, "enable_lora_hotswap"): + component.enable_lora_hotswap(**kwargs) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 11d4a5c87ed9..c1774b51328f 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -121,6 +121,8 @@ class PeftAdapterMixin: """ _hf_peft_config_loaded = False + # kwargs for prepare_model_for_compiled_hotswap, if required + _prepare_lora_hotswap_kwargs: Optional[dict] = None @classmethod # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading @@ -325,9 +327,13 @@ def load_lora_adapter( if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - if hotswap: + if hotswap or (self._prepare_lora_hotswap_kwargs is not None): if is_peft_version(">", "0.14.0"): - from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict + from peft.utils.hotswap import ( + check_hotswap_configs_compatible, + hotswap_adapter_from_state_dict, + prepare_model_for_compiled_hotswap, + ) else: msg = ( "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " @@ -366,6 +372,19 @@ def map_state_dict_for_hotswap(sd): else: inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + + if self._prepare_lora_hotswap_kwargs is not None: + # For hotswapping of compiled models or adapters with different ranks. + # If the user called enable_lora_hotswap, we need to ensure it is called: + # - after the first adapter was loaded + # - before the model is compiled and the 2nd adapter is being hotswapped in + # Therefore, it needs to be called here + prepare_model_for_compiled_hotswap( + self, config=lora_config, **self._prepare_lora_hotswap_kwargs + ) + # We only want to call prepare_model_for_compiled_hotswap once + self._prepare_lora_hotswap_kwargs = None + except Exception as e: # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. if hasattr(self, "peft_config"): @@ -816,3 +835,17 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): # Pop also the corresponding adapter from the config if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) + + def enable_lora_hotswap(self, target_rank: int) -> None: + """Enables the possibility to hotswap LoRA adapters. + + Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of + the loaded adapters differ. + + Args: + target_rank (`int`): + The highest rank among all the adapters that will be loaded. + """ + if getattr(self, "peft_config", {}): + raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.") + self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank} diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 2f2b2f6329eb..7ffa3b350681 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -15,7 +15,7 @@ from collections import defaultdict from contextlib import nullcontext from pathlib import Path -from typing import Callable, Dict, Union +from typing import Callable, Dict, Optional, Union import safetensors import torch @@ -62,6 +62,8 @@ class UNet2DConditionLoadersMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME + # kwargs for prepare_model_for_compiled_hotswap, if required + _prepare_lora_hotswap_kwargs: Optional[dict] = None @validate_hf_hub_args def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): @@ -377,9 +379,13 @@ def _process_lora( if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - if hotswap: + if hotswap or (self._prepare_lora_hotswap_kwargs is not None): if is_peft_version(">", "0.14.0"): - from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict + from peft.utils.hotswap import ( + check_hotswap_configs_compatible, + hotswap_adapter_from_state_dict, + prepare_model_for_compiled_hotswap, + ) else: msg = ( "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " @@ -417,6 +423,19 @@ def map_state_dict_for_hotswap(sd): else: inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) + + if self._prepare_lora_hotswap_kwargs is not None: + # For hotswapping of compiled models or adapters with different ranks. + # If the user called enable_lora_hotswap, we need to ensure it is called: + # - after the first adapter was loaded + # - before the model is compiled and the 2nd adapter is being hotswapped in + # Therefore, it needs to be called here + prepare_model_for_compiled_hotswap( + self, config=lora_config, **self._prepare_lora_hotswap_kwargs + ) + # We only want to call prepare_model_for_compiled_hotswap once + self._prepare_lora_hotswap_kwargs = None + except Exception as e: # TODO: add test in line with: # https://github.com/huggingface/diffusers/pull/10188/files#diff-b544edcc938e163009735ef4fa963abd0a41615c175552160c9e0f94ceb7f552 @@ -1002,3 +1021,17 @@ def _load_ip_adapter_loras(self, state_dicts): } ) return lora_dicts + + def enable_lora_hotswap(self, target_rank: int) -> None: + """Enables the possibility to hotswap LoRA adapters. + + Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of + the loaded adapters differ. + + Args: + target_rank (`int`): + The highest rank among all the adapters that will be loaded. + """ + if getattr(self, "peft_config", {}): + raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.") + self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank} diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index df21d7bbd88c..c504abb86f83 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1638,10 +1638,8 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules): Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is fine. - """ - from peft.utils.hotswap import prepare_model_for_compiled_hotswap - + # create 2 adapters with different ranks and alphas dummy_input = self.get_dummy_input() alpha0, alpha1 = rank0, rank1 max_rank = max([rank0, rank1]) @@ -1665,22 +1663,21 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules): assert not (output1_before == 0).all() with tempfile.TemporaryDirectory() as tmp_dirname: + # save the adapter checkpoints unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0") unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1") del unet + # load the first adapter unet = self.get_small_unet() + if do_compile or (rank0 != rank1): + # no need to prepare if the model is not compiled or if the ranks are identical + unet.enable_lora_hotswap(target_rank=max_rank) + file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors") file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors") unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0") - if do_compile or (rank0 != rank1): - # no need to prepare if the model is not compiled or if the ranks are identical - prepare_model_for_compiled_hotswap( - unet, - config={"adapter0": lora_config0, "adapter1": lora_config1}, - target_rank=max_rank, - ) if do_compile: unet = torch.compile(unet, mode="reduce-overhead") @@ -1688,6 +1685,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules): output0_after = unet(**dummy_input)["sample"] assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) + # hotswap the 2nd adapter unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True) # we need to call forward to potentially trigger recompilation @@ -1727,3 +1725,12 @@ def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1): target_modules = ["to_q", "conv"] with torch._dynamo.config.patch(error_on_recompile=True): self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + + def test_enable_lora_hotswap_called_too_late_raises(self): + # ensure that enable_lora_hotswap is called before loading the first adapter + lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) + unet = self.get_small_unet() + unet.add_adapter(lora_config) + msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.") + with self.assertRaisesRegex(RuntimeError, msg): + unet.enable_lora_hotswap(target_rank=32) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index eb6a7fda4873..ce7c60848e88 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -17,6 +17,7 @@ import json import os import random +import re import shutil import sys import tempfile @@ -2239,12 +2240,23 @@ def get_dummy_input(self): return pipeline_inputs def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules): - # Similar to check_hotswap but more realistic: check a whole pipeline to be closer to how users would use it - from peft.utils.hotswap import prepare_model_for_compiled_hotswap - + """ + Check that hotswapping works on a pipeline. + + Steps: + - create 2 LoRA adapters and save them + - load the first adapter + - hotswap the second adapter + - check that the outputs are correct + - optionally compile the model + + Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would + fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is + fine. + """ + # create 2 adapters with different ranks and alphas dummy_input = self.get_dummy_input() pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) - alpha0, alpha1 = rank0, rank1 max_rank = max([rank0, rank1]) lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules) @@ -2266,6 +2278,7 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules): assert not (output1_before == 0).all() with tempfile.TemporaryDirectory() as tmp_dirname: + # save the adapter checkpoints lora0_state_dicts = self.get_lora_state_dicts({"unet": pipeline.unet}, adapter_name="adapter0") StableDiffusionPipeline.save_lora_weights( save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, **lora0_state_dicts @@ -2276,17 +2289,16 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules): ) del pipeline + # load the first adapter pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + if do_compile or (rank0 != rank1): + # no need to prepare if the model is not compiled or if the ranks are identical + pipeline.enable_lora_hotswap(target_rank=max_rank) + file_name0 = os.path.join(tmp_dirname, "adapter0", "pytorch_lora_weights.safetensors") file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors") pipeline.load_lora_weights(file_name0) - if do_compile or (rank0 != rank1): - prepare_model_for_compiled_hotswap( - pipeline.unet, - config={"adapter0": lora_config0, "adapter1": lora_config1}, - target_rank=max_rank, - ) if do_compile: pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead") @@ -2295,6 +2307,7 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules): # sanity check: still same result assert np.allclose(output0_before, output0_after, atol=tol, rtol=tol) + # hotswap the 2nd adapter pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0") output1_after = pipeline(**dummy_input, generator=torch.manual_seed(0))[0] @@ -2327,3 +2340,12 @@ def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1) target_modules = ["to_q", "conv"] with torch._dynamo.config.patch(error_on_recompile=True): self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + + def test_enable_lora_hotswap_called_too_late_raises(self): + # ensure that enable_lora_hotswap is called before loading the first adapter + lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) + pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + pipeline.unet.add_adapter(lora_config) + msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.") + with self.assertRaisesRegex(RuntimeError, msg): + pipeline.enable_lora_hotswap(target_rank=32) From 2cd366548217db020dc5cbc90caf33f27ade3995 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 17 Feb 2025 17:49:32 +0100 Subject: [PATCH 16/36] Reviewer feedback: check _lora_loadable_modules --- src/diffusers/loaders/lora_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index e195894b30be..e09b4f4d0d9e 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -909,6 +909,6 @@ def enable_lora_hotswap(self, **kwargs) -> None: target_rank (`int`): The highest rank among all the adapters that will be loaded. """ - for component in self.components.values(): - if hasattr(component, "enable_lora_hotswap"): + for key, component in self.components.items(): + if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules): component.enable_lora_hotswap(**kwargs) From e735ac26c743b7701243e28fd0ed29887c255b18 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 18 Feb 2025 16:55:25 +0100 Subject: [PATCH 17/36] Revert changes in unet.py --- src/diffusers/loaders/unet.py | 109 ++-------------------------------- 1 file changed, 5 insertions(+), 104 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7ffa3b350681..c68349c36dba 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -15,7 +15,7 @@ from collections import defaultdict from contextlib import nullcontext from pathlib import Path -from typing import Callable, Dict, Optional, Union +from typing import Callable, Dict, Union import safetensors import torch @@ -62,8 +62,6 @@ class UNet2DConditionLoadersMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME - # kwargs for prepare_model_for_compiled_hotswap, if required - _prepare_lora_hotswap_kwargs: Optional[dict] = None @validate_hf_hub_args def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): @@ -283,14 +281,7 @@ def _process_custom_diffusion(self, state_dict): return attn_processors def _process_lora( - self, - state_dict, - unet_identifier_key, - network_alphas, - adapter_name, - _pipeline, - low_cpu_mem_usage, - hotswap: bool = False, + self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage ): # This method does the following things: # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy @@ -303,7 +294,6 @@ def _process_lora( raise ValueError("PEFT backend is required for this method.") from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict - from peft.tuners.tuners_utils import BaseTunerLayer keys = list(state_dict.keys()) @@ -323,15 +313,10 @@ def _process_lora( state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict if len(state_dict_to_be_used) > 0: - if adapter_name in getattr(self, "peft_config", {}) and not hotswap: + if adapter_name in getattr(self, "peft_config", {}): raise ValueError( f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name." ) - elif adapter_name not in getattr(self, "peft_config", {}) and hotswap: - raise ValueError( - f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. " - "Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping." - ) state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used) @@ -379,78 +364,8 @@ def _process_lora( if is_peft_version(">=", "0.13.1"): peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - if hotswap or (self._prepare_lora_hotswap_kwargs is not None): - if is_peft_version(">", "0.14.0"): - from peft.utils.hotswap import ( - check_hotswap_configs_compatible, - hotswap_adapter_from_state_dict, - prepare_model_for_compiled_hotswap, - ) - else: - msg = ( - "Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it " - "from source." - ) - raise ImportError(msg) - - if hotswap: - - def map_state_dict_for_hotswap(sd): - # For hotswapping, we need the adapter name to be present in the state dict keys - new_sd = {} - for k, v in sd.items(): - if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"): - k = k[: -len(".weight")] + f".{adapter_name}.weight" - elif k.endswith("lora_B.bias"): # lora_bias=True option - k = k[: -len(".bias")] + f".{adapter_name}.bias" - new_sd[k] = v - return new_sd - - # To handle scenarios where we cannot successfully set state dict. If it's unsucessful, - # we should also delete the `peft_config` associated to the `adapter_name`. - try: - if hotswap: - check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) - hotswap_adapter_from_state_dict( - model=self, - state_dict=state_dict, - adapter_name=adapter_name, - config=lora_config, - ) - # the hotswap function raises if there are incompatible keys, so if we reach this point we can set - # it to None - incompatible_keys = None - else: - inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) - incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) - - if self._prepare_lora_hotswap_kwargs is not None: - # For hotswapping of compiled models or adapters with different ranks. - # If the user called enable_lora_hotswap, we need to ensure it is called: - # - after the first adapter was loaded - # - before the model is compiled and the 2nd adapter is being hotswapped in - # Therefore, it needs to be called here - prepare_model_for_compiled_hotswap( - self, config=lora_config, **self._prepare_lora_hotswap_kwargs - ) - # We only want to call prepare_model_for_compiled_hotswap once - self._prepare_lora_hotswap_kwargs = None - - except Exception as e: - # TODO: add test in line with: - # https://github.com/huggingface/diffusers/pull/10188/files#diff-b544edcc938e163009735ef4fa963abd0a41615c175552160c9e0f94ceb7f552 - # In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`. - if hasattr(self, "peft_config"): - for module in self.modules(): - if isinstance(module, BaseTunerLayer): - active_adapters = module.active_adapters - for active_adapter in active_adapters: - if adapter_name in active_adapter: - module.delete_adapter(adapter_name) - - self.peft_config.pop(adapter_name) - logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") - raise + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) warn_msg = "" if incompatible_keys is not None: @@ -1021,17 +936,3 @@ def _load_ip_adapter_loras(self, state_dicts): } ) return lora_dicts - - def enable_lora_hotswap(self, target_rank: int) -> None: - """Enables the possibility to hotswap LoRA adapters. - - Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of - the loaded adapters differ. - - Args: - target_rank (`int`): - The highest rank among all the adapters that will be loaded. - """ - if getattr(self, "peft_config", {}): - raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.") - self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank} From 3a6677ce4e5de03fc2d0296e29880dbfc8ed9469 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 21 Feb 2025 16:39:00 +0100 Subject: [PATCH 18/36] Add possibility to ignore enabled at wrong time --- src/diffusers/loaders/lora_base.py | 6 ++++ src/diffusers/loaders/lora_pipeline.py | 33 ++++++++---------- src/diffusers/loaders/peft.py | 48 ++++++++++++++++++-------- tests/models/test_modeling_common.py | 36 ++++++++++++++++++- tests/pipelines/test_pipelines.py | 36 ++++++++++++++++++- 5 files changed, 124 insertions(+), 35 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 6b53eb535ee1..02be690565ea 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -914,6 +914,12 @@ def enable_lora_hotswap(self, **kwargs) -> None: Args: target_rank (`int`): The highest rank among all the adapters that will be loaded. + check_correct (`str`, *optional*, defaults to `"error"`): + How to handle the case when the model is already compiled, which should generally be avoided. The + options are: + - "error" (default): raise an error + - "warn": issue a warning + - "ignore": do nothing """ for key, component in self.components.items(): if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules): diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 834086daef53..a99fc08b4600 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -118,13 +118,12 @@ def load_lora_weights( to call an additional method before loading the adapter: ```py - from peft.utils.hotswap import prepare_model_for_compiled_hotswap - - model = ... # load diffusers model with first LoRA adapter + pipeline = ... # load diffusers pipeline max_rank = ... # the highest rank among all LoRAs that you want to load - prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling - model = torch.compile(model) - model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now ``` There are some limitations to this technique, which are documented here: @@ -330,13 +329,12 @@ def load_lora_into_unet( to call an additional method before loading the adapter: ```py - from peft.utils.hotswap import prepare_model_for_compiled_hotswap - - model = ... # load diffusers model with first LoRA adapter + pipeline = ... # load diffusers pipeline max_rank = ... # the highest rank among all LoRAs that you want to load - prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling - model = torch.compile(model) - model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now ``` There are some limitations to this technique, which are documented here: @@ -800,13 +798,12 @@ def load_lora_into_unet( to call an additional method before loading the adapter: ```py - from peft.utils.hotswap import prepare_model_for_compiled_hotswap - - model = ... # load diffusers model with first LoRA adapter + pipeline = ... # load diffusers pipeline max_rank = ... # the highest rank among all LoRAs that you want to load - prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling - model = torch.compile(model) - model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now ``` There are some limitations to this technique, which are documented here: diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 36e57bd8c59f..e294aef6e9d5 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -16,7 +16,7 @@ import os from functools import partial from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Literal, Optional, Union import safetensors import torch @@ -144,8 +144,7 @@ def _optionally_disable_offloading(cls, _pipeline): def load_lora_adapter( self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs ): - r""" - Loads a LoRA adapter into the underlying model. + r"""Loads a LoRA adapter into the underlying model. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): @@ -194,21 +193,21 @@ def load_lora_adapter( However, the main advantage of hotswapping is that when the model is compiled with torch.compile, loading the new adapter does not require recompilation of the model. - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: + If the model is compiled, or if the new adapter and the old adapter have different ranks and/or LoRA + alphas (i.e. scaling), you need to call an additional method before loading the adapter: ```py - from peft.utils.hotswap import prepare_model_for_compiled_hotswap - - model = ... # load diffusers model with first LoRA adapter + pipeline = ... # load diffusers pipeline max_rank = ... # the highest rank among all LoRAs that you want to load - prepare_model_for_compiled_hotswap(model, target_rank=max_rank) # call *before* compiling - model = torch.compile(model) - model.load_lora_adapter(..., hotswap=True) # now hotswap the 2nd adapter + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now ``` There are some limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap + """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer @@ -837,16 +836,35 @@ def delete_adapters(self, adapter_names: Union[List[str], str]): if hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) - def enable_lora_hotswap(self, target_rank: int) -> None: + def enable_lora_hotswap( + self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error" + ) -> None: """Enables the possibility to hotswap LoRA adapters. Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of the loaded adapters differ. Args: - target_rank (`int`): + target_rank (`int`, *optional*, defaults to `128`): The highest rank among all the adapters that will be loaded. + + check_correct (`str`, *optional*, defaults to `"error"`): + How to handle the case when the model is already compiled, which should generally be avoided. The + options are: + - "error" (default): raise an error + - "warn": issue a warning + - "ignore": do nothing """ if getattr(self, "peft_config", {}): - raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.") - self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank} + if check_compiled == "error": + raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.") + elif check_compiled == "warn": + logger.warning( + "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." + ) + elif check_compiled != "ignore": + raise ValueError( + f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead." + ) + + self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled} diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index c38fbe6adf50..0a02b40b7128 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -24,6 +24,7 @@ import unittest import unittest.mock as mock import uuid +import warnings from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -1827,7 +1828,7 @@ def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1): with torch._dynamo.config.patch(error_on_recompile=True): self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) - def test_enable_lora_hotswap_called_too_late_raises(self): + def test_enable_lora_hotswap_called_after_adapter_added_raises(self): # ensure that enable_lora_hotswap is called before loading the first adapter lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) unet = self.get_small_unet() @@ -1835,3 +1836,36 @@ def test_enable_lora_hotswap_called_too_late_raises(self): msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.") with self.assertRaisesRegex(RuntimeError, msg): unet.enable_lora_hotswap(target_rank=32) + + def test_enable_lora_hotswap_called_after_adapter_added_warning(self): + # ensure that enable_lora_hotswap is called before loading the first adapter + from diffusers.loaders.peft import logger + + lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) + unet = self.get_small_unet() + unet.add_adapter(lora_config) + msg = ( + "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." + ) + with self.assertLogs(logger=logger, level="WARNING") as cm: + unet.enable_lora_hotswap(target_rank=32, check_compiled="warn") + assert any(msg in log for log in cm.output) + + def test_enable_lora_hotswap_called_after_adapter_added_ignore(self): + # check possibility to ignore the error/warning + lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) + unet = self.get_small_unet() + unet.add_adapter(lora_config) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # Capture all warnings + unet.enable_lora_hotswap(target_rank=32, check_compiled="warn") + self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}") + + def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): + # check that wrong argument value raises an error + lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) + unet = self.get_small_unet() + unet.add_adapter(lora_config) + msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.") + with self.assertRaisesRegex(ValueError, msg): + unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index ce7c60848e88..d4971fb5586f 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -24,6 +24,7 @@ import traceback import unittest import unittest.mock as mock +import warnings import numpy as np import PIL.Image @@ -2341,7 +2342,7 @@ def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1) with torch._dynamo.config.patch(error_on_recompile=True): self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) - def test_enable_lora_hotswap_called_too_late_raises(self): + def test_enable_lora_hotswap_called_after_adapter_added_raises(self): # ensure that enable_lora_hotswap is called before loading the first adapter lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) @@ -2349,3 +2350,36 @@ def test_enable_lora_hotswap_called_too_late_raises(self): msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.") with self.assertRaisesRegex(RuntimeError, msg): pipeline.enable_lora_hotswap(target_rank=32) + + def test_enable_lora_hotswap_called_after_adapter_added_warns(self): + # ensure that enable_lora_hotswap is called before loading the first adapter + from diffusers.loaders.peft import logger + + lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) + pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + pipeline.unet.add_adapter(lora_config) + msg = ( + "It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation." + ) + with self.assertLogs(logger=logger, level="WARNING") as cm: + pipeline.enable_lora_hotswap(target_rank=32, check_compiled="warn") + assert any(msg in log for log in cm.output) + + def test_enable_lora_hotswap_called_after_adapter_added_ignore(self): + # check possibility to ignore the error/warning + lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) + pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + pipeline.unet.add_adapter(lora_config) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") # Capture all warnings + pipeline.enable_lora_hotswap(target_rank=32, check_compiled="warn") + self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}") + + def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): + # check that wrong argument value raises an error + lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"]) + pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + pipeline.unet.add_adapter(lora_config) + msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.") + with self.assertRaisesRegex(ValueError, msg): + pipeline.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") From a96f3fd1c4539d4fdf1c0cc2709f2a714b42e1b8 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 21 Feb 2025 16:40:39 +0100 Subject: [PATCH 19/36] Fix docstrings --- src/diffusers/loaders/lora_base.py | 2 +- src/diffusers/loaders/peft.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 02be690565ea..a4c80d385116 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -914,7 +914,7 @@ def enable_lora_hotswap(self, **kwargs) -> None: Args: target_rank (`int`): The highest rank among all the adapters that will be loaded. - check_correct (`str`, *optional*, defaults to `"error"`): + check_compiled (`str`, *optional*, defaults to `"error"`): How to handle the case when the model is already compiled, which should generally be avoided. The options are: - "error" (default): raise an error diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index e294aef6e9d5..da550aadcc5b 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -848,7 +848,7 @@ def enable_lora_hotswap( target_rank (`int`, *optional*, defaults to `128`): The highest rank among all the adapters that will be loaded. - check_correct (`str`, *optional*, defaults to `"error"`): + check_compiled (`str`, *optional*, defaults to `"error"`): How to handle the case when the model is already compiled, which should generally be avoided. The options are: - "error" (default): raise an error From 2c6b435caef34e89f3fb166b91c9f9903881039b Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 27 Feb 2025 16:41:52 +0100 Subject: [PATCH 20/36] Log possible PEFT error, test --- src/diffusers/loaders/peft.py | 16 +++++++++------ tests/models/test_modeling_common.py | 30 +++++++++++++++++++++------- tests/pipelines/test_pipelines.py | 30 +++++++++++++++++++++------- 3 files changed, 56 insertions(+), 20 deletions(-) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index 4d7e4c8d986b..ab5f431b3a9b 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -377,12 +377,16 @@ def map_state_dict_for_hotswap(sd): if hotswap: state_dict = map_state_dict_for_hotswap(state_dict) check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config) - hotswap_adapter_from_state_dict( - model=self, - state_dict=state_dict, - adapter_name=adapter_name, - config=lora_config, - ) + try: + hotswap_adapter_from_state_dict( + model=self, + state_dict=state_dict, + adapter_name=adapter_name, + config=lora_config, + ) + except Exception as e: + logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}") + raise # the hotswap function raises if there are incompatible keys, so if we reach this point we can set # it to None incompatible_keys = None diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 923b4770ac50..28628a433850 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1725,7 +1725,7 @@ def get_dummy_input(self): return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} - def check_model_hotswap(self, do_compile, rank0, rank1, target_modules): + def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None): """ Check that hotswapping works on a small unet. @@ -1744,8 +1744,10 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules): dummy_input = self.get_dummy_input() alpha0, alpha1 = rank0, rank1 max_rank = max([rank0, rank1]) - lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules) - lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules) + if target_modules1 is None: + target_modules1 = target_modules0[:] + lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0) + lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1) unet = self.get_small_unet() unet.add_adapter(lora_config0, adapter_name="adapter0") @@ -1803,7 +1805,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules): @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_model(self, rank0, rank1): self.check_model_hotswap( - do_compile=False, rank0=rank0, rank1=rank1, target_modules=["to_q", "to_k", "to_v", "to_out.0"] + do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"] ) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa @@ -1811,21 +1813,21 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "to_k", "to_v", "to_out.0"] with torch._dynamo.config.patch(error_on_recompile=True): - self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_model_conv2d(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["conv", "conv1", "conv2"] with torch._dynamo.config.patch(error_on_recompile=True): - self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "conv"] with torch._dynamo.config.patch(error_on_recompile=True): - self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules) def test_enable_lora_hotswap_called_after_adapter_added_raises(self): # ensure that enable_lora_hotswap is called before loading the first adapter @@ -1868,3 +1870,17 @@ def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.") with self.assertRaisesRegex(ValueError, msg): unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") + + def test_hotswap_second_adapter_targets_more_layers_raises(self): + # check the error and log + from diffusers.loaders.peft import logger + + # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers + target_modules0 = ["to_q"] + target_modules1 = ["to_q", "to_k"] + with self.assertRaises(RuntimeError): # peft raises RuntimeError + with self.assertLogs(logger=logger, level="ERROR") as cm: + self.check_model_hotswap( + do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1 + ) + assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index d4971fb5586f..8940cddffdb2 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2240,7 +2240,7 @@ def get_dummy_input(self): } return pipeline_inputs - def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules): + def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None): """ Check that hotswapping works on a pipeline. @@ -2260,8 +2260,10 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules): pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) alpha0, alpha1 = rank0, rank1 max_rank = max([rank0, rank1]) - lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules) - lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules) + if target_modules1 is None: + target_modules1 = target_modules0[:] + lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0) + lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1) torch.manual_seed(0) pipeline.unet.add_adapter(lora_config0, adapter_name="adapter0") @@ -2318,7 +2320,7 @@ def check_pipeline_hotswap(self, do_compile, rank0, rank1, target_modules): @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_pipeline(self, rank0, rank1): self.check_pipeline_hotswap( - do_compile=False, rank0=rank0, rank1=rank1, target_modules=["to_q", "to_k", "to_v", "to_out.0"] + do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"] ) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa @@ -2326,21 +2328,21 @@ def test_hotswapping_compiled_pipline_linear(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "to_k", "to_v", "to_out.0"] with torch._dynamo.config.patch(error_on_recompile=True): - self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_pipline_conv2d(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["conv", "conv1", "conv2"] with torch._dynamo.config.patch(error_on_recompile=True): - self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_compiled_pipline_both_linear_and_conv2d(self, rank0, rank1): # It's important to add this context to raise an error on recompilation target_modules = ["to_q", "conv"] with torch._dynamo.config.patch(error_on_recompile=True): - self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules=target_modules) + self.check_pipeline_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules) def test_enable_lora_hotswap_called_after_adapter_added_raises(self): # ensure that enable_lora_hotswap is called before loading the first adapter @@ -2383,3 +2385,17 @@ def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self): msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.") with self.assertRaisesRegex(ValueError, msg): pipeline.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument") + + def test_hotswap_second_adapter_targets_more_layers_raises(self): + # check the error and log + from diffusers.loaders.peft import logger + + # at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers + target_modules0 = ["to_q"] + target_modules1 = ["to_q", "to_k"] + with self.assertRaises(RuntimeError): # peft raises RuntimeError + with self.assertLogs(logger=logger, level="ERROR") as cm: + self.check_pipeline_hotswap( + do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1 + ) + assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output) From ccb45f795ac4b3c1f2f8e8f62d54820ef1a9808f Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 27 Feb 2025 16:57:17 +0100 Subject: [PATCH 21/36] Raise helpful error if hotswap not supported I.e. for the text encoder --- src/diffusers/loaders/lora_base.py | 4 +++ src/diffusers/loaders/lora_pipeline.py | 21 ++++++++++++++++ tests/pipelines/test_pipelines.py | 34 ++++++++++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index a4c80d385116..cc9ca5d9040e 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -316,6 +316,7 @@ def _load_lora_into_text_encoder( adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, + hotswap: bool = False, ): if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -344,6 +345,9 @@ def _load_lora_into_text_encoder( # Safe prefix to check with. if any(text_encoder_name in key for key in keys): + if hotswap: + raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.") + # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] text_encoder_lora_state_dict = { diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a99fc08b4600..825ae1315e18 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -170,6 +170,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -377,6 +378,7 @@ def load_lora_into_text_encoder( adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -402,6 +404,24 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place. + This means that, instead of loading an additional adapter, this will take the existing adapter weights + and replace them with the weights of the new adapter. This can be faster and more memory efficient. + However, the main advantage of hotswapping is that when the model is compiled with torch.compile, + loading the new adapter does not require recompilation of the model. + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + There are some limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -413,6 +433,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 8940cddffdb2..97b16ed7b811 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2399,3 +2399,37 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self): do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1 ) assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output) + + def test_hotswap_component_not_supported_raises(self): + # right now, not some components don't support hotswapping, e.g. the text_encoder + from peft import LoraConfig + + pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + max_rank = 8 + lora_config0 = LoraConfig(target_modules=["q_proj"]) + lora_config1 = LoraConfig(target_modules=["q_proj"]) + + pipeline.text_encoder.add_adapter(lora_config0, adapter_name="adapter0") + pipeline.text_encoder.add_adapter(lora_config1, adapter_name="adapter1") + + with tempfile.TemporaryDirectory() as tmp_dirname: + # save the adapter checkpoints + lora0_state_dicts = self.get_lora_state_dicts({"text_encoder": pipeline.text_encoder}, adapter_name="adapter0") + StableDiffusionPipeline.save_lora_weights( + save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, **lora0_state_dicts + ) + lora1_state_dicts = self.get_lora_state_dicts({"text_encoder": pipeline.text_encoder}, adapter_name="adapter1") + StableDiffusionPipeline.save_lora_weights( + save_directory=os.path.join(tmp_dirname, "adapter1"), safe_serialization=True, **lora1_state_dicts + ) + del pipeline + + # load the first adapter + pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) + file_name0 = os.path.join(tmp_dirname, "adapter0", "pytorch_lora_weights.safetensors") + file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors") + + pipeline.load_lora_weights(file_name0) + msg = re.escape("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`") + with self.assertRaisesRegex(ValueError, msg): + pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0") From 09e2ec79cdd5fcd9146d67a1ab83ba56278fb1f7 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 27 Feb 2025 16:59:34 +0100 Subject: [PATCH 22/36] Formatting --- src/diffusers/loaders/lora_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index cc9ca5d9040e..502be35c4425 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -346,7 +346,9 @@ def _load_lora_into_text_encoder( # Safe prefix to check with. if any(text_encoder_name in key for key in keys): if hotswap: - raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.") + raise ValueError( + "At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`." + ) # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] From 67ab6bfa84379c9de73e0a66d2cd14bc950a30cc Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 27 Feb 2025 17:04:42 +0100 Subject: [PATCH 23/36] More linter --- tests/pipelines/test_pipelines.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 97b16ed7b811..3d089d6f52d6 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2405,7 +2405,6 @@ def test_hotswap_component_not_supported_raises(self): from peft import LoraConfig pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device) - max_rank = 8 lora_config0 = LoraConfig(target_modules=["q_proj"]) lora_config1 = LoraConfig(target_modules=["q_proj"]) From f03fe6b5cf3426826bce12145bf42e1990500d99 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 27 Feb 2025 17:16:17 +0100 Subject: [PATCH 24/36] More ruff --- tests/pipelines/test_pipelines.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 3d089d6f52d6..7cc09abb4195 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -2413,11 +2413,15 @@ def test_hotswap_component_not_supported_raises(self): with tempfile.TemporaryDirectory() as tmp_dirname: # save the adapter checkpoints - lora0_state_dicts = self.get_lora_state_dicts({"text_encoder": pipeline.text_encoder}, adapter_name="adapter0") + lora0_state_dicts = self.get_lora_state_dicts( + {"text_encoder": pipeline.text_encoder}, adapter_name="adapter0" + ) StableDiffusionPipeline.save_lora_weights( save_directory=os.path.join(tmp_dirname, "adapter0"), safe_serialization=True, **lora0_state_dicts ) - lora1_state_dicts = self.get_lora_state_dicts({"text_encoder": pipeline.text_encoder}, adapter_name="adapter1") + lora1_state_dicts = self.get_lora_state_dicts( + {"text_encoder": pipeline.text_encoder}, adapter_name="adapter1" + ) StableDiffusionPipeline.save_lora_weights( save_directory=os.path.join(tmp_dirname, "adapter1"), safe_serialization=True, **lora1_state_dicts ) @@ -2429,6 +2433,8 @@ def test_hotswap_component_not_supported_raises(self): file_name1 = os.path.join(tmp_dirname, "adapter1", "pytorch_lora_weights.safetensors") pipeline.load_lora_weights(file_name0) - msg = re.escape("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`") + msg = re.escape( + "At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`" + ) with self.assertRaisesRegex(ValueError, msg): pipeline.load_lora_weights(file_name1, hotswap=True, adapter_name="default_0") From 2d407ca471dba425c569ca588738599bee4ea9e5 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 27 Feb 2025 17:33:11 +0100 Subject: [PATCH 25/36] Doc-builder complaint --- src/diffusers/loaders/lora_pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 825ae1315e18..1e127e0e5883 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -409,9 +409,9 @@ def load_lora_into_text_encoder( This means that, instead of loading an additional adapter, this will take the existing adapter weights and replace them with the weights of the new adapter. This can be faster and more memory efficient. However, the main advantage of hotswapping is that when the model is compiled with torch.compile, - loading the new adapter does not require recompilation of the model. - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: + loading the new adapter does not require recompilation of the model. If the new adapter and the old + adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an additional method + before loading the adapter: ```py pipeline = ... # load diffusers pipeline max_rank = ... # the highest rank among all LoRAs that you want to load From 6b59ecfe7df83323063273830fe80bff2720de48 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 3 Mar 2025 12:36:46 +0100 Subject: [PATCH 26/36] Update docstring: - mention no text encoder support yet - make it clear that LoRA is meant - mention that same adapter name should be passed --- src/diffusers/loaders/lora_pipeline.py | 68 +++++++++++++++----------- src/diffusers/loaders/peft.py | 22 +++++---- 2 files changed, 52 insertions(+), 38 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1e127e0e5883..88c60c27e449 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -83,8 +83,7 @@ def load_lora_weights( hotswap: bool = False, **kwargs, ): - """ - Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + """Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. @@ -108,11 +107,12 @@ def load_lora_weights( Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place. - This means that, instead of loading an additional adapter, this will take the existing adapter weights - and replace them with the weights of the new adapter. This can be faster and more memory efficient. - However, the main advantage of hotswapping is that when the model is compiled with torch.compile, - loading the new adapter does not require recompilation of the model. + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an additional method before loading the adapter: @@ -126,10 +126,12 @@ def load_lora_weights( # optionally compile the model now ``` - There are some limitations to this technique, which are documented here: + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -320,11 +322,12 @@ def load_lora_into_unet( Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place. - This means that, instead of loading an additional adapter, this will take the existing adapter weights - and replace them with the weights of the new adapter. This can be faster and more memory efficient. - However, the main advantage of hotswapping is that when the model is compiled with torch.compile, - loading the new adapter does not require recompilation of the model. + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an additional method before loading the adapter: @@ -338,7 +341,8 @@ def load_lora_into_unet( # optionally compile the model now ``` - There are some limitations to this technique, which are documented here: + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if not USE_PEFT_BACKEND: @@ -405,13 +409,17 @@ def load_lora_into_text_encoder( Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place. - This means that, instead of loading an additional adapter, this will take the existing adapter weights - and replace them with the weights of the new adapter. This can be faster and more memory efficient. - However, the main advantage of hotswapping is that when the model is compiled with torch.compile, - loading the new adapter does not require recompilation of the model. If the new adapter and the old - adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an additional method - before loading the adapter: + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + ```py pipeline = ... # load diffusers pipeline max_rank = ... # the highest rank among all LoRAs that you want to load @@ -420,7 +428,9 @@ def load_lora_into_text_encoder( pipeline.load_lora_weights(file_name) # optionally compile the model now ``` - There are some limitations to this technique, which are documented here: + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ _load_lora_into_text_encoder( @@ -809,11 +819,12 @@ def load_lora_into_unet( Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place. - This means that, instead of loading an additional adapter, this will take the existing adapter weights - and replace them with the weights of the new adapter. This can be faster and more memory efficient. - However, the main advantage of hotswapping is that when the model is compiled with torch.compile, - loading the new adapter does not require recompilation of the model. + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an additional method before loading the adapter: @@ -827,7 +838,8 @@ def load_lora_into_unet( # optionally compile the model now ``` - There are some limitations to this technique, which are documented here: + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if not USE_PEFT_BACKEND: diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index ab5f431b3a9b..f26cf1e9983e 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -148,7 +148,8 @@ def _optionally_disable_offloading(cls, _pipeline): def load_lora_adapter( self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs ): - r"""Loads a LoRA adapter into the underlying model. + r""" + Loads a LoRA adapter into the underlying model. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): @@ -191,14 +192,15 @@ def load_lora_adapter( Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap : (`bool`, *optional*) - Defaults to `False`. Whether to substitute an existing adapter with the newly loaded adapter in-place. - This means that, instead of loading an additional adapter, this will take the existing adapter weights - and replace them with the weights of the new adapter. This can be faster and more memory efficient. - However, the main advantage of hotswapping is that when the model is compiled with torch.compile, - loading the new adapter does not require recompilation of the model. + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - If the model is compiled, or if the new adapter and the old adapter have different ranks and/or LoRA - alphas (i.e. scaling), you need to call an additional method before loading the adapter: + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: ```py pipeline = ... # load diffusers pipeline @@ -209,9 +211,9 @@ def load_lora_adapter( # optionally compile the model now ``` - There are some limitations to this technique, which are documented here: + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap - """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer From c3c1bdf728c885013cd27f3844ca11632a47cae0 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 5 Mar 2025 16:40:49 +0100 Subject: [PATCH 27/36] Fix error in docstring --- src/diffusers/loaders/lora_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 4df7f5ec73b4..b39827c54b9c 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -409,7 +409,6 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. - hotswap : (`bool`, *optional*) hotswap : (`bool`, *optional*) Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. This means that, instead of loading an additional adapter, this will take the existing From 387ddf6876ac634008adcdc5cbb7a5f1da018ee7 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Fri, 7 Mar 2025 16:54:54 +0100 Subject: [PATCH 28/36] Update more methods with hotswap argument - SDXL - SD3 - Flux No changes were made to load_lora_into_transformer. --- src/diffusers/loaders/lora_pipeline.py | 155 ++++++++++++++++++++++++- 1 file changed, 153 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index b39827c54b9c..ec618c7f3d19 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -880,6 +880,7 @@ def load_lora_into_text_encoder( adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -905,6 +906,29 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -916,6 +940,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -1155,7 +1180,11 @@ def lora_state_dict( return state_dict def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name=None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and @@ -1178,6 +1207,26 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new + adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an + additional method before loading the adapter: + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ @@ -1224,6 +1273,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} @@ -1237,6 +1287,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -1287,6 +1338,7 @@ def load_lora_into_text_encoder( adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1312,6 +1364,29 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -1323,6 +1398,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -1600,7 +1676,11 @@ def lora_state_dict( return state_dict def load_lora_weights( - self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name=None, + hotswap: bool = False, + **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and @@ -1625,6 +1705,26 @@ def load_lora_weights( low_cpu_mem_usage (`bool`, *optional*): `Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new + adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an + additional method before loading the adapter: + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1706,6 +1806,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -1817,6 +1918,7 @@ def load_lora_into_text_encoder( adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1842,6 +1944,29 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -1853,6 +1978,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -2312,6 +2438,7 @@ def load_lora_into_text_encoder( adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -2337,6 +2464,29 @@ def load_lora_into_text_encoder( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ _load_lora_into_text_encoder( state_dict=state_dict, @@ -2348,6 +2498,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod From dec4d1087e52e6c7e76376bdad02adc35b3a6d0a Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 11 Mar 2025 11:57:23 +0100 Subject: [PATCH 29/36] Add hotswap argument to load_lora_into_transformer For SD3 and Flux. Use shorter docstring for brevity. --- src/diffusers/loaders/lora_pipeline.py | 29 ++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index ec618c7f3d19..4e443e477867 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1260,6 +1260,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} @@ -1292,7 +1293,7 @@ def load_lora_weights( @classmethod def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1310,6 +1311,13 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -1324,6 +1332,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -1786,6 +1795,7 @@ def load_lora_weights( adapter_name=adapter_name, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) if len(transformer_norm_state_dict) > 0: @@ -1811,7 +1821,14 @@ def load_lora_weights( @classmethod def load_lora_into_transformer( - cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, + state_dict, + network_alphas, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1833,6 +1850,13 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -1850,6 +1874,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod From 716f446401a5e781aa7dcdd556ac8a6de2e94d61 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 12 Mar 2025 11:22:52 +0100 Subject: [PATCH 30/36] Extend docstrings --- src/diffusers/loaders/lora_pipeline.py | 44 +++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e16b3662a9bb..1754af3d7ace 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -132,7 +132,6 @@ def load_lora_weights( https://huggingface.co/docs/peft/main/en/package_reference/hotswap kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1200,9 +1199,11 @@ def load_lora_weights( adapter weights and replace them with the weights of the new adapter. This can be faster and more memory efficient. However, the main advantage of hotswapping is that when the model is compiled with torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. If the new - adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need to call an - additional method before loading the adapter: + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + ```py pipeline = ... # load diffusers pipeline max_rank = ... # the highest rank among all LoRAs that you want to load @@ -1211,6 +1212,7 @@ def load_lora_weights( pipeline.load_lora_weights(file_name) # optionally compile the model now ``` + Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap @@ -1295,7 +1297,23 @@ def load_lora_into_transformer( memory efficient. However, the main advantage of hotswapping is that when the model is compiled with torch.compile, loading the new adapter does not require recompilation of the model. When using hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - """ + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap + """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." @@ -1841,6 +1859,22 @@ def load_lora_into_transformer( memory efficient. However, the main advantage of hotswapping is that when the model is compiled with torch.compile, loading the new adapter does not require recompilation of the model. When using hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( From 4d821117e4cebdb5eb022146c9e0fa2b65e601d9 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 12 Mar 2025 11:28:14 +0100 Subject: [PATCH 31/36] Add version guards to tests --- tests/models/test_modeling_common.py | 2 ++ tests/pipelines/test_pipelines.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 720d7bf93b22..97088f429cd5 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -63,6 +63,7 @@ is_torch_compile, numpy_cosine_similarity_distance, require_peft_backend, + require_peft_version_greater, require_torch_2, require_torch_accelerator, require_torch_accelerator_with_training, @@ -1670,6 +1671,7 @@ def test_push_to_hub_library_name(self): @require_torch_2 @require_torch_accelerator @require_peft_backend +@require_peft_version_greater("0.14.0") @is_torch_compile class TestLoraHotSwappingForModel(unittest.TestCase): """Test that hotswapping does not result in recompilation on the model directly. diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 2f933cf4b813..ae5a12e04ba8 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -81,6 +81,7 @@ require_hf_hub_version_greater, require_onnxruntime, require_peft_backend, + require_peft_version_greater, require_torch_2, require_torch_accelerator, require_transformers_version_greater, @@ -2184,6 +2185,7 @@ def test_ddpm_ddim_equality_batched(self): @require_torch_2 @require_torch_accelerator @require_peft_backend +@require_peft_version_greater("0.14.0") @is_torch_compile class TestLoraHotSwappingForPipeline(unittest.TestCase): """Test that hotswapping does not result in recompilation in a pipeline. From 425cb398899260992c45107c7e1a234495d549ba Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 12 Mar 2025 12:24:04 +0100 Subject: [PATCH 32/36] Formatting --- src/diffusers/loaders/lora_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 1754af3d7ace..9781875f3957 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1313,7 +1313,7 @@ def load_lora_into_transformer( Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap - """ + """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." From 115c77d22d026416945bc35c3eac59423d438029 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 12 Mar 2025 12:52:40 +0100 Subject: [PATCH 33/36] Fix LoRA loading call to add prefix=None See: https://github.com/huggingface/diffusers/pull/10187#issuecomment-2717571064 --- tests/models/test_modeling_common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 97088f429cd5..4b8d3db422f0 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1791,7 +1791,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors") file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors") - unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0") + unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) if do_compile: unet = torch.compile(unet, mode="reduce-overhead") @@ -1801,7 +1801,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) # hotswap the 2nd adapter - unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True) + unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) # we need to call forward to potentially trigger recompilation with torch.inference_mode(): @@ -1812,7 +1812,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ name = "does-not-exist" msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name" with self.assertRaisesRegex(ValueError, msg): - unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True) + unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None) @parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa def test_hotswapping_model(self, rank0, rank1): From 5d9075376a753cee1eb28d60e503bea2fc7cddf2 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 12 Mar 2025 13:41:13 +0100 Subject: [PATCH 34/36] Run make fix-copies --- src/diffusers/loaders/lora_pipeline.py | 241 ++++++++++++++++++++++++- 1 file changed, 232 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 9781875f3957..fae5e1d6d888 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2425,7 +2425,14 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): @classmethod # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel def load_lora_into_transformer( - cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, + state_dict, + network_alphas, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -2447,6 +2454,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): raise ValueError( @@ -2461,6 +2491,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -2752,7 +2783,7 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -2770,6 +2801,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -2784,6 +2838,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -3055,7 +3110,7 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3073,6 +3128,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3087,6 +3165,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -3360,7 +3439,7 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3378,6 +3457,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3392,6 +3494,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -3665,7 +3768,7 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3683,6 +3786,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -3697,6 +3823,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -3973,7 +4100,7 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -3991,6 +4118,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4005,6 +4155,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -4282,7 +4433,7 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4300,6 +4451,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4314,6 +4488,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -4587,7 +4762,7 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4605,6 +4780,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4619,6 +4817,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod @@ -4892,7 +5091,7 @@ def load_lora_weights( @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel def load_lora_into_transformer( - cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -4910,6 +5109,29 @@ def load_lora_into_transformer( low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + hotswap : (`bool`, *optional*) + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( @@ -4924,6 +5146,7 @@ def load_lora_into_transformer( adapter_name=adapter_name, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, ) @classmethod From 366632d06d04e25000b01871213b914eddcbffae Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 17 Mar 2025 14:45:31 +0100 Subject: [PATCH 35/36] Add hot swap documentation to the docs --- .../en/using-diffusers/loading_adapters.md | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index e16c1322e5d1..a3a40c2e1a05 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -194,6 +194,62 @@ Currently, [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] only support +### Hot swapping LoRA adapters + +A common use case when serving multiple adapters is to load one adapter first, generate images, then load another adapter, generate more images, load another adapter, etc. This workflow would normally require calling [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] and [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] and possibly [`~loaders.peft.PeftAdapterMixin.delete_adapters`] to save on memory. Those are quite a few steps. Morever, if the model is compiled using `torch.compile`, performing these steps will result in recompilation, which takes time. + +To better support this common workflow, diffusers offers the option to "hot swap" a LoRA adapter. This requires an adapter to already be loaded. Then, a new adapter can be hot swapped for the existing adapter, i.e. the weights are swapped in-place. This is more convenient, doesn't accumulate memory, and does not require recompilation, at least in some circumstances. + +In general, hot swapping can be accomplished by passing `hotswap=True` when loading the LoRA adapter: + +```python +pipe = ... +# load adapter 1 as normal +pipeline.load_lora_weights(file_name_adapter_1) +# generate some images with adapter 1 +... +# now hot swap the 2nd adapter +pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0") +# generate images with adapter 2 +``` + +Notice that we passed `adapter_name="default_0"`. This is the default adapter name given by diffusers and it is important that we indicate the name of the existing adapter. If you loaded the first adapter under a different name, pass that name instead. + + + +Hot swapping is currently not supported for the text encoder. If the LoRA adapter targets the text encoder, don't use this feature. + + + +Now when it comes to compiled models, the same code as above may also work without triggering recompilation, but only if the second adapter targets the exact same ranks, has the exact same LoRA ranks and also scales. For most adapters, this is not the case. Therefore, it is necessary to go through one more step, as shown in this snippet: + +```python +pipe = ... +# call this extra method +pipe.enable_lora_hotswap(target_rank=max_rank) +# now load adapter 1 +pipe.load_lora_weights(file_name_adapter_1) +# now compile the unet of the pipeline +pipe.unet = torch.compile(pipeline.unet, ...) +# generate some images with adapter 1 +... +# now hot swap adapter 2 +pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0") +# generate images with adapter 2 +``` + +By calling the [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] method, diffusers makes it possible to hot swap the LoRA adapter without triggering recompilation. For this to work, call the method _before_ loading the first adapter. Also note that, as always, `torch.compile` has to be called _after_ loading the first adapter. + +The `target_rank=max_rank` argument is important to let diffusers know what will be the maximum rank among all LoRA adapters that will be loaded. So if you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. By default, this value is 128. If in doubt, prefer a higher value. + +Even after following these steps, there can be situations that will result in recompilation. Most notably, if the swapped in adapters targets more layers than the initial adapter, recompilation is needed. Try to load the adapter that targets most layers first. Read more about the limitations of hot swapping in the [PEFT documentation on hot swapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter). + + + +To detect if the model was recompiled, move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager. If you detect recompilation despite following all the steps above, please open an issue on the [diffusers GitHub repository](https://github.com/huggingface/diffusers/issues) with a reproducer. + + + ### Kohya and TheLastBen Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way. From b181a47739adb3ac33f8f764ecf3214302064239 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 18 Mar 2025 11:17:19 +0100 Subject: [PATCH 36/36] Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- .../en/using-diffusers/loading_adapters.md | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index a3a40c2e1a05..7522996b2424 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -194,13 +194,13 @@ Currently, [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] only support -### Hot swapping LoRA adapters +### Hotswapping LoRA adapters -A common use case when serving multiple adapters is to load one adapter first, generate images, then load another adapter, generate more images, load another adapter, etc. This workflow would normally require calling [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] and [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] and possibly [`~loaders.peft.PeftAdapterMixin.delete_adapters`] to save on memory. Those are quite a few steps. Morever, if the model is compiled using `torch.compile`, performing these steps will result in recompilation, which takes time. +A common use case when serving multiple adapters is to load one adapter first, generate images, load another adapter, generate more images, load another adapter, etc. This workflow normally requires calling [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`], and possibly [`~loaders.peft.PeftAdapterMixin.delete_adapters`] to save memory. Moreover, if the model is compiled using `torch.compile`, performing these steps requires recompilation, which takes time. -To better support this common workflow, diffusers offers the option to "hot swap" a LoRA adapter. This requires an adapter to already be loaded. Then, a new adapter can be hot swapped for the existing adapter, i.e. the weights are swapped in-place. This is more convenient, doesn't accumulate memory, and does not require recompilation, at least in some circumstances. +To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter. -In general, hot swapping can be accomplished by passing `hotswap=True` when loading the LoRA adapter: +Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter, (`default_0` is the default adapter name), to be swapped. If you loaded the first adapter with a different name, use that name instead. ```python pipe = ... @@ -213,15 +213,14 @@ pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="defa # generate images with adapter 2 ``` -Notice that we passed `adapter_name="default_0"`. This is the default adapter name given by diffusers and it is important that we indicate the name of the existing adapter. If you loaded the first adapter under a different name, pass that name instead. -Hot swapping is currently not supported for the text encoder. If the LoRA adapter targets the text encoder, don't use this feature. +Hotswapping is not currently supported for LoRA adapters that target the text encoder. -Now when it comes to compiled models, the same code as above may also work without triggering recompilation, but only if the second adapter targets the exact same ranks, has the exact same LoRA ranks and also scales. For most adapters, this is not the case. Therefore, it is necessary to go through one more step, as shown in this snippet: +For compiled models, it is often (though not always if the second adapter targets identical LoRA ranks and scales) necessary to call [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] to avoid recompilation. Use [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] _before_ loading the first adapter, and `torch.compile` should be called _after_ loading the first adapter. ```python pipe = ... @@ -238,15 +237,13 @@ pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="defa # generate images with adapter 2 ``` -By calling the [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] method, diffusers makes it possible to hot swap the LoRA adapter without triggering recompilation. For this to work, call the method _before_ loading the first adapter. Also note that, as always, `torch.compile` has to be called _after_ loading the first adapter. +The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128. -The `target_rank=max_rank` argument is important to let diffusers know what will be the maximum rank among all LoRA adapters that will be loaded. So if you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. By default, this value is 128. If in doubt, prefer a higher value. - -Even after following these steps, there can be situations that will result in recompilation. Most notably, if the swapped in adapters targets more layers than the initial adapter, recompilation is needed. Try to load the adapter that targets most layers first. Read more about the limitations of hot swapping in the [PEFT documentation on hot swapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter). +However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature. -To detect if the model was recompiled, move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager. If you detect recompilation despite following all the steps above, please open an issue on the [diffusers GitHub repository](https://github.com/huggingface/diffusers/issues) with a reproducer. +Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [Diffusers](https://github.com/huggingface/diffusers/issues) with a reproducible example.