Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
8a9610e
Initial working implementation of a-LoRA.
tdoublep Jun 16, 2025
a68e70b
Fix type hint for query_start_locs
tdoublep Jun 17, 2025
b254fb7
vllm/model_executor/layers/linear.py: add comment on torch.compile
tdoublep Jun 18, 2025
3897b1b
vllm/v1/worker/gpu_model_runner.py: remove print statement
tdoublep Jun 18, 2025
24ff376
vllm/v1/core/sched/scheduler.py: remove debug code
tdoublep Jun 18, 2025
412eacd
vllm/envs.py
tdoublep Jun 18, 2025
32098e4
Inject aLoRA behaviour via mixin
tdoublep Jun 18, 2025
fb6d28e
Simpler implementation without mixin
tdoublep Jun 18, 2025
5f62d8b
Scan for invocation tokens in one place
tdoublep Jun 18, 2025
f9396b0
Just use single field in request
tdoublep Jun 18, 2025
6f36f6d
Use peft_helper instead of reading files directly
tdoublep Jun 19, 2025
c6ffe8f
Remove online example for now.
tdoublep Jun 19, 2025
4a4b568
Further simplification; works with chunked prefill; correct output wi…
tdoublep Jun 19, 2025
4cbef84
Add enable_activated_lora engine arg
tdoublep Jun 19, 2025
49a5bdc
Disable tqdm in example
tdoublep Jun 19, 2025
a9ac26d
Resolve merge conflicts
tdoublep Jun 19, 2025
5c2e181
Trigger Build
tdoublep Jun 19, 2025
ceae7c7
vllm/model_executor/layers/linear.py: check lora_config exists before…
tdoublep Jun 19, 2025
99b8b60
arg_utils.py: fix typo
tdoublep Jun 19, 2025
477ab6e
Additional checking of lora_config
tdoublep Jun 19, 2025
91f39d1
Merge branch 'main' into alora
tdoublep Jun 20, 2025
5abbb78
Merge branch 'main' into alora
kgreenewald Aug 26, 2025
0a20f2a
Merge branch 'vllm-project:main' into alora
kgreenewald Aug 26, 2025
438ab6f
Resolve conflicts
tdoublep Sep 8, 2025
51edf96
Fix example
tdoublep Sep 8, 2025
a9d5986
Inject aLoRA behaviour via mixin
tdoublep Sep 8, 2025
6fbc108
Linting
tdoublep Sep 8, 2025
cb373e9
lint
tdoublep Sep 8, 2025
24dfc4a
add todo
tdoublep Sep 8, 2025
6c1b46a
Reorganize LoRA examples
tdoublep Sep 8, 2025
b8444d9
Lint
tdoublep Sep 8, 2025
4e513cc
lint
tdoublep Sep 8, 2025
b9df31f
more lint
tdoublep Sep 8, 2025
643d893
Cleanup example
tdoublep Sep 8, 2025
199ee89
Resolve conflicts
tdoublep Sep 11, 2025
03b6480
Refactor according to new structure
tdoublep Sep 11, 2025
76744da
Apply suggestions from code review
tdoublep Sep 11, 2025
6b83cc4
Fix a few naming issues
tdoublep Sep 11, 2025
18397d7
lint
tdoublep Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions examples/alora/alora_offline_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time

import torch
from huggingface_hub import snapshot_download

from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

BASE_NAME = "ibm-granite/granite-3.2-8b-instruct"

ALORA_NAME = "ibm-granite/granite-3.2-8b-alora-uncertainty"
invocation_string = "<|start_of_role|>certainty<|end_of_role|>"

os.environ["VLLM_USE_V1"] = "1"

# download your LoRA adapter to ~/.cache/huggingface/…
alora_path = snapshot_download(repo_id=ALORA_NAME)

print(alora_path)
#######################################


llm = LLM(
model=BASE_NAME,
enable_lora=True,
enable_activated_lora=True,
dtype=torch.bfloat16,
max_lora_rank=64,
)

prompts = [
(
"<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>\n"
"<|start_of_role|>assistant<|end_of_role|>"
),
]

sampling_params = SamplingParams(temperature=0, max_tokens=600)

outputsBase = llm.generate(
prompts,
sampling_params,
use_tqdm=False,
)
generated_text = []
for output in outputsBase:
prompt = output.prompt
generated_text += [output.outputs[0].text]
print(f"Prompt: {prompt!r}, Generated text: {generated_text[-1]!r}")

prompts_alora = [
x + y + "<|end_of_text|>\n" + invocation_string
for x, y in zip(prompts, generated_text)
]

sampling_params = SamplingParams(temperature=0, max_tokens=10)

t0 = time.time()
outputs = llm.generate(
prompts_alora,
sampling_params,
lora_request=LoRARequest("UQ_adapter", 1, alora_path),
use_tqdm=False,
)
t = time.time() - t0
print(f"Time: {t}")

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
2 changes: 2 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,8 @@ class LoRAConfig:
in alphabetic order."""
bias_enabled: bool = False
"""Enable bias for LoRA adapters."""
activated_lora_enabled: bool = False
"""Enable Activated LoRA."""

def compute_hash(self) -> str:
"""
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ class EngineArgs:
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
enable_activated_lora: bool = LoRAConfig.activated_lora_enabled
max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank
default_mm_loras: Optional[Dict[str, str]] = \
Expand Down Expand Up @@ -776,6 +777,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="If True, enable handling of LoRA adapters.")
lora_group.add_argument("--enable-lora-bias",
**lora_kwargs["bias_enabled"])
lora_group.add_argument("--enable-activated-lora",
**lora_kwargs["activated_lora_enabled"])
lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"])
lora_group.add_argument("--max-lora-rank",
**lora_kwargs["max_lora_rank"])
Expand Down Expand Up @@ -1349,6 +1352,7 @@ def create_engine_config(

lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
activated_lora_enabled=self.enable_activated_lora,
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
default_mm_loras=self.default_mm_loras,
Expand Down
10 changes: 9 additions & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
batchsize_forward_time: defaultdict = defaultdict(list)


@dataclass
class ALoRAMetadata:
mask1d: torch.Tensor


class BatchDescriptor(NamedTuple):
"""
Batch descriptor for cudagraph dispatching. We should keep the num of
Expand Down Expand Up @@ -173,6 +178,7 @@ class ForwardContext:
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: Optional[DPMetadata] = None
alora_metadata: Optional[ALoRAMetadata] = None
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
# by default NONE, no cudagraph is used.
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
Expand Down Expand Up @@ -203,7 +209,8 @@ def set_forward_context(
num_tokens: Optional[int] = None,
num_tokens_across_dp: Optional[torch.Tensor] = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None):
batch_descriptor: Optional[BatchDescriptor] = None,
alora_metadata: Optional[ALoRAMetadata] = None):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
Expand All @@ -227,6 +234,7 @@ def set_forward_context(
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
dp_metadata=dp_metadata,
alora_metadata=alora_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
)
Expand Down
51 changes: 50 additions & 1 deletion vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed.utils import divide
from vllm.forward_context import get_forward_context
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
Expand Down Expand Up @@ -873,7 +874,8 @@ def can_replace_layer(
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is QKVParallelLinear
and len(packed_modules_list) == 3)
and len(packed_modules_list) == 3
and not lora_config.activated_lora_enabled)


#TODO: Implement this
Expand Down Expand Up @@ -1190,3 +1192,50 @@ def can_replace_layer(
) -> bool:
# Special handling for the LogitsProcessor.
return False


class MergedQKVParallelLinearWithActivatedLoRA(MergedQKVParallelLinearWithLoRA
):

def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super familiar with the LoRA codepath, would this flattening need to be reversed or is it fine because flatten doesn't modify the original tensor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Extract aLoRA batch metadata from forward context
alora_metadata = get_forward_context().alora_metadata

mask1d = alora_metadata.mask1d
mask2d = mask1d.unsqueeze(1).to(output.dtype)

# Clone base layer output before running LoRA
orig_out = output.clone()

# Apply LoRA in‐place on `output`:
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
self.lora_b_stacked,
self.lora_bias_stacked, 1.0,
self.output_slices)
# Apply alora mask
final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d)
return final_output

@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
return (type(source_layer) is QKVParallelLinear
and len(packed_modules_list) == 3
and lora_config.activated_lora_enabled)
2 changes: 2 additions & 0 deletions vllm/lora/peft_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class PEFTHelper:
use_rslora: bool = field(default=False)
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
use_dora: bool = field(default=False)
# Invocation string for Activated LoRA (aLoRA, see: https://arxiv.org/abs/2504.12397)
invocation_string: Optional[str] = field(default=None)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_lora_scaling_factor: float = field(default=1.0)
vllm_max_position_embeddings: Optional[int] = field(default=False)
Expand Down
1 change: 1 addition & 0 deletions vllm/lora/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LoRARequest(
long_lora_max_len: Optional[int] = None
base_model_name: Optional[str] = msgspec.field(default=None)
tensorizer_config_dict: Optional[dict] = None
invocation_start: Optional[int] = None

def __post_init__(self):
if self.lora_local_path:
Expand Down
2 changes: 2 additions & 0 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithActivatedLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
Expand All @@ -47,6 +48,7 @@
MergedColumnParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithActivatedLoRA,
RowParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
LogitsProcessorWithLoRA,
Expand Down
11 changes: 11 additions & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.nn.parameter import Parameter, UninitializedParameter

from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
Expand Down Expand Up @@ -253,6 +254,16 @@ def __init__(
):
super().__init__()

vllm_config = get_current_vllm_config()
if (vllm_config.lora_config
and vllm_config.lora_config.activated_lora_enabled):
# lets torch.compile know that forward_context needs to be
# considered as an input to the layer (copied from attention)
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self

# Keep input parameters
self.input_size = input_size
self.output_size = output_size
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,12 @@ def request_block_hasher(request: Request) -> list[BlockHash]:
# MM and LoRA requests need extra keys for block-hash computation.
extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, curr_mm_idx)
# Respect aLoRA behaviour
if (request.lora_request is not None
and request.lora_request.invocation_start is not None
and end_token_idx <= request.lora_request.invocation_start):
# cache is equivalent to base model cache
extra_keys = None

# Compute the hash of the current block
block_tokens = request.all_token_ids[start_token_idx:end_token_idx]
Expand Down
41 changes: 41 additions & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
Expand Down Expand Up @@ -328,6 +329,46 @@
sorted_mm_hashes,
)

# Handle aLoRA invocation sequence if applicable.
if (self.lora_config and self.lora_config.activated_lora_enabled
and lora_request is not None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also check if it is a alora request here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't actually know that until we've called the PEFTHelper below to look at the adapter config


text_config = self.model_config.hf_config.get_text_config()

peft_helper = PEFTHelper.from_local_dir(
lora_request.lora_path, text_config.max_position_embeddings,
lora_request.tensorizer_config_dict)
if peft_helper.alora_invocation_tokens is not None or peft_helper.invocation_string is not None:

Check failure on line 341 in vllm/v1/engine/processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/engine/processor.py:341:81: E501 Line too long (108 > 80)
if peft_helper.alora_invocation_tokens is not None:
invocation_tokens = peft_helper.alora_invocation_tokens
delta = 0
elif peft_helper.invocation_string is not None: #backwards compatibility

Check failure on line 345 in vllm/v1/engine/processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/engine/processor.py:345:81: E501 Line too long (89 > 80)
invocation_tokens = self.input_preprocessor._tokenize_prompt(

Check failure on line 346 in vllm/v1/engine/processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/engine/processor.py:346:81: E501 Line too long (81 > 80)
peft_helper.invocation_string,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs)
delta = 1
invocation_start = -1
n = len(invocation_tokens)
token_ids = decoder_inputs["prompt_token_ids"]

if n > 0 and len(token_ids) >= n:
# scan backward for the last match
# (faster than full forward scan+max)
for idx in range(len(token_ids) - n, -1, -1):
if token_ids[idx:idx + n] == invocation_tokens:
# weights activated after start
invocation_start = idx + delta
break

if invocation_start == -1:
raise ValueError(
"Invocation sequence not found in prompt "
f"for request '{request_id}'. aLoRA models require the "
"invocation tokens to be present in the input.")

lora_request.invocation_start = invocation_start

return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
Expand Down
Loading
Loading