Skip to content

Commit 29adf0d

Browse files
committed
FEAT Add hotswapping functionality (#2120)
The idea of hotswapping an adapter is the following: We can already load multiple adapters, e.g. two LoRAs, at the same time. But sometimes, we want to load one LoRA and then replace its weights in-place with the LoRA weights of another adapter. This is now possible the hotswap_adapter function. In general, this should be faster than deleting one adapter and loading the adapter in its place, which would be the current way to achieve the same final outcome. Another advantage of hotswapping is that it prevents re-compilation in case the PEFT model is already compiled. This can save quite a lot of time. There are some caveats for hotswapping: - It only works for the same PEFT method, so no swapping LoRA and LoHa. - Right now, only LoRA is properly supported. - The adapters must be compatible (e.g. same LoRA alpha, same target modules). - To avoid recompilation, ranks must be identical See also huggingface/diffusers#9453
1 parent e1997b8 commit 29adf0d

File tree

7 files changed

+811
-26
lines changed

7 files changed

+811
-26
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,7 @@
125125
title: Model merge
126126
- local: package_reference/helpers
127127
title: Helpers
128+
- local: package_reference/hotswap
129+
title: Hotswapping adapters
128130
title: Utilities
129131
title: API reference
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
<!--⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
2+
rendered properly in your Markdown viewer.
3+
-->
4+
5+
# Hotswapping adapters
6+
7+
The idea of hotswapping an adapter is the following: We can already load multiple adapters, e.g. two LoRAs, at the same time. But sometimes, we want to load one LoRA and then replace its weights in-place with the LoRA weights of another adapter. This is now possible the `hotswap_adapter` function.
8+
9+
In general, this should be faster than deleting one adapter and loading the adapter in its place, which would be the how to achieve the same final outcome without hotswapping. Another advantage of hotswapping is that it prevents re-compilation in case the PEFT model is already compiled using `torch.compile`. This can save quite a lot of time.
10+
11+
```python
12+
import torch
13+
from transformers import AutoModelForCausalLM
14+
from peft import PeftModel
15+
from peft.utils.hotswap import hotswap_adapter
16+
17+
model_id = ...
18+
inputs = ...
19+
device = ...
20+
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
21+
22+
# load lora 0
23+
model = PeftModel.from_pretrained(model, <path-adapter-0>)
24+
model = torch.compile(model) # optionally compile the model
25+
with torch.inference_mode():
26+
output_adapter_0 = model(inputs)
27+
28+
# replace the "default" lora adapter with the new one
29+
hotswap_adapter(model, <path-adapter-1>, adapter_name="default", torch_device=device)
30+
with torch.inference_mode():
31+
output_adapter_1 = model(inputs).logits
32+
```
33+
34+
Hotswapping works with transformers models and diffusers models. However, there are some caveats:
35+
36+
- It only works for the same PEFT method, so no swapping LoRA and LoHa, for example.
37+
- Right now, only LoRA is properly supported.
38+
- The adapters must be compatible (e.g. same LoRA alpha, same target modules).
39+
- If you use `torch.compile` and want to avoid recompilation, the LoRA rank must be the same.
40+
41+
[[autodoc]] utils.hotswap.hotswap_adapter
42+
- all
43+
44+
[[autodoc]] utils.hotswap.hotswap_adapter_from_state_dict
45+
- all

src/peft/utils/hotswap.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Copyright 2024-present the HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from operator import attrgetter
17+
18+
import torch
19+
20+
from peft.config import PeftConfig
21+
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
22+
23+
from .constants import PEFT_TYPE_TO_PREFIX_MAPPING
24+
from .other import infer_device
25+
from .peft_types import PeftType
26+
from .save_and_load import _insert_adapter_name_into_state_dict, load_peft_weights
27+
28+
29+
# so far only LoRA is supported
30+
CONFIG_KEYS_TO_CHECK = {PeftType.LORA: ["lora_alpha", "use_rslora", "lora_dropout", "alpha_pattern", "use_dora"]}
31+
32+
33+
def hotswap_adapter_from_state_dict(model, state_dict, adapter_name, parameter_prefix="lora_"):
34+
"""
35+
Swap out the adapter weights from the model with the weights from state_dict.
36+
37+
As of now, only LoRA is supported.
38+
39+
This is a low-level function that assumes that the adapters have been checked for compatibility and that the
40+
state_dict has been correctly mapped to work with PEFT. For a high level function that performs this work for you,
41+
use `hotswap_adapter` instead.
42+
43+
Args:
44+
model (`nn.Module`):
45+
The model with the loaded adapter.
46+
state_dict (`dict[str, torch.Tensor]`):
47+
The state dict of the new adapter, which needs to be compatible (targeting same modules etc.).
48+
adapter_name (`str`):
49+
The name of the adapter that should be hot-swapped, e.g. `"default"`. The name will remain the same after
50+
swapping.
51+
parameter_prefix (`str`, *optional*, defaults to `"lora_"`)
52+
The prefix used to identify the adapter's keys in the state dict. For LoRA, this would be `"lora_"` (the
53+
default).
54+
55+
Raises:
56+
RuntimeError
57+
If the old and the new adapter are not compatible, a RuntimeError is raised.
58+
59+
"""
60+
# Ensure that all the keys of the new adapter correspond exactly to the keys of the old adapter, otherwise
61+
# hot-swapping is not possible
62+
63+
is_compiled = hasattr(model, "_orig_mod")
64+
# TODO: there is probably a more precise way to identify the adapter keys
65+
missing_keys = {k for k in model.state_dict() if (parameter_prefix in k) and (adapter_name in k)}
66+
unexpected_keys = set()
67+
68+
# first: dry run, not swapping anything
69+
for key, new_val in state_dict.items():
70+
try:
71+
old_val = attrgetter(key)(model)
72+
except AttributeError:
73+
unexpected_keys.add(key)
74+
continue
75+
76+
if is_compiled:
77+
missing_keys.remove("_orig_mod." + key)
78+
else:
79+
missing_keys.remove(key)
80+
81+
if missing_keys or unexpected_keys:
82+
msg = "Hot swapping the adapter did not succeed."
83+
if missing_keys:
84+
msg += f" Missing keys: {', '.join(sorted(missing_keys))}."
85+
if unexpected_keys:
86+
msg += f" Unexpected keys: {', '.join(sorted(unexpected_keys))}."
87+
raise RuntimeError(msg)
88+
89+
# actual swapping
90+
for key, new_val in state_dict.items():
91+
# no need to account for potential _orig_mod in key here, as torch handles that
92+
old_val = attrgetter(key)(model)
93+
if is_compiled:
94+
# Compiled models don't work with swap_tensors because there are weakrefs for the tensor. It is unclear if
95+
# this workaround could not cause trouble but the tests indicate that it works.
96+
old_val.data = new_val.data
97+
else:
98+
torch.utils.swap_tensors(old_val, new_val)
99+
100+
101+
def _check_hotswap_configs_compatible(config0: PeftConfig, config1: PeftConfig) -> None:
102+
"""
103+
Check if two configs are compatible for hot-swapping.
104+
105+
Only LoRA parameters are checked for now.
106+
107+
To hot-swap two adapters, their configs must be compatible. Otherwise, the results could be false. E.g. if they use
108+
different alpha values, after hot-swapping, the alphas from the first adapter would still be used with the weights
109+
from the 2nd adapter, which would result in incorrect behavior. There is probably a way to swap these values as
110+
well, but that's not implemented yet, and we need to be careful not to trigger re-compilation if the model is
111+
compiled (so no modification of the dict).
112+
113+
"""
114+
115+
if config0.peft_type != config1.peft_type:
116+
msg = f"Incompatible PEFT types found: {config0.peft_type.value} and {config1.peft_type.value}"
117+
raise ValueError(msg)
118+
119+
if config0.peft_type not in CONFIG_KEYS_TO_CHECK:
120+
msg = (
121+
f"Hotswapping only supports {', '.join(CONFIG_KEYS_TO_CHECK.keys())} but "
122+
f"{config0.peft_type.value} was passed."
123+
)
124+
raise ValueError(msg)
125+
config_keys_to_check = CONFIG_KEYS_TO_CHECK[config0.peft_type]
126+
127+
# TODO: This is a very rough check only for LoRA at the moment. Also, there might be some options that don't
128+
# necessarily require an error.
129+
config0 = config0.to_dict()
130+
config1 = config1.to_dict()
131+
sentinel = object()
132+
for key in config_keys_to_check:
133+
val0 = config0.get(key, sentinel)
134+
val1 = config1.get(key, sentinel)
135+
if val0 != val1:
136+
raise ValueError(f"Configs are incompatible: for {key}, {val0} != {val1}")
137+
138+
139+
def hotswap_adapter(model, model_name_or_path, adapter_name, torch_device=None, **kwargs):
140+
"""Substitute old adapter data with new adapter data, keeping the rest the same.
141+
142+
As of now, only LoRA is supported.
143+
144+
This function is useful when you want to replace the loaded adapter with a new adapter. The adapter name will
145+
remain the same, but the weights and other parameters will be swapped out.
146+
147+
If the adapters are incomptabile, e.g. targeting different layers or having different alpha values, an error will
148+
be raised.
149+
150+
Example:
151+
152+
```py
153+
>>> import torch
154+
>>> from transformers import AutoModelForCausalLM
155+
>>> from peft import PeftModel
156+
>>> from peft.utils.hotswap import hotswap_adapter
157+
158+
>>> model_id = ...
159+
>>> inputs = ...
160+
>>> device = ...
161+
>>> model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
162+
163+
>>> # load lora 0
164+
>>> model = PeftModel.from_pretrained(model, "path-adapter-0")
165+
>>> model = torch.compile(model) # optionally compile the model
166+
>>> with torch.inference_mode():
167+
... output_adapter_0 = model(inputs)
168+
169+
>>> # replace the "default" lora adapter with the new one
170+
>>> hotswap_adapter(model, "path-adapter-1", adapter_name="default", torch_device=device)
171+
>>> with torch.inference_mode():
172+
... output_adapter_1 = model(inputs).logits
173+
```
174+
175+
Args:
176+
model ([`~PeftModel`]):
177+
The PEFT model with the loaded adapter.
178+
model_name_or_path (`str`):
179+
The name or path of the model to load the new adapter from.
180+
adapter_name (`str`):
181+
The name of the adapter to swap, e.g. `"default"`. The name will stay the same after swapping.
182+
torch_device: (`str`, *optional*, defaults to None):
183+
The device to load the new adapter onto.
184+
**kwargs (`optional`):
185+
Additional keyword arguments used for loading the config and weights.
186+
187+
"""
188+
if torch_device is None:
189+
torch_device = infer_device()
190+
191+
############################
192+
# LOAD CONFIG AND VALIDATE #
193+
############################
194+
195+
config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[
196+
PeftConfig._get_peft_type(
197+
model_name_or_path,
198+
subfolder=kwargs.get("subfolder", None),
199+
revision=kwargs.get("revision", None),
200+
cache_dir=kwargs.get("cache_dir", None),
201+
use_auth_token=kwargs.get("use_auth_token", None),
202+
token=kwargs.get("token", None),
203+
)
204+
]
205+
config = config_cls.from_pretrained(model_name_or_path, **kwargs)
206+
# config keys that could affect the model output besides what is determined by the state_dict
207+
_check_hotswap_configs_compatible(model.active_peft_config, config)
208+
209+
state_dict = load_peft_weights(model_name_or_path, device=torch_device, **kwargs)
210+
211+
###########################
212+
# LOAD & REMAP STATE_DICT #
213+
###########################
214+
215+
parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type]
216+
peft_model_state_dict = _insert_adapter_name_into_state_dict(
217+
state_dict, adapter_name=adapter_name, parameter_prefix=parameter_prefix
218+
)
219+
220+
hotswap_adapter_from_state_dict(
221+
model=model,
222+
state_dict=peft_model_state_dict,
223+
adapter_name=adapter_name,
224+
parameter_prefix=parameter_prefix,
225+
)

src/peft/utils/save_and_load.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,25 @@ def _find_mismatched_keys(
305305
return peft_model_state_dict, mismatched
306306

307307

308+
def _insert_adapter_name_into_state_dict(
309+
state_dict: dict[str, torch.Tensor], adapter_name: str, parameter_prefix: str
310+
) -> dict[str, torch.Tensor]:
311+
"""Utility function to remap the state_dict keys to fit the PEFT model by inserting the adapter name."""
312+
peft_model_state_dict = {}
313+
for key, val in state_dict.items():
314+
if parameter_prefix in key:
315+
suffix = key.split(parameter_prefix)[1]
316+
if "." in suffix:
317+
suffix_to_replace = ".".join(suffix.split(".")[1:])
318+
key = key.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
319+
else:
320+
key = f"{key}.{adapter_name}"
321+
peft_model_state_dict[key] = val
322+
else:
323+
peft_model_state_dict[key] = val
324+
return peft_model_state_dict
325+
326+
308327
def set_peft_model_state_dict(
309328
model,
310329
peft_model_state_dict,
@@ -342,21 +361,7 @@ def set_peft_model_state_dict(
342361
else:
343362
state_dict = peft_model_state_dict
344363

345-
if config.peft_type in (
346-
PeftType.LORA,
347-
PeftType.LOHA,
348-
PeftType.LOKR,
349-
PeftType.ADALORA,
350-
PeftType.IA3,
351-
PeftType.OFT,
352-
PeftType.POLY,
353-
PeftType.LN_TUNING,
354-
PeftType.BOFT,
355-
PeftType.VERA,
356-
PeftType.FOURIERFT,
357-
PeftType.HRA,
358-
PeftType.VBLORA,
359-
):
364+
if config.peft_type in PEFT_TYPE_TO_PREFIX_MAPPING:
360365
peft_model_state_dict = {}
361366
parameter_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type]
362367
if config.peft_type == PeftType.VBLORA and config.save_only_topk_weights:
@@ -386,17 +391,10 @@ def set_peft_model_state_dict(
386391
# delete the topk_indices and topk_weights from the state_dict
387392
del state_dict[k]
388393
del state_dict[k.replace("_topk_indices", "_topk_weights")]
389-
for k, v in state_dict.items():
390-
if parameter_prefix in k:
391-
suffix = k.split(parameter_prefix)[1]
392-
if "." in suffix:
393-
suffix_to_replace = ".".join(suffix.split(".")[1:])
394-
k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
395-
else:
396-
k = f"{k}.{adapter_name}"
397-
peft_model_state_dict[k] = v
398-
else:
399-
peft_model_state_dict[k] = v
394+
395+
peft_model_state_dict = _insert_adapter_name_into_state_dict(
396+
state_dict, adapter_name=adapter_name, parameter_prefix=parameter_prefix
397+
)
400398

401399
if config.peft_type == PeftType.ADALORA:
402400
rank_pattern = config.rank_pattern

0 commit comments

Comments
 (0)