Skip to content

Commit 95e021c

Browse files
ENH Add support for LoRA hotswapping
LoRA hotswapping has been available in PEFT since 0.15.0. There is already a diffusers integration (huggingface/diffusers#9453), but the transformers integration was still missing this feature. This PR remedies this. Hotswapping allows to swap different LoRA adapters in-place instead of loading multiple adapters and switchint between them. Not only can this be advantageous to safe memory and potentially for quicker loading, the biggest advantage is that if the model is compiled, we can hotswap without triggering recompilation (loading a separate adapter would require recompilation). There are some caveats to using this feature, most notably that only LoRA is supported. This was fine for diffusers, as it only works with LoRA, but the transformers integration works with other PEFT methods too. However, LoRA should be by far the most common method, so this should be fine for now. This and other caveats have been documented.
1 parent 1d91a8a commit 95e021c

File tree

3 files changed

+349
-7
lines changed

3 files changed

+349
-7
lines changed

docs/source/en/peft.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,45 @@ model.enable_adapters()
151151
# disable all adapters
152152
model.disable_adapters()
153153
```
154+
155+
## Hotswapping adapters
156+
157+
A common use case when serving multiple adapters is to load one adapter first, generate output, load another adapter, generate more outputs, load another adapter, etc. This can be inefficient, since each time a new adapter is loaded, new memory is reserved; moreover, if the model is compiled with `torch.compile`, it needs to be re-compiled each time a new adapter is used. When switching frequently, the compilation time may never be amortized.
158+
159+
To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and, in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter. Note that other PEFT methods are not supported yet, only LoRA.
160+
161+
Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter (`"default"` is the default adapter name) to be swapped.
162+
163+
```python
164+
model = AutoModel.from_pretrained(...)
165+
# load adapter 1 as normal
166+
model.load_adapter(file_name_adapter_1)
167+
# generate outputs with adapter 1
168+
...
169+
# now hotswap the 2nd adapter
170+
model.load_adapter(file_name_adapter_2, hotswap=True, adapter_name="default")
171+
# generate outputs with adapter 2
172+
```
173+
174+
For compiled models, it is often necessary to call [`~integrations.peft.PeftAdapterMixin.enable_peft_hotswap`] to avoid recompilation. Call this method _before_ loading the first adapter, while `torch.compile` should be called _after_ loading the first adapter.
175+
176+
```python
177+
model = AutoModel.from_pretrained(...)
178+
max_rank = ... # the highest rank among all LoRAs that you want to load
179+
# call *before* compiling and loading the LoRA adapter
180+
model.enable_peft_hotswap(target_rank=max_rank)
181+
model.load_adapter(file_name_1, adapter_name="default")
182+
# optionally compile the model now
183+
model = torch.compile(model, ...)
184+
output_1 = model(...)
185+
# now you can hotswap the 2nd adapter, use the same name as for the 1st
186+
model.load_adapter(file_name_2, adapter_name="default", hotswap=True)
187+
output_2 = model(...)
188+
```
189+
190+
The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128.
191+
192+
However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature.
193+
194+
> [!Tip]
195+
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [PEFT](https://github.com/huggingface/peft/issues) with a reproducible example.

src/transformers/integrations/peft.py

Lines changed: 134 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import inspect
1717
import re
1818
import warnings
19-
from typing import Any, Optional, Union
19+
from typing import Any, Literal, Optional, Union
2020

2121
from packaging import version
2222

@@ -89,6 +89,7 @@ class PeftAdapterMixin:
8989
"""
9090

9191
_hf_peft_config_loaded = False
92+
_prepare_peft_hotswap_kwargs: Optional[dict] = None
9293

9394
def load_adapter(
9495
self,
@@ -104,6 +105,7 @@ def load_adapter(
104105
adapter_state_dict: Optional[dict[str, "torch.Tensor"]] = None,
105106
low_cpu_mem_usage: bool = False,
106107
is_trainable: bool = False,
108+
hotswap: bool = False,
107109
adapter_kwargs: Optional[dict[str, Any]] = None,
108110
) -> None:
109111
"""
@@ -162,12 +164,52 @@ def load_adapter(
162164
is_trainable (`bool`, *optional*, defaults to `False`):
163165
Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and can only be
164166
used for inference.
167+
hotswap : (`bool`, *optional*, defaults to `False`)
168+
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. This means
169+
that, instead of loading an additional adapter, this will take the existing adapter weights and replace
170+
them with the weights of the new adapter. This can be faster and more memory efficient. However, the
171+
main advantage of hotswapping is that when the model is compiled with torch.compile, loading the new
172+
adapter does not require recompilation of the model. When using hotswapping, the passed `adapter_name`
173+
should be the name of an already loaded adapter.
174+
175+
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
176+
to call an additional method before loading the adapter:
177+
178+
```py
179+
model = AutoModel.from_pretrained(...)
180+
max_rank = ... # the highest rank among all LoRAs that you want to load
181+
# call *before* compiling and loading the LoRA adapter
182+
model.enable_peft_hotswap(target_rank=max_rank)
183+
model.load_adapter(file_name_1, adapter_name="default")
184+
# optionally compile the model now
185+
model = torch.compile(model, ...)
186+
output_1 = model(...)
187+
# now you can hotswap the 2nd adapter, use the same name as for the 1st
188+
model.load_adapter(file_name_2, adapter_name="default", hotswap=True)
189+
output_2 = model(...)
190+
```
191+
192+
Note that hotswapping comes with a couple of limitations documented here:
193+
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
165194
adapter_kwargs (`dict[str, Any]`, *optional*):
166195
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
167196
`find_adapter_config_file` method.
168197
"""
198+
from peft import PeftType
199+
169200
check_peft_version(min_version=MIN_PEFT_VERSION)
170201

202+
if hotswap:
203+
min_version_hotswap = "0.15.0"
204+
if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap):
205+
raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.")
206+
if (not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config):
207+
raise ValueError(
208+
"To hotswap an adapter, there must already be an existing adapter with the same adapter name."
209+
)
210+
if any(conf.peft_type != PeftType.LORA for conf in self.peft_config.values()):
211+
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")
212+
171213
# peft only supports low_cpu_mem_usage starting from v0.13.0
172214
peft_load_kwargs = {}
173215
key_mapping = adapter_kwargs.pop("key_mapping", None) if adapter_kwargs is not None else None
@@ -190,8 +232,12 @@ def load_adapter(
190232
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
191233
from peft.utils import set_peft_model_state_dict
192234

193-
if self._hf_peft_config_loaded and adapter_name in self.peft_config:
235+
if self._hf_peft_config_loaded and (not hotswap) and (adapter_name in self.peft_config):
194236
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
237+
elif hotswap and ((not self._hf_peft_config_loaded) or (adapter_name not in self.peft_config)):
238+
raise ValueError(
239+
"To hotswap an adapter, there must already be an existing adapter with the same adapter name."
240+
)
195241

196242
if peft_model_id is None and (adapter_state_dict is None and peft_config is None):
197243
raise ValueError(
@@ -240,8 +286,12 @@ def load_adapter(
240286
)
241287
peft_config.inference_mode = not is_trainable
242288

243-
# Create and add fresh new adapters into the model.
244-
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
289+
if peft_config.peft_type != PeftType.LORA:
290+
raise ValueError("Hotswapping is currently only supported for LoRA, please set `hotswap=False`.")
291+
292+
if not hotswap:
293+
# Create and add fresh new adapters into the model, unless the weights are hotswapped
294+
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
245295

246296
if not self._hf_peft_config_loaded:
247297
self._hf_peft_config_loaded = True
@@ -264,12 +314,49 @@ def load_adapter(
264314
# Early exit of the loop
265315
if n_replace > 0:
266316
break
317+
318+
# For hotswapping, we need the adapter name to be present in the state dict keys
319+
if hotswap:
320+
if key.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
321+
new_key = new_key[: -len(".weight")] + f".{adapter_name}.weight"
322+
elif key.endswith("lora_B.bias"): # lora_bias=True option
323+
new_key = new_key[: -len(".bias")] + f".{adapter_name}.bias"
267324
processed_adapter_state_dict[new_key] = value
268325

269326
# Load state dict
270-
incompatible_keys = set_peft_model_state_dict(
271-
self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs
272-
)
327+
if not hotswap:
328+
incompatible_keys = set_peft_model_state_dict(
329+
self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs
330+
)
331+
332+
if self._prepare_peft_hotswap_kwargs is not None:
333+
# For hotswapping of compiled models or adapters with different ranks.
334+
# If the user called enable_peft_hotswap, we need to ensure it is called:
335+
# - after the first adapter was loaded
336+
# - before the model is compiled and the 2nd adapter is being hotswapped in
337+
# Therefore, it needs to be called here
338+
from peft.utils.hotswap import prepare_model_for_compiled_hotswap
339+
340+
prepare_model_for_compiled_hotswap(
341+
self, config=peft_config, **self._prepare_peft_hotswap_kwargs
342+
)
343+
# We only want to call prepare_model_for_compiled_hotswap once
344+
self._prepare_peft_hotswap_kwargs = None
345+
else:
346+
from peft.utils.hotswap import check_hotswap_configs_compatible, hotswap_adapter_from_state_dict
347+
348+
check_hotswap_configs_compatible(self.peft_config[adapter_name], peft_config)
349+
try:
350+
hotswap_adapter_from_state_dict(
351+
model=self,
352+
state_dict=processed_adapter_state_dict,
353+
adapter_name=adapter_name,
354+
config=peft_config,
355+
)
356+
except Exception as e:
357+
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
358+
raise
359+
incompatible_keys = None
273360

274361
if incompatible_keys is not None:
275362
err_msg = ""
@@ -311,6 +398,46 @@ def load_adapter(
311398
offload_index=offload_index,
312399
)
313400

401+
def enable_peft_hotswap(
402+
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
403+
) -> None:
404+
"""Enables the possibility to hotswap PEFT adapters with different ranks, or, if the model is compiled, without
405+
triggering recompilation.
406+
407+
Right now, hotswapping is only supported for LoRA.
408+
409+
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
410+
the loaded adapters differ. If the ranks are all identical and the model is not compiled, hotswapping works
411+
without calling this method first.
412+
413+
Args:
414+
target_rank (`int`, *optional*, defaults to `128`):
415+
The highest rank among all the adapters that will be loaded.
416+
check_compiled (`str`, *optional*, defaults to `"error"`):
417+
How to handle the case when the model is already compiled, which should generally be avoided. The
418+
options are:
419+
- "error" (default): raise an error
420+
- "warn": issue a warning
421+
- "ignore": do nothing
422+
"""
423+
min_version_hotswap = "0.15.0"
424+
if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_hotswap):
425+
raise ValueError(f"To hotswap the adapter, you need PEFT >= v{min_version_hotswap}.")
426+
427+
if getattr(self, "peft_config", {}):
428+
if check_compiled == "error":
429+
raise RuntimeError("Call `enable_peft_hotswap` before loading the first adapter.")
430+
elif check_compiled == "warn":
431+
logger.warning(
432+
"It is recommended to call `enable_peft_hotswap` before loading the first adapter to avoid recompilation."
433+
)
434+
elif check_compiled != "ignore":
435+
raise ValueError(
436+
f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead."
437+
)
438+
439+
self._prepare_peft_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}
440+
314441
def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None:
315442
r"""
316443
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT

0 commit comments

Comments
 (0)