Skip to content

Commit 282687a

Browse files
committed
Integrate suffix decoding
1 parent 46ef280 commit 282687a

File tree

5 files changed

+55
-2
lines changed

5 files changed

+55
-2
lines changed

vllm_ascend/patch/platform/patch_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def __post_init__(self):
2929
self.quantization = self.target_model_config.quantization
3030
elif self.method in ("ngram", "[ngram]"):
3131
self.model = "ngram"
32+
elif self.method == "suffix":
33+
self.model = "suffix"
3234
else:
3335
raise ValueError("num_speculative_tokens was provided but without "
3436
"speculative model.")
@@ -71,6 +73,9 @@ def __post_init__(self):
7173
# draft related config as None here.
7274
self.draft_model_config = self.target_model_config
7375
self.draft_parallel_config = self.target_parallel_config
76+
elif self.method == "suffix":
77+
self.draft_model_config = self.target_model_config
78+
self.draft_parallel_config = self.target_parallel_config
7479
else:
7580
self.prompt_lookup_max = 0
7681
self.prompt_lookup_min = 0

vllm_ascend/spec_decode/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
2020
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
2121
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
22+
from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer
2223
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
2324

2425

@@ -35,6 +36,8 @@ def get_spec_decode_method(method,
3536
if is_torchair_graph:
3637
return TorchairMtpProposer(vllm_config, device, runner)
3738
return MtpProposer(vllm_config, device, runner)
39+
elif method == 'suffix':
40+
return SuffixDecodingProposer(vllm_config, device, runner)
3841
else:
3942
raise ValueError("Unknown speculative decoding method: "
4043
f"{method}")

vllm_ascend/spec_decode/interface.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class SpecDcodeType(enum.Enum):
1313
EAGLE = 1
1414
EAGLE3 = 2
1515
MTP = 4
16+
SUFFIX = 5
1617

1718

1819
class Proposer:
@@ -50,4 +51,4 @@ def generate_token_ids(self,
5051
attn_metadata=None,
5152
aux_hidden_states: torch.Tensor = None):
5253
"""Called by execute_model in model_runner"""
53-
raise NotImplementedError
54+
raise NotImplementedError
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
from vllm.config import CUDAGraphMode
3+
from vllm.v1.spec_decode.suffix_decoding import \
4+
SuffixDecodingProposer as VllmSuffixDecodingProposer
5+
6+
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
7+
8+
9+
class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):
10+
11+
def __init__(self, vllm_config, device, runner):
12+
super().__init__(vllm_config)
13+
self.name = SpecDcodeType.SUFFIX
14+
self.device = device
15+
self.runner = runner
16+
17+
def load_model(self, *args, **kwargs):
18+
# No model to load.
19+
pass
20+
21+
@torch.inference_mode()
22+
def dummy_run(self,
23+
num_tokens,
24+
with_prefill=None,
25+
skip_attn=None,
26+
num_reqs=None,
27+
num_tokens_across_dp=None,
28+
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
29+
batch_descriptor=None):
30+
pass
31+
32+
def generate_token_ids(self,
33+
valid_sampled_token_ids,
34+
sampling_metadata=None,
35+
scheduler_output=None,
36+
spec_decode_metadata=None,
37+
positions=None,
38+
num_scheduled_tokens=None,
39+
hidden_states=None,
40+
attn_metadata=None,
41+
aux_hidden_states=None) -> list[list[int]]:
42+
draft_token_ids = self.propose(self.runner.input_batch, valid_sampled_token_ids)
43+
return draft_token_ids

vllm_ascend/worker/model_runner_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
from vllm.v1.sample.metadata import SamplingMetadata
9696
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
9797
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
98+
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
9899
from vllm.v1.utils import CpuGpuBuffer
99100
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
100101
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -591,7 +592,7 @@ def _set_up_drafter(self):
591592
# Set up speculative decoding.
592593
self.spec_attn_mask = None
593594
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
594-
TorchairMtpProposer]] = None
595+
TorchairMtpProposer, SuffixDecodingProposer]] = None
595596
self.actual_seq_lengths_q: list[int] = []
596597
self.decode_token_per_req = 1
597598
if self.speculative_config:

0 commit comments

Comments
 (0)