Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,88 @@ def test_eagle_correctness(
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))

def test_suffix_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=False)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
with VllmRunner(model_name,
speculative_config={
"method": "suffix",
"num_speculative_tokens": 8,
},
max_model_len=1024,
enforce_eager=False) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")

# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))



def test_suffix_acceptance(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Check that suffix decoding caching takes effect and improves acceptance
lengths and acceptance rates over multiple runs of the same prompts.
'''
num_draft = []
num_accept = []
with VllmRunner(model_name,
speculative_config={
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
"suffix_decoding_max_cached_requests": 1000,
"num_speculative_tokens": 10,
},
max_model_len=1024,
disable_log_stats=False,
enforce_eager=False) as runner:
for i in range(10):
runner.model.chat(test_prompts[i], sampling_config)
metrics = runner.model.get_metrics()
for metric in metrics:
print(metric)
if metric.name == "vllm:spec_decode_num_draft_tokens":
num_draft.append(metric.value)
if metric.name == "vllm:spec_decode_num_accepted_tokens":
num_accept.append(metric.value)
# Calculate the acceptance rates for the first and last runs.
first_accept_tokens = num_accept[0]
first_draft_tokens = num_draft[0]
first_accept_rate = first_accept_tokens / first_draft_tokens

# Take the diff since the stats are cumulative.
last_accept_tokens = num_accept[-1] - num_accept[-2]
last_draft_tokens = num_draft[-1] - num_draft[-2]
last_accept_rate = last_accept_tokens / last_draft_tokens

# Expect the acceptance length to improve.
assert first_accept_tokens < last_accept_tokens

# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate

# Heuristic: expect at least 80% acceptance rate at the end.
assert last_accept_rate > 0.60
6 changes: 6 additions & 0 deletions vllm_ascend/patch/platform/patch_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __post_init__(self):
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
else:
raise ValueError("num_speculative_tokens was provided but without "
"speculative model.")
Expand Down Expand Up @@ -70,6 +72,10 @@ def __post_init__(self):
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
self._validate_suffix_decoding()
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
Expand Down
8 changes: 7 additions & 1 deletion vllm_ascend/spec_decode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer


def get_spec_decode_method(method,
vllm_config,
device,
Expand All @@ -35,6 +34,13 @@ def get_spec_decode_method(method,
if is_torchair_graph:
return TorchairMtpProposer(vllm_config, device, runner)
return MtpProposer(vllm_config, device, runner)
elif method == 'suffix':
from vllm_ascend.utils import vllm_version_is
if not vllm_version_is("0.11.0"):
from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer
return SuffixDecodingProposer(vllm_config, device, runner)
else:
raise ValueError("suffix deocding is unsupported on vllm 0.11.0")
else:
raise ValueError("Unknown speculative decoding method: "
f"{method}")
3 changes: 2 additions & 1 deletion vllm_ascend/spec_decode/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class SpecDcodeType(enum.Enum):
EAGLE = 1
EAGLE3 = 2
MTP = 4
SUFFIX = 5


class Proposer:
Expand Down Expand Up @@ -50,4 +51,4 @@ def generate_token_ids(self,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
"""Called by execute_model in model_runner"""
raise NotImplementedError
raise NotImplementedError
43 changes: 43 additions & 0 deletions vllm_ascend/spec_decode/suffix_proposer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from vllm.config import CUDAGraphMode
from vllm.v1.spec_decode.suffix_decoding import \

Check failure on line 3 in vllm_ascend/spec_decode/suffix_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.v1.spec_decode.suffix_decoding" [import-not-found]

Check failure on line 3 in vllm_ascend/spec_decode/suffix_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.v1.spec_decode.suffix_decoding" [import-not-found]

Check failure on line 3 in vllm_ascend/spec_decode/suffix_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.v1.spec_decode.suffix_decoding" [import-not-found]

Check failure on line 3 in vllm_ascend/spec_decode/suffix_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.v1.spec_decode.suffix_decoding" [import-not-found]

Check failure on line 3 in vllm_ascend/spec_decode/suffix_proposer.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.v1.spec_decode.suffix_decoding" [import-not-found]
SuffixDecodingProposer as VllmSuffixDecodingProposer

from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType


class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer):

def __init__(self, vllm_config, device, runner):
super().__init__(vllm_config)
self.name = SpecDcodeType.SUFFIX
self.device = device
self.runner = runner

def load_model(self, *args, **kwargs):
# No model to load.
pass

@torch.inference_mode()
def dummy_run(self,
num_tokens,
with_prefill=None,
skip_attn=None,
num_reqs=None,
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
pass

def generate_token_ids(self,
valid_sampled_token_ids,
sampling_metadata=None,
scheduler_output=None,
spec_decode_metadata=None,
positions=None,
num_scheduled_tokens=None,
hidden_states=None,
attn_metadata=None,
aux_hidden_states=None) -> list[list[int]]:
draft_token_ids = self.propose(self.runner.input_batch, valid_sampled_token_ids)
return draft_token_ids
5 changes: 4 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")

if not vllm_version_is("0.11.0"):
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer

Check failure on line 181 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.v1.spec_decode.suffix_decoding" [import-not-found]

Check failure on line 181 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.v1.spec_decode.suffix_decoding" [import-not-found]

Check failure on line 181 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.v1.spec_decode.suffix_decoding" [import-not-found]

Check failure on line 181 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.v1.spec_decode.suffix_decoding" [import-not-found]

Check failure on line 181 in vllm_ascend/worker/model_runner_v1.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Cannot find implementation or library stub for module named "vllm.v1.spec_decode.suffix_decoding" [import-not-found]

import torch_npu

# if true, allow tensor initialization and casting with internal format (e.g., NZ)
Expand Down Expand Up @@ -597,7 +600,7 @@
# Set up speculative decoding.
self.spec_attn_mask = None
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
TorchairMtpProposer]] = None
TorchairMtpProposer, SuffixDecodingProposer]] = None
self.actual_seq_lengths_q: list[int] = []
self.decode_token_per_req = 1
if self.speculative_config:
Expand Down
Loading