-
-
Notifications
You must be signed in to change notification settings - Fork 11.9k
[Model] Activated LoRA #19710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Model] Activated LoRA #19710
Changes from 23 commits
8a9610e
a68e70b
b254fb7
3897b1b
24ff376
412eacd
32098e4
fb6d28e
5f62d8b
f9396b0
6f36f6d
c6ffe8f
4a4b568
4cbef84
49a5bdc
a9ac26d
5c2e181
ceae7c7
99b8b60
477ab6e
91f39d1
5abbb78
0a20f2a
438ab6f
51edf96
a9d5986
6fbc108
cb373e9
24dfc4a
6c1b46a
b8444d9
4e513cc
b9df31f
643d893
199ee89
03b6480
76744da
6b83cc4
18397d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import os | ||
| import time | ||
|
|
||
| import torch | ||
| from huggingface_hub import snapshot_download | ||
|
|
||
| from vllm import LLM, SamplingParams | ||
| from vllm.lora.request import LoRARequest | ||
|
|
||
| BASE_NAME = "ibm-granite/granite-3.2-8b-instruct" | ||
|
|
||
| ALORA_NAME = "ibm-granite/granite-3.2-8b-alora-uncertainty" | ||
| invocation_string = "<|start_of_role|>certainty<|end_of_role|>" | ||
|
|
||
| os.environ["VLLM_USE_V1"] = "1" | ||
|
|
||
| # download your LoRA adapter to ~/.cache/huggingface/… | ||
| alora_path = snapshot_download(repo_id=ALORA_NAME) | ||
|
|
||
| print(alora_path) | ||
| ####################################### | ||
|
|
||
|
|
||
| llm = LLM( | ||
| model=BASE_NAME, | ||
| enable_lora=True, | ||
| enable_activated_lora=True, | ||
| dtype=torch.bfloat16, | ||
| max_lora_rank=64, | ||
| ) | ||
|
|
||
| prompts = [ | ||
| ( | ||
| "<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>\n" | ||
| "<|start_of_role|>assistant<|end_of_role|>" | ||
| ), | ||
| ] | ||
|
|
||
| sampling_params = SamplingParams(temperature=0, max_tokens=600) | ||
|
|
||
| outputsBase = llm.generate( | ||
| prompts, | ||
| sampling_params, | ||
| use_tqdm=False, | ||
| ) | ||
| generated_text = [] | ||
| for output in outputsBase: | ||
| prompt = output.prompt | ||
| generated_text += [output.outputs[0].text] | ||
| print(f"Prompt: {prompt!r}, Generated text: {generated_text[-1]!r}") | ||
|
|
||
| prompts_alora = [ | ||
| x + y + "<|end_of_text|>\n" + invocation_string | ||
| for x, y in zip(prompts, generated_text) | ||
| ] | ||
|
|
||
| sampling_params = SamplingParams(temperature=0, max_tokens=10) | ||
|
|
||
| t0 = time.time() | ||
| outputs = llm.generate( | ||
| prompts_alora, | ||
| sampling_params, | ||
| lora_request=LoRARequest("UQ_adapter", 1, alora_path), | ||
| use_tqdm=False, | ||
| ) | ||
| t = time.time() - t0 | ||
| print(f"Time: {t}") | ||
|
|
||
| for output in outputs: | ||
| prompt = output.prompt | ||
| generated_text = output.outputs[0].text | ||
| print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| tensor_model_parallel_all_gather, | ||
| tensor_model_parallel_all_reduce) | ||
| from vllm.distributed.utils import divide | ||
| from vllm.forward_context import get_forward_context | ||
| # yapf: disable | ||
| from vllm.model_executor.layers.linear import (ColumnParallelLinear, | ||
| LinearBase, | ||
|
|
@@ -873,7 +874,8 @@ def can_replace_layer( | |
| model_config: Optional[PretrainedConfig], | ||
| ) -> bool: | ||
| return (type(source_layer) is QKVParallelLinear | ||
| and len(packed_modules_list) == 3) | ||
| and len(packed_modules_list) == 3 | ||
| and not lora_config.activated_lora_enabled) | ||
|
|
||
|
|
||
| #TODO: Implement this | ||
|
|
@@ -1190,3 +1192,50 @@ def can_replace_layer( | |
| ) -> bool: | ||
| # Special handling for the LogitsProcessor. | ||
| return False | ||
|
|
||
|
|
||
| class MergedQKVParallelLinearWithActivatedLoRA(MergedQKVParallelLinearWithLoRA | ||
| ): | ||
|
|
||
| def apply(self, | ||
| x: torch.Tensor, | ||
| bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
| output = self.base_layer.quant_method.apply(self.base_layer, x, bias) | ||
|
|
||
| # In transformers backend, x and output have extra batch dimension like | ||
| # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), | ||
| # therefore we need to flatten the batch dimensions. | ||
| if x.ndim == 3 and output.ndim == 3: | ||
| output = output.flatten(0, 1) | ||
| x = x.flatten(0, 1) | ||
|
||
|
|
||
| # Extract aLoRA batch metadata from forward context | ||
| alora_metadata = get_forward_context().alora_metadata | ||
|
|
||
| mask1d = alora_metadata.mask1d | ||
| mask2d = mask1d.unsqueeze(1).to(output.dtype) | ||
|
|
||
| # Clone base layer output before running LoRA | ||
| orig_out = output.clone() | ||
tdoublep marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Apply LoRA in‐place on `output`: | ||
| self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, | ||
| self.lora_b_stacked, | ||
| self.lora_bias_stacked, 1.0, | ||
| self.output_slices) | ||
| # Apply alora mask | ||
| final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d) | ||
| return final_output | ||
|
|
||
| @classmethod | ||
| def can_replace_layer( | ||
| cls, | ||
| source_layer: nn.Module, | ||
| lora_config: LoRAConfig, | ||
| packed_modules_list: list, | ||
| model_config: Optional[PretrainedConfig], | ||
| ) -> bool: | ||
| """Returns True if the layer can be replaced by this LoRA layer.""" | ||
| return (type(source_layer) is QKVParallelLinear | ||
| and len(packed_modules_list) == 3 | ||
| and lora_config.activated_lora_enabled) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs | ||
| from vllm.inputs.parse import split_enc_dec_inputs | ||
| from vllm.inputs.preprocess import InputPreprocessor | ||
| from vllm.lora.peft_helper import PEFTHelper | ||
| from vllm.lora.request import LoRARequest | ||
| from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry | ||
| from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange | ||
|
|
@@ -328,6 +329,46 @@ | |
| sorted_mm_hashes, | ||
| ) | ||
|
|
||
| # Handle aLoRA invocation sequence if applicable. | ||
| if (self.lora_config and self.lora_config.activated_lora_enabled | ||
| and lora_request is not None): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe also check if it is a
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't actually know that until we've called the PEFTHelper below to look at the adapter config |
||
|
|
||
| text_config = self.model_config.hf_config.get_text_config() | ||
|
|
||
| peft_helper = PEFTHelper.from_local_dir( | ||
| lora_request.lora_path, text_config.max_position_embeddings, | ||
| lora_request.tensorizer_config_dict) | ||
| if peft_helper.alora_invocation_tokens is not None or peft_helper.invocation_string is not None: | ||
| if peft_helper.alora_invocation_tokens is not None: | ||
| invocation_tokens = peft_helper.alora_invocation_tokens | ||
| delta = 0 | ||
| elif peft_helper.invocation_string is not None: #backwards compatibility | ||
| invocation_tokens = self.input_preprocessor._tokenize_prompt( | ||
| peft_helper.invocation_string, | ||
| lora_request=lora_request, | ||
| tokenization_kwargs=tokenization_kwargs) | ||
| delta = 1 | ||
| invocation_start = -1 | ||
| n = len(invocation_tokens) | ||
tdoublep marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| token_ids = decoder_inputs["prompt_token_ids"] | ||
|
|
||
| if n > 0 and len(token_ids) >= n: | ||
| # scan backward for the last match | ||
| # (faster than full forward scan+max) | ||
| for idx in range(len(token_ids) - n, -1, -1): | ||
| if token_ids[idx:idx + n] == invocation_tokens: | ||
| # weights activated after start | ||
| invocation_start = idx + delta | ||
| break | ||
|
|
||
| if invocation_start == -1: | ||
| raise ValueError( | ||
| "Invocation sequence not found in prompt " | ||
| f"for request '{request_id}'. aLoRA models require the " | ||
| "invocation tokens to be present in the input.") | ||
|
|
||
| lora_request.invocation_start = invocation_start | ||
|
|
||
| return decoder_inputs.get("prompt"), EngineCoreRequest( | ||
| request_id=request_id, | ||
| prompt_token_ids=decoder_inputs["prompt_token_ids"], | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.