Skip to content

Commit 569ea69

Browse files
committed
ENH Improve HRA speed and docs (#2160)
1 parent 082e927 commit 569ea69

File tree

4 files changed

+56
-4
lines changed

4 files changed

+56
-4
lines changed

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@
116116
title: FourierFT
117117
- local: package_reference/vblora
118118
title: VB-LoRA
119+
- local: package_reference/hra
120+
title: HRA
119121

120122
title: Adapters
121123
- sections:

docs/source/conceptual_guides/adapter.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,16 @@ A set of of learnable adaption prompts are prefixed to the input instruction tok
105105
<small><a href="https://hf.co/papers/2303.16199">LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention</a></small>
106106

107107
To avoid adding noise to the tokens, the adapter uses zero-initialized attention. On top of this, the adapter adds a learnable gating factor (initialized with zeros) to progressively add information to the model during training. This prevents overwhelming the model's pretrained knowledge with the newly learned instructions.
108+
109+
## Householder Reflection Adaptation (HRA)
110+
111+
[HRA](https://huggingface.co/papers/2405.17484) provides a new perspective connecting LoRA to OFT, which means it can harness the advantages of both strategies, reduce parameters and computation costs while penalizing the loss of pre-training knowledge.
112+
113+
<div class="flex justify-center">
114+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/peft/hra.png"/>
115+
</div>
116+
<small><a href="https://huggingface.co/papers/2405.17484">Bridging The Gap between Low-rank and Orthogonal Adaptation via Householder Reflection Adaptation</a></small>
117+
118+
HRA constructs a chain of `r` trainable Householder reflections (HRs). Because the Householder reflection matrix is an orthogonal matrix and the product of orthogonal matrices is also an orthogonal matrix, HRA satisfies the theoretical guarantee of Orthogonal Finetuning (OFT). Meanwhile, HRA can also be viewed as an low-rank fine-tuning adapter by rewriting formula.
119+
120+
The higher `r`, the more trainable parameters, resulting in a larger model capacity and better performance. Besides, due to the chain structure, the orthogonality of HR planes impacts the capacity and regularity of HRA. To achieve a trade-off between the model capacity and regularity, an orthogonality regularizer of the HR planes is added to the loss function. The weight \\(\lambda\\) can control the strength of the regularizer.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Bridging The Gap between Low-rank and Orthogonal Adaptation via Householder Reflection Adaptation (HRA)
18+
19+
[HRA](https://huggingface.co/papers/2405.17484) is a simple but effective adapter-based fine-tuning method by leveraging Householder reflections. This method harnesses the advantages of both strategies, reducing parameters and computation costs while penalizing the loss of pre-training knowledge. It consistently achieves better performance with fewer trainable parameters and outperforms state-of-the-art adapters across different models, including large language models (LLMs) and conditional image generators.
20+
21+
22+
The abstract from the paper is:
23+
24+
> While following different technical routes, both low-rank and orthogonal adaptation techniques can efficiently adapt large-scale pre-training models in specific tasks or domains based on a small piece of trainable parameters. In this study, we bridge the gap between these two techniques, proposing a simple but effective adaptation method based on Householder reflections. Given a pre-trained model, our method fine-tunes its layers by multiplying each frozen weight matrix with an orthogonal matrix constructed by a chain of learnable Householder reflections (HRs). This HR-based orthogonal fine-tuning is equivalent to an adaptive low-rank adaptation. Moreover, we show that the orthogonality of the reflection planes corresponding to the HRs impacts the model capacity and regularity. The analysis motivates us to regularize the orthogonality of the HRs, leading to different implementations of the proposed Householder reflection adaptation (HRA) method. Compared with state-of-the-art methods, HRA achieves superior performance with fewer learnable parameters when adapting large language models and conditional image generators. The code is available at [peft](https://github.com/huggingface/peft/tree/main/src/peft/tuners/hra) and [HRA](https://github.com/DaShenZi721/HRA).
25+
26+
## HRAConfig
27+
28+
[[autodoc]] tuners.hra.config.HRAConfig
29+
30+
## HRAModel
31+
32+
[[autodoc]] tuners.hra.model.HRAModel

src/peft/tuners/hra/layer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Te
220220

221221
for i in indices:
222222
ui = opt_u[:, i].view(-1, 1)
223-
weight = weight @ (torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * ui @ ui.t())
223+
weight = weight - 2 * weight @ ui @ ui.t()
224224

225225
return weight
226226

@@ -384,7 +384,7 @@ def get_delta_weight(self, adapter_name: str, reverse: bool = False) -> torch.Te
384384

385385
for i in indices:
386386
ui = opt_u[:, i].view(-1, 1)
387-
weight = weight @ (torch.eye(shape[0], device=opt_u.device, dtype=opt_u.dtype) - 2 * ui @ ui.t())
387+
weight = weight - 2 * weight @ ui @ ui.t()
388388

389389
return weight
390390

@@ -399,7 +399,9 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
399399
result = self.base_layer(x, *args, **kwargs)
400400
else:
401401
new_weight = torch.eye(
402-
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0], device=x.device
402+
self.in_features * self.base_layer.kernel_size[0] * self.base_layer.kernel_size[0],
403+
device=x.device,
404+
dtype=previous_dtype,
403405
)
404406
for active_adapter in self.active_adapters:
405407
if active_adapter not in self.hra_u.keys():
@@ -416,7 +418,10 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
416418
)
417419
new_weight = torch.mm(orig_weight, new_weight)
418420
new_weight = new_weight.view(
419-
self.out_features, self.in_features, self.base_layer.kernel_size[0], self.base_layer.kernel_size[0]
421+
self.out_features,
422+
self.in_features,
423+
self.base_layer.kernel_size[0],
424+
self.base_layer.kernel_size[0],
420425
)
421426

422427
result = F.conv2d(

0 commit comments

Comments
 (0)