diff --git a/examples/offline_inference/lora/README.md b/examples/offline_inference/lora/README.md new file mode 100644 index 000000000000..afe5bda548c2 --- /dev/null +++ b/examples/offline_inference/lora/README.md @@ -0,0 +1,27 @@ +# LoRA Examples + +This folder contains examples of offline inference using LoRA. + +## Multi-LoRA + +This example shows how to use the multi-LoRA functionality: + +```bash +python examples/offline_inference/lora/multilora_inference.py +``` + +## LoRA with Quantization + +This example shows how to use LoRA with different quantization techniques: + +```bash +python examples/offline_inference/lora/lora_with_quantization_inference.py +``` + +## Activated LoRA + +This example how to use [activated LoRA](https://arxiv.org/abs/2504.12397): + +```bash +python examples/offline_inference/lora/activated_lora.py +``` diff --git a/examples/offline_inference/lora/activated_lora.py b/examples/offline_inference/lora/activated_lora.py new file mode 100644 index 000000000000..be469611d4f6 --- /dev/null +++ b/examples/offline_inference/lora/activated_lora.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time + +import torch +from huggingface_hub import snapshot_download + +from vllm import LLM, SamplingParams +from vllm.lora.request import LoRARequest + +BASE_NAME = "ibm-granite/granite-3.2-8b-instruct" + +ALORA_NAME = "ibm-granite/granite-3.2-8b-alora-uncertainty" +invocation_string = "<|start_of_role|>certainty<|end_of_role|>" + +# download your LoRA adapter to ~/.cache/huggingface/… +alora_path = snapshot_download(repo_id=ALORA_NAME) + +print(alora_path) +####################################### + + +llm = LLM( + model=BASE_NAME, + enable_lora=True, + enable_activated_lora=True, + dtype=torch.bfloat16, + max_lora_rank=64, +) + +prompts = [ + ( + "<|start_of_role|>user<|end_of_role|>What is MIT?<|end_of_text|>\n" + "<|start_of_role|>assistant<|end_of_role|>" + ), +] + +sampling_params = SamplingParams(temperature=0, max_tokens=600) + +outputsBase = llm.generate( + prompts, + sampling_params, + use_tqdm=False, +) +generated_text = [] +for output in outputsBase: + prompt = output.prompt + generated_text += [output.outputs[0].text] + print(f"Prompt: {prompt!r}, Generated text: {generated_text[-1]!r}") + +prompts_alora = [ + x + y + "<|end_of_text|>\n" + invocation_string + for x, y in zip(prompts, generated_text) +] + +sampling_params = SamplingParams(temperature=0, max_tokens=10) + +t0 = time.time() +outputs = llm.generate( + prompts_alora, + sampling_params, + lora_request=LoRARequest("UQ_adapter", 1, alora_path), + use_tqdm=False, +) +t = time.time() - t0 +print(f"Time: {t}") + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/offline_inference/lora_with_quantization_inference.py b/examples/offline_inference/lora/lora_with_quantization_inference.py similarity index 100% rename from examples/offline_inference/lora_with_quantization_inference.py rename to examples/offline_inference/lora/lora_with_quantization_inference.py diff --git a/examples/offline_inference/multilora_inference.py b/examples/offline_inference/lora/multilora_inference.py similarity index 100% rename from examples/offline_inference/multilora_inference.py rename to examples/offline_inference/lora/multilora_inference.py diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 587cfab35515..325c8f05b245 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2440,6 +2440,8 @@ class LoRAConfig: bias_enabled: bool = False """[DEPRECATED] Enable bias for LoRA adapters. This option will be removed in v0.12.0.""" + enable_activated_lora: bool = False + """Enable Activated LoRA.""" def compute_hash(self) -> str: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d9a29511eb52..c3e5cc8ce6a2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -373,6 +373,7 @@ class EngineArgs: # LoRA fields enable_lora: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled + enable_activated_lora: bool = LoRAConfig.enable_activated_lora max_loras: int = LoRAConfig.max_loras max_lora_rank: int = LoRAConfig.max_lora_rank default_mm_loras: Optional[Dict[str, str]] = \ @@ -794,6 +795,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="If True, enable handling of LoRA adapters.") lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"]) + lora_group.add_argument("--enable-activated-lora", + **lora_kwargs["enable_activated_lora"]) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) @@ -1369,6 +1372,7 @@ def create_engine_config( lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, + enable_activated_lora=self.enable_activated_lora, max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, default_mm_loras=self.default_mm_loras, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c57c51d289ac..98ea8606fe68 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -26,6 +26,11 @@ batchsize_forward_time: defaultdict = defaultdict(list) +@dataclass +class ALoRAMetadata: + mask1d: torch.Tensor + + class BatchDescriptor(NamedTuple): """ Batch descriptor for cudagraph dispatching. We should keep the num of @@ -173,6 +178,7 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None + alora_metadata: Optional[ALoRAMetadata] = None # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. # by default NONE, no cudagraph is used. cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE @@ -203,7 +209,8 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None): + batch_descriptor: Optional[BatchDescriptor] = None, + alora_metadata: Optional[ALoRAMetadata] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -227,6 +234,7 @@ def set_forward_context( virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, + alora_metadata=alora_metadata, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ) diff --git a/vllm/lora/layers/__init__.py b/vllm/lora/layers/__init__.py index d3bb145dc7bf..23b6a65881ba 100644 --- a/vllm/lora/layers/__init__.py +++ b/vllm/lora/layers/__init__.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.lora.layers.activated_linear import LinearLayerWithActivatedLoRAMixin from vllm.lora.layers.base import BaseLayerWithLoRA from vllm.lora.layers.column_parallel_linear import ( ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA, @@ -31,4 +32,5 @@ "RowParallelLinearWithShardedLoRA", "ReplicatedLinearWithLoRA", "LoRAMapping", + "LinearLayerWithActivatedLoRAMixin", ] diff --git a/vllm/lora/layers/activated_linear.py b/vllm/lora/layers/activated_linear.py new file mode 100644 index 000000000000..adf2d299d009 --- /dev/null +++ b/vllm/lora/layers/activated_linear.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TYPE_CHECKING, Optional + +import torch + +from vllm.forward_context import get_forward_context +from vllm.lora.punica_wrapper import PunicaWrapperBase + +from .base_linear import BaseLinearLayerWithLoRA + +if TYPE_CHECKING: + from .base import BaseLayerWithLoRA + + +class LinearLayerWithActivatedLoRAMixin: + + base_layer: BaseLinearLayerWithLoRA + punica_wrapper: PunicaWrapperBase + lora_a_stacked: torch.tensor + lora_b_stacked: torch.tensor + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] + output_slices: tuple[int, ...] + + @classmethod + def maybe_mixin(cls, lora_cls: "type[BaseLayerWithLoRA]"): + if issubclass(lora_cls, BaseLinearLayerWithLoRA): + return type(lora_cls.__name__.replace("LoRA", "ActivatedLoRA"), + (cls, lora_cls), {}) + else: + return lora_cls + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + + # Extract aLoRA batch metadata from forward context + alora_metadata = get_forward_context().alora_metadata + + mask1d = alora_metadata.mask1d + mask2d = mask1d.unsqueeze(1).to(output.dtype) + + # Clone base layer output before running LoRA + # TODO(tdoublep): pass in mask1d and only operate on valid entries + orig_out = output.clone() + + # Apply LoRA in‐place on `output`: + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, 1.0, + self.output_slices) + # Apply alora mask + final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d) + return final_output diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index 8b8e5cb7d5fa..de11c750d7d1 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -35,6 +35,8 @@ class PEFTHelper: use_rslora: bool = field(default=False) # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353) use_dora: bool = field(default=False) + # Invocation tokens for Activated LoRA (aLoRA, see: https://arxiv.org/abs/2504.12397) + alora_invocation_tokens: Optional[list[int]] = None # Extra vllm field, start with 'vllm_' to avoid conflict vllm_lora_scaling_factor: float = field(default=1.0) vllm_max_position_embeddings: Optional[int] = field(default=False) diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 5bbba7830c1b..64762f15ec2e 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -33,6 +33,7 @@ class LoRARequest( long_lora_max_len: Optional[int] = None base_model_name: Optional[str] = msgspec.field(default=None) tensorizer_config_dict: Optional[dict] = None + invocation_start: Optional[int] = None def __post_init__(self): if self.lora_local_path: diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 2b05a2cf4d40..c52ee583bb32 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -18,6 +18,7 @@ # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA, + LinearLayerWithActivatedLoRAMixin, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithShardedLoRA, @@ -69,6 +70,10 @@ def from_layer(layer: nn.Module, lora_config=lora_config, packed_modules_list=packed_modules_list, model_config=model_config): + # inject a-LoRA behaviour + if lora_config.enable_activated_lora: + lora_cls = LinearLayerWithActivatedLoRAMixin.maybe_mixin( + lora_cls) instance_layer = lora_cls(layer) instance_layer.create_lora_weights(max_loras, lora_config, model_config) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index fd88eac55cb5..23ba302a77b4 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -9,6 +9,7 @@ import torch.nn as nn from torch.nn.parameter import Parameter, UninitializedParameter +from vllm.config import get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -240,6 +241,16 @@ def __init__( ): super().__init__() + vllm_config = get_current_vllm_config() + if (vllm_config.lora_config + and vllm_config.lora_config.enable_activated_lora): + # lets torch.compile know that forward_context needs to be + # considered as an input to the layer (copied from attention) + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + # Keep input parameters self.input_size = input_size self.output_size = output_size diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2c0eac3ddd79..fb2809f63127 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -597,6 +597,12 @@ def request_block_hasher(request: Request) -> list[BlockHash]: # MM and LoRA requests need extra keys for block-hash computation. extra_keys, curr_mm_idx = generate_block_hash_extra_keys( request, start_token_idx, end_token_idx, curr_mm_idx) + # Respect aLoRA behaviour + if (request.lora_request is not None + and request.lora_request.invocation_start is not None and + end_token_idx <= request.lora_request.invocation_start): + # cache is equivalent to base model cache + extra_keys = None # Compute the hash of the current block block_tokens = request.all_token_ids[start_token_idx:end_token_idx] diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 96e89eeac556..46b30f04d9de 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -9,6 +9,7 @@ from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor +from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.cache import processor_cache_from_config @@ -423,6 +424,36 @@ def process_inputs( identifier=decoder_mm_hashes[modality][idx], mm_position=decoder_mm_positions[modality][idx])) + # Handle aLoRA invocation sequence if applicable. + if (self.lora_config and self.lora_config.enable_activated_lora + and lora_request is not None): + + text_config = self.model_config.hf_config.get_text_config() + + peft_helper = PEFTHelper.from_local_dir( + lora_request.lora_path, text_config.max_position_embeddings, + lora_request.tensorizer_config_dict) + + if peft_helper.alora_invocation_tokens is not None: + invocation_tokens = peft_helper.alora_invocation_tokens + invocation_start = -1 + n = len(invocation_tokens) + token_ids = decoder_inputs["prompt_token_ids"] + if n > 0 and len(token_ids) >= n: + # scan backward for the last match + # (faster than full forward scan+max) + for idx in range(len(token_ids) - n, -1, -1): + if token_ids[idx:idx + n] == invocation_tokens: + # weights activated after start + invocation_start = idx + break + if invocation_start == -1: + raise ValueError( + "Invocation sequence not found in prompt " + f"for request '{request_id}'. aLoRA models require the " + "invocation tokens to be present in the input.") + lora_request.invocation_start = invocation_start + return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1b785af96a9a..939b53752a7c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -32,7 +32,7 @@ from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) -from vllm.forward_context import (BatchDescriptor, DPMetadata, +from vllm.forward_context import (ALoRAMetadata, BatchDescriptor, DPMetadata, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -329,6 +329,11 @@ def __init__( self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int64) + if self.lora_config and self.lora_config.enable_activated_lora: + self.mask1d = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: # NOTE: `mrope_positions` is implemented with one additional dummy @@ -863,7 +868,8 @@ def _prepare_inputs( self, scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata], int]: + Optional[ALoRAMetadata], np.ndarray, + Optional[CommonAttentionMetadata], int]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -1093,9 +1099,17 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) + # Compute aLoRA metadata + alora_metadata = None + if self.lora_config and self.lora_config.enable_activated_lora: + alora_metadata = self.build_alora_metadata( + num_reqs, positions_np, req_indices, + total_num_scheduled_tokens, self.input_batch, self.requests, + self.mask1d) + return (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens) + alora_metadata, num_scheduled_tokens, + spec_decode_common_attn_metadata, max_num_scheduled_tokens) def _compute_cascade_attn_prefix_len( self, @@ -2022,7 +2036,8 @@ def execute_model( try: # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np, spec_decode_common_attn_metadata, + alora_metadata, num_scheduled_tokens_np, + spec_decode_common_attn_metadata, max_query_len) = self._prepare_inputs(scheduler_output) finally: @@ -2057,6 +2072,7 @@ def execute_model( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, + alora_metadata=alora_metadata, batch_descriptor=batch_descriptor, ), record_function_or_nullcontext("Forward"), self.maybe_get_kv_connector_output(scheduler_output) as @@ -2799,11 +2815,17 @@ def _dummy_run( f"Cudagraph runtime mode mismatch at dummy_run. " f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") + alora_metadata = None + if self.lora_config and self.lora_config.enable_activated_lora: + alora_metadata = self.build_dummy_alora_metadata( + num_tokens, self.mask1d) + with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, + alora_metadata=alora_metadata, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor): outputs = self.model( diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 4b5f27d27541..c000ca9f28db 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -12,11 +12,13 @@ import torch.nn as nn from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.forward_context import ALoRAMetadata from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.gpu_input_batch import InputBatch as GPUInputBatch from vllm.v1.worker.tpu_input_batch import InputBatch as TPUInputBatch @@ -86,6 +88,37 @@ def set_active_loras(self, input_batch: InputBatch, return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, lora_requests) + def build_alora_metadata(self, num_reqs: int, positions_np: np.ndarray, + req_indices: np.ndarray, + total_num_scheduled_tokens: int, + input_batch: InputBatch, + requests: dict[str, CachedRequestState], + mask1d: torch.Tensor) -> ALoRAMetadata: + invocation_start = np.empty(shape=(num_reqs, ), dtype=int) + for req_id in input_batch.req_ids: + req_index = input_batch.req_id_to_index[req_id] + cached_lora_request = requests[req_id].lora_request + if (cached_lora_request is not None + and cached_lora_request.invocation_start is not None): + invocation_start[ + req_index] = cached_lora_request.invocation_start + else: + invocation_start[req_index] = len( + requests[req_id].prompt_token_ids) + mask1d_cpu = torch.tensor(positions_np < invocation_start[req_indices], + dtype=torch.bool, + device="cpu") + mask1d = mask1d[:total_num_scheduled_tokens] + mask1d.copy_(mask1d_cpu, non_blocking=True) + return ALoRAMetadata(mask1d=mask1d) + + def build_dummy_alora_metadata(self, num_tokens: int, + mask1d: torch.tensor): + alora_metadata = ALoRAMetadata(mask1d=mask1d[:num_tokens]) + # needed to avoid guard failures + torch._dynamo.mark_dynamic(alora_metadata.mask1d, 0) + return alora_metadata + @contextmanager def maybe_setup_dummy_loras(self, lora_config: Optional[LoRAConfig],