diff --git a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py index b35de243977..db915ab1f7a 100644 --- a/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py +++ b/tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py @@ -145,3 +145,88 @@ def test_eagle_correctness( # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) + +def test_suffix_correctness( + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, +): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using ngram speculative decoding. + ''' + ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=False) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + with VllmRunner(model_name, + speculative_config={ + "method": "suffix", + "num_speculative_tokens": 8, + }, + max_model_len=1024, + enforce_eager=False) as runner: + spec_outputs = runner.model.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.66 * len(ref_outputs)) + + + +def test_suffix_acceptance( + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, +): + ''' + Check that suffix decoding caching takes effect and improves acceptance + lengths and acceptance rates over multiple runs of the same prompts. + ''' + num_draft = [] + num_accept = [] + with VllmRunner(model_name, + speculative_config={ + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + "suffix_decoding_max_cached_requests": 1000, + "num_speculative_tokens": 10, + }, + max_model_len=1024, + disable_log_stats=False, + enforce_eager=False) as runner: + for i in range(10): + runner.model.chat(test_prompts[i], sampling_config) + metrics = runner.model.get_metrics() + for metric in metrics: + print(metric) + if metric.name == "vllm:spec_decode_num_draft_tokens": + num_draft.append(metric.value) + if metric.name == "vllm:spec_decode_num_accepted_tokens": + num_accept.append(metric.value) + # Calculate the acceptance rates for the first and last runs. + first_accept_tokens = num_accept[0] + first_draft_tokens = num_draft[0] + first_accept_rate = first_accept_tokens / first_draft_tokens + + # Take the diff since the stats are cumulative. + last_accept_tokens = num_accept[-1] - num_accept[-2] + last_draft_tokens = num_draft[-1] - num_draft[-2] + last_accept_rate = last_accept_tokens / last_draft_tokens + + # Expect the acceptance length to improve. + assert first_accept_tokens < last_accept_tokens + + # Expect the acceptance rate to improve. + assert first_accept_rate < last_accept_rate + + # Heuristic: expect at least 80% acceptance rate at the end. + assert last_accept_rate > 0.60 diff --git a/vllm_ascend/patch/platform/patch_config.py b/vllm_ascend/patch/platform/patch_config.py index 0e8642d1cea..b798fda3bc7 100644 --- a/vllm_ascend/patch/platform/patch_config.py +++ b/vllm_ascend/patch/platform/patch_config.py @@ -28,6 +28,8 @@ def __post_init__(self): self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "suffix": + self.model = "suffix" else: raise ValueError("num_speculative_tokens was provided but without " "speculative model.") @@ -70,6 +72,10 @@ def __post_init__(self): # draft related config as None here. self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config + elif self.method == "suffix": + self.draft_model_config = self.target_model_config + self.draft_parallel_config = self.target_parallel_config + self._validate_suffix_decoding() else: self.prompt_lookup_max = 0 self.prompt_lookup_min = 0 diff --git a/vllm_ascend/spec_decode/__init__.py b/vllm_ascend/spec_decode/__init__.py index 6abe8777cd3..03c577a944f 100644 --- a/vllm_ascend/spec_decode/__init__.py +++ b/vllm_ascend/spec_decode/__init__.py @@ -21,7 +21,6 @@ from vllm_ascend.spec_decode.ngram_proposer import NgramProposer from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer - def get_spec_decode_method(method, vllm_config, device, @@ -35,6 +34,13 @@ def get_spec_decode_method(method, if is_torchair_graph: return TorchairMtpProposer(vllm_config, device, runner) return MtpProposer(vllm_config, device, runner) + elif method == 'suffix': + from vllm_ascend.utils import vllm_version_is + if not vllm_version_is("0.11.0"): + from vllm_ascend.spec_decode.suffix_proposer import SuffixDecodingProposer + return SuffixDecodingProposer(vllm_config, device, runner) + else: + raise ValueError("suffix deocding is unsupported on vllm 0.11.0") else: raise ValueError("Unknown speculative decoding method: " f"{method}") diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py index 3f0a36b13cd..ae2d92294c8 100644 --- a/vllm_ascend/spec_decode/interface.py +++ b/vllm_ascend/spec_decode/interface.py @@ -13,6 +13,7 @@ class SpecDcodeType(enum.Enum): EAGLE = 1 EAGLE3 = 2 MTP = 4 + SUFFIX = 5 class Proposer: @@ -50,4 +51,4 @@ def generate_token_ids(self, attn_metadata=None, aux_hidden_states: torch.Tensor = None): """Called by execute_model in model_runner""" - raise NotImplementedError + raise NotImplementedError \ No newline at end of file diff --git a/vllm_ascend/spec_decode/suffix_proposer.py b/vllm_ascend/spec_decode/suffix_proposer.py new file mode 100644 index 00000000000..de87ebab2a2 --- /dev/null +++ b/vllm_ascend/spec_decode/suffix_proposer.py @@ -0,0 +1,43 @@ +import torch +from vllm.config import CUDAGraphMode +from vllm.v1.spec_decode.suffix_decoding import \ + SuffixDecodingProposer as VllmSuffixDecodingProposer + +from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType + + +class SuffixDecodingProposer(VllmSuffixDecodingProposer, Proposer): + + def __init__(self, vllm_config, device, runner): + super().__init__(vllm_config) + self.name = SpecDcodeType.SUFFIX + self.device = device + self.runner = runner + + def load_model(self, *args, **kwargs): + # No model to load. + pass + + @torch.inference_mode() + def dummy_run(self, + num_tokens, + with_prefill=None, + skip_attn=None, + num_reqs=None, + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None): + pass + + def generate_token_ids(self, + valid_sampled_token_ids, + sampling_metadata=None, + scheduler_output=None, + spec_decode_metadata=None, + positions=None, + num_scheduled_tokens=None, + hidden_states=None, + attn_metadata=None, + aux_hidden_states=None) -> list[list[int]]: + draft_token_ids = self.propose(self.runner.input_batch, valid_sampled_token_ids) + return draft_token_ids diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 167345bf1d8..66f8d785ef2 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -177,6 +177,9 @@ else: xgr = LazyLoader("xgr", globals(), "xgrammar") +if not vllm_version_is("0.11.0"): + from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer + import torch_npu # if true, allow tensor initialization and casting with internal format (e.g., NZ) @@ -597,7 +600,7 @@ def _set_up_drafter(self): # Set up speculative decoding. self.spec_attn_mask = None self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer, - TorchairMtpProposer]] = None + TorchairMtpProposer, SuffixDecodingProposer]] = None self.actual_seq_lengths_q: list[int] = [] self.decode_token_per_req = 1 if self.speculative_config: