|
19 | 19 | tensor_model_parallel_all_gather, |
20 | 20 | tensor_model_parallel_all_reduce) |
21 | 21 | from vllm.distributed.utils import divide |
| 22 | +from vllm.forward_context import get_forward_context |
22 | 23 | # yapf: disable |
23 | 24 | from vllm.model_executor.layers.linear import (ColumnParallelLinear, |
24 | 25 | LinearBase, |
@@ -418,14 +419,44 @@ def apply(self, |
418 | 419 | output = output.flatten(0, 1) |
419 | 420 | x = x.flatten(0, 1) |
420 | 421 |
|
421 | | - lora_output: Optional[ |
422 | | - torch.Tensor] = self.punica_wrapper.add_lora_linear( |
423 | | - output, x, self.lora_a_stacked, self.lora_b_stacked, |
424 | | - self.lora_bias_stacked, 1.0, self.output_slices) |
425 | | - if not current_platform.can_update_inplace(): |
426 | | - output = lora_output |
427 | | - |
428 | | - return output |
| 422 | + # Extract aLoRA batch metadata from forward context |
| 423 | + alora_metadata = get_forward_context().alora_metadata |
| 424 | + k_offsets = alora_metadata.k_offsets |
| 425 | + query_start_locs = alora_metadata.query_start_locs |
| 426 | + |
| 427 | + # Build the 1D “save‐prefix” mask: |
| 428 | + T = output.size(0) # total tokens |
| 429 | + starts = query_start_locs[:-1] # starts and end index of each request |
| 430 | + ends = query_start_locs[1:] |
| 431 | + lengths = ends - starts # request lengths |
| 432 | + kept_lens = lengths - k_offsets |
| 433 | + kept_lens = torch.clamp( |
| 434 | + kept_lens, |
| 435 | + min=0) # portion of request to keep as base model weights |
| 436 | + |
| 437 | + device = output.device |
| 438 | + # Create the alora mask |
| 439 | + delta = torch.zeros(T + 1, device=device, dtype=output.dtype) |
| 440 | + ends_for_scatter = starts + kept_lens |
| 441 | + pos_vals = kept_lens.sign().to(output.dtype) |
| 442 | + neg_vals = -pos_vals |
| 443 | + delta.scatter_add_(0, starts, pos_vals) |
| 444 | + delta.scatter_add_(0, ends_for_scatter, neg_vals) |
| 445 | + cums = torch.cumsum(delta[:-1], dim=0) |
| 446 | + mask1d = cums > 0 # shape [T], bool |
| 447 | + mask2d = mask1d.unsqueeze(1).to(output.dtype) |
| 448 | + |
| 449 | + # Clone base layer output before running LoRA |
| 450 | + orig_out = output.clone() |
| 451 | + |
| 452 | + # Apply LoRA in‐place on `output`: |
| 453 | + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, |
| 454 | + self.lora_b_stacked, |
| 455 | + self.lora_bias_stacked, 1.0, |
| 456 | + self.output_slices) |
| 457 | + # Apply alora mask |
| 458 | + final_output = orig_out.mul(mask2d) + output.mul(1.0 - mask2d) |
| 459 | + return final_output |
429 | 460 |
|
430 | 461 | @property |
431 | 462 | def weight(self) -> torch.Tensor: |
|
0 commit comments