From 2d4180ce1b7ad84db4712f071ceda626fb8359e7 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 24 Sep 2025 00:45:28 +0000 Subject: [PATCH 01/26] Integrate Suffix Decoding --- vllm/config/speculative.py | 43 +++++++++++++- vllm/v1/worker/gpu_model_runner.py | 93 ++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 2 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index d533930e1c7a..fa69f7eb740d 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -31,7 +31,7 @@ SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", "mlp_speculator", "draft_model", "deepseek_mtp", - "ernie_mtp", "qwen3_next_mtp", "mimo_mtp"] + "ernie_mtp", "qwen3_next_mtp", "mimo_mtp", "suffix"] @config @@ -120,6 +120,22 @@ class SpeculativeConfig: ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" + # Suffix decoding configuration + suffix_decoding_max_tree_depth: int = 64 + """The maximum depth of the suffix decoding tree.""" + + suffix_decoding_max_cached_requests: int = 1000 + """The maximum number of requests to cache in the global suffix tree.""" + + suffix_decoding_max_spec_factor: float = 1.0 + """The maximum speculative factor for suffix decoding.""" + + suffix_decoding_max_spec_offset: float = 0.0 + """The maximum speculative offset for suffix decoding.""" + + suffix_decoding_min_token_prob: float = 0.1 + """The minimum token probability for suffix decoding.""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -215,6 +231,8 @@ def __post_init__(self): self.quantization = self.target_model_config.quantization elif self.method in ("ngram", "[ngram]"): self.model = "ngram" + elif self.method == "suffix": + self.model = "suffix" else: raise ValueError("num_speculative_tokens was provided without " "speculative model.") @@ -258,6 +276,27 @@ def __post_init__(self): # draft related config as None here. self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config + elif self.method == "suffix": + if self.num_speculative_tokens is None: + self.num_speculative_tokens = 16 + # Validate values + if self.suffix_decoding_max_tree_depth < 4: + raise ValueError( + f"suffix_decoding_max_tree_depth=" + f"{self.suffix_decoding_max_tree_depth} must be >= 4") + if self.suffix_decoding_max_cached_requests < 0: + raise ValueError( + f"suffix_decoding_max_cached_requests=" + f"{self.suffix_decoding_max_cached_requests} must be >= 0") + if self.suffix_decoding_max_spec_factor < 0: + raise ValueError( + f"suffix_decoding_max_spec_factor=" + f"{self.suffix_decoding_max_spec_factor} must be >= 0") + if (self.suffix_decoding_min_token_prob < 0 or + self.suffix_decoding_min_token_prob > 1): + raise ValueError( + f"suffix_decoding_min_token_prob=" + f"{self.suffix_decoding_min_token_prob} must be in [0, 1]") else: self.prompt_lookup_max = 0 self.prompt_lookup_min = 0 @@ -554,6 +593,6 @@ def use_eagle(self) -> bool: def __repr__(self) -> str: method = self.method - model = None if method == "ngram" else self.draft_model_config.model + model = None if method in ["ngram", "suffix"] else self.draft_model_config.model num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eebdbcc621c6..340100816c95 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -268,9 +268,21 @@ def __init__( # NOTE(Jiayi): currently we put the entire draft model on # the last PP rank. This is not ideal if there are many # layers in the draft model. + self.suffix_cache = None if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) + elif self.speculative_config.method == "suffix": + try: + from arctic_inference.suffix_decoding import ( + SuffixDecodingCache) + except ImportError: + raise ImportError( + "Arctic Inference is required for suffix decoding. " + "Please install via `pip install arctic-inference`.") + self.suffix_cache = SuffixDecodingCache( + self.speculative_config.suffix_decoding_max_tree_depth, + self.speculative_config.suffix_decoding_max_cached_requests) elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore @@ -2121,6 +2133,9 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) + if self.suffix_cache is not None: + self._update_suffix_cache(valid_sampled_token_ids) + return ( num_nans_in_logits, logprobs_lists, @@ -2146,6 +2161,32 @@ def synchronize_input_prep(self): finally: self.prepare_inputs_event.record() + def _update_suffix_cache(self, sampled_token_ids: list[list[int]]) -> None: + seen_req_ids = set() + for i, sampled_ids in enumerate(sampled_token_ids): + req_id = self.input_batch.req_ids[i] + seen_req_ids.add(req_id) + + if not sampled_ids: + continue + + index = self.input_batch.req_id_to_index[req_id] + if req_id not in self.suffix_cache.active_requests: + if req_id in self.suffix_cache.cached_requests: + # Reset the suffix cache for this request. + self.suffix_cache.evict_cached_response(req_id) + num_prompt_tokens = self.input_batch.num_prompt_tokens[index] + prompt_token_ids = ( + self.input_batch.token_ids_cpu[index, :num_prompt_tokens]) + self.suffix_cache.start_request(req_id, prompt_token_ids) + + self.suffix_cache.add_active_response(req_id, sampled_ids) + + # Stop requests that are not seen + for req_id in list(self.suffix_cache.active_requests): + if req_id not in seen_req_ids: + self.suffix_cache.stop_request(req_id) + @torch.inference_mode() def execute_model( self, @@ -2377,6 +2418,9 @@ def propose_draft_token_ids( assert isinstance(self.drafter, NgramProposer) draft_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) + elif self.speculative_config.method == "suffix": + results = self.propose_suffix_draft_token_ids(sampled_token_ids) + draft_token_ids = [result.token_ids for result in results] elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -2521,6 +2565,55 @@ def propose_ngram_draft_token_ids( draft_token_ids.append(drafter_output.tolist()) return draft_token_ids + def propose_suffix_draft_token_ids( + self, + sampled_token_ids: list[list[int]], + ) -> list[list[int]]: + from arctic_inference.suffix_decoding import SuffixDecodingDraft + config = self.speculative_config + results = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + results.append(SuffixDecodingDraft()) + continue + + req_id = self.input_batch.req_ids[i] + + # Add sampled_token_ids to token_ids_cpu. + start_idx = self.input_batch.num_tokens_no_spec[i] + end_idx = start_idx + len(sampled_ids) + + if end_idx >= self.max_model_len: + results.append(SuffixDecodingDraft()) + self.input_batch.token_ids_cpu[ + i, start_idx:self. + max_model_len] = sampled_ids[:self.max_model_len - + start_idx] + continue + + self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids + + size = min(end_idx, config.suffix_decoding_max_tree_depth) + pattern = self.input_batch.token_ids_cpu[i, end_idx - size:end_idx] + pattern = pattern.tolist() + if len(pattern) > config.suffix_decoding_max_tree_depth: + pattern = pattern[-config.suffix_decoding_max_tree_depth:] + + result = self.suffix_cache.speculate( + req_id, + pattern, + max_spec_tokens=min(config.num_speculative_tokens, + self.max_model_len - end_idx - 1), + max_spec_factor=config.suffix_decoding_max_spec_factor, + max_spec_offset=config.suffix_decoding_max_spec_offset, + min_token_prob=config.suffix_decoding_min_token_prob) + + results.append(result) + + return results + def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): From db621b461ccb20bb98f6e30f65bb6385ce03a843 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Fri, 26 Sep 2025 01:33:53 +0000 Subject: [PATCH 02/26] update --- vllm/config/speculative.py | 7 ++--- vllm/v1/worker/gpu_model_runner.py | 50 +++++++++++++----------------- 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index fa69f7eb740d..548b09ed52f6 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -124,15 +124,12 @@ class SpeculativeConfig: suffix_decoding_max_tree_depth: int = 64 """The maximum depth of the suffix decoding tree.""" - suffix_decoding_max_cached_requests: int = 1000 + suffix_decoding_max_cached_requests: int = 10000 """The maximum number of requests to cache in the global suffix tree.""" suffix_decoding_max_spec_factor: float = 1.0 """The maximum speculative factor for suffix decoding.""" - suffix_decoding_max_spec_offset: float = 0.0 - """The maximum speculative offset for suffix decoding.""" - suffix_decoding_min_token_prob: float = 0.1 """The minimum token probability for suffix decoding.""" @@ -278,7 +275,7 @@ def __post_init__(self): self.draft_parallel_config = self.target_parallel_config elif self.method == "suffix": if self.num_speculative_tokens is None: - self.num_speculative_tokens = 16 + self.num_speculative_tokens = 32 # Validate values if self.suffix_decoding_max_tree_depth < 4: raise ValueError( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 340100816c95..42b063031b0f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1191,7 +1191,7 @@ def _prepare_inputs( if (self.speculative_config and spec_decode_common_attn_metadata is None): - if isinstance(self.drafter, EagleProposer): + if isinstance(getattr(self, "drafter", None), EagleProposer): if (self.drafter.attn_layer_names[0] in kv_cache_group_spec.layer_names): spec_decode_common_attn_metadata = common_attn_metadata @@ -2419,8 +2419,8 @@ def propose_draft_token_ids( draft_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) elif self.speculative_config.method == "suffix": - results = self.propose_suffix_draft_token_ids(sampled_token_ids) - draft_token_ids = [result.token_ids for result in results] + draft_token_ids = self.propose_suffix_draft_token_ids( + sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -2570,49 +2570,43 @@ def propose_suffix_draft_token_ids( sampled_token_ids: list[list[int]], ) -> list[list[int]]: from arctic_inference.suffix_decoding import SuffixDecodingDraft + req_ids = self.input_batch.req_ids config = self.speculative_config - results = [] + draft_token_ids = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) if not num_sampled_ids: # Skip speculative decoding. - results.append(SuffixDecodingDraft()) + draft_token_ids.append([]) continue - req_id = self.input_batch.req_ids[i] - - # Add sampled_token_ids to token_ids_cpu. - start_idx = self.input_batch.num_tokens_no_spec[i] - end_idx = start_idx + len(sampled_ids) - - if end_idx >= self.max_model_len: - results.append(SuffixDecodingDraft()) - self.input_batch.token_ids_cpu[ - i, start_idx:self. - max_model_len] = sampled_ids[:self.max_model_len - - start_idx] + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. + req_id = req_ids[i] + if req_id in self.input_batch.spec_decode_unsupported_reqs: + draft_token_ids.append([]) continue - self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids + num_tokens = self.input_batch.num_tokens_no_spec[i] + if num_tokens >= self.max_model_len: + # Skip requests that have already reached the max model length. + draft_token_ids.append([]) + continue - size = min(end_idx, config.suffix_decoding_max_tree_depth) - pattern = self.input_batch.token_ids_cpu[i, end_idx - size:end_idx] + start = max(0, num_tokens - config.suffix_decoding_max_tree_depth) + pattern = self.input_batch.token_ids_cpu[i, start:num_tokens] pattern = pattern.tolist() - if len(pattern) > config.suffix_decoding_max_tree_depth: - pattern = pattern[-config.suffix_decoding_max_tree_depth:] - - result = self.suffix_cache.speculate( + draft = self.suffix_cache.speculate( req_id, pattern, max_spec_tokens=min(config.num_speculative_tokens, - self.max_model_len - end_idx - 1), + self.max_model_len - num_tokens - 1), max_spec_factor=config.suffix_decoding_max_spec_factor, - max_spec_offset=config.suffix_decoding_max_spec_offset, min_token_prob=config.suffix_decoding_min_token_prob) - results.append(result) + draft_token_ids.append(draft.token_ids) - return results + return draft_token_ids def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} From a15727d7e7f99b396b8f797da5eeb188bc75c8e8 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 29 Sep 2025 23:45:21 +0000 Subject: [PATCH 03/26] add tests --- requirements/test.in | 1 + requirements/test.txt | 2 + tests/v1/e2e/test_spec_decode.py | 84 +++++++++++++++++++++++++++++--- 3 files changed, 80 insertions(+), 7 deletions(-) diff --git a/requirements/test.in b/requirements/test.in index 451bd7387910..06799178afcf 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -48,6 +48,7 @@ buildkite-test-collector==0.1.9 genai_perf==0.0.8 tritonclient==2.51.0 +arctic-inference == 0.0.9 # Required for suffix decoding test numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding numba == 0.61.2; python_version > '3.9' numpy diff --git a/requirements/test.txt b/requirements/test.txt index 3519aa524f41..d48b7a202b45 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -39,6 +39,8 @@ anyio==4.6.2.post1 # via # httpx # starlette +arctic-inference==0.0.9 + # via -r requirements/test.in argcomplete==3.5.1 # via datamodel-code-generator arrow==1.3.0 diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index bf90f50b1082..e39e5797ab19 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -77,7 +77,20 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" -def test_ngram_correctness( +@pytest.mark.parametrize("speculative_config", [ + { + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, + { + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + } +]) +def test_ngram_and_suffix_correctness( + speculative_config: dict, monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, model_name: str, @@ -98,12 +111,7 @@ def test_ngram_correctness( spec_llm = LLM( model=model_name, - speculative_config={ - "method": "ngram", - "prompt_lookup_max": 5, - "prompt_lookup_min": 3, - "num_speculative_tokens": 3, - }, + speculative_config=speculative_config, max_model_len=1024, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) @@ -125,6 +133,68 @@ def test_ngram_correctness( cleanup_dist_env_and_memory() +def test_suffix_decoding_acceptance( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + model_name: str, +): + ''' + Check that suffix decoding caching takes effect and improves acceptance + lengths and acceptance rates over multiple runs of the same prompts. + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + test_prompts = get_test_prompts(mm_enabled=False) + + spec_llm = LLM( + model=model_name, + speculative_config={ + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + "suffix_decoding_max_cached_requests": 1000, + }, + max_model_len=1024, + disable_log_stats=False, + ) + + # Run several times and check that the accepted tokens increase. + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + num_draft = [] + num_accept = [] + for i in range(10): # Run multiple times to warm up the cache. + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + # Collect draft and acceptance stats. + metrics = spec_llm.get_metrics() + for metric in metrics: + if metric.name == "vllm:spec_decode_num_draft_tokens": + num_draft.append(metric.value) + if metric.name == "vllm:spec_decode_num_accepted_tokens": + num_accept.append(metric.value) + + # Calculate the acceptance rates for the first and last runs. + first_accept_tokens = num_accept[0] + first_draft_tokens = num_draft[0] + first_accept_rate = first_accept_tokens / first_draft_tokens + + # Take the diff since the stats are cumulative. + last_accept_tokens = num_accept[-1] - num_accept[-2] + last_draft_tokens = num_draft[-1] - num_draft[-2] + last_accept_rate = last_accept_tokens / last_draft_tokens + + # Expect the acceptance length to improve. + assert first_accept_tokens < last_accept_tokens + + # Expect the acceptance rate to improve. + assert first_accept_rate < last_accept_rate + + # Heuristic: expect at least 85% acceptance rate at the end. + assert last_accept_rate > 0.85 + + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + @pytest.mark.parametrize(["model_setup", "mm_enabled"], [ (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), (("eagle", "meta-llama/Llama-3.1-8B-Instruct", From 50d6fd3c74ff15829bd9f1b5519a9395f35f382f Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 29 Sep 2025 23:49:04 +0000 Subject: [PATCH 04/26] fix --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 42b063031b0f..367238afcc94 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2178,6 +2178,7 @@ def _update_suffix_cache(self, sampled_token_ids: list[list[int]]) -> None: num_prompt_tokens = self.input_batch.num_prompt_tokens[index] prompt_token_ids = ( self.input_batch.token_ids_cpu[index, :num_prompt_tokens]) + prompt_token_ids = prompt_token_ids.tolist() self.suffix_cache.start_request(req_id, prompt_token_ids) self.suffix_cache.add_active_response(req_id, sampled_ids) From 92fe31ea6b77135b7ab5cb6eef5be1e7fcb83359 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 30 Sep 2025 02:01:39 +0000 Subject: [PATCH 05/26] docs --- docs/features/spec_decode.md | 37 ++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/docs/features/spec_decode.md b/docs/features/spec_decode.md index 597a8e864427..2dcf0403e6c1 100644 --- a/docs/features/spec_decode.md +++ b/docs/features/spec_decode.md @@ -131,6 +131,43 @@ matching n-grams in the prompt. For more information read [this thread.](https:/ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` +## Speculating using Suffix Decoding + +The following code configures vLLM to use speculative decoding where proposals are generated using Suffix Decoding ([technical report](https://arxiv.org/abs/2411.04975)). + +Like n-gram, Suffix Decoding can generate draft tokens from pattern-matching the prompt. Unlike n-gram, Suffix Decoding (1) can also pattern-match against previous generations, (2) uses frequency counts to propose the most likely continuations, and (3) speculates an adaptive number of tokens for each request at each iteration to get better acceptance rates. + +Suffix Decoding can achieve better performance for tasks with high repetition, such as code-editing, agentic loops (e.g. self-reflection, self-consistency), and RL rollouts. + +!!! tip "Install Arctic Inference" + Suffix Decoding requires [Arctic Inference](https://github.com/snowflakedb/ArcticInference). You can install it with `pip install arctic-inference`. + +??? code + + ```python + from vllm import LLM, SamplingParams + + prompts = [ + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + llm = LLM( + model="facebook/opt-6.7b", + tensor_parallel_size=1, + speculative_config={ + "method": "suffix", + "num_speculative_tokens": 16, + }, + ) + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + ``` + ## Speculating using MLP speculators The following code configures vLLM to use speculative decoding where proposals are generated by From 6755a31d29f70ef02565340d4a097362d271bef8 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 30 Sep 2025 02:51:29 +0000 Subject: [PATCH 06/26] move import check --- vllm/config/speculative.py | 7 +++++++ vllm/v1/worker/gpu_model_runner.py | 9 ++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 548b09ed52f6..b361c05c3ab5 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -274,6 +274,13 @@ def __post_init__(self): self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config elif self.method == "suffix": + try: + from arctic_inference.suffix_decoding import ( + SuffixDecodingCache) + except ImportError: + raise ImportError( + "Arctic Inference is required for suffix decoding. " + "Please install via `pip install arctic-inference`.") if self.num_speculative_tokens is None: self.num_speculative_tokens = 32 # Validate values diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 367238afcc94..b49506422f08 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -273,13 +273,8 @@ def __init__( if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.method == "suffix": - try: - from arctic_inference.suffix_decoding import ( - SuffixDecodingCache) - except ImportError: - raise ImportError( - "Arctic Inference is required for suffix decoding. " - "Please install via `pip install arctic-inference`.") + from arctic_inference.suffix_decoding import ( + SuffixDecodingCache) self.suffix_cache = SuffixDecodingCache( self.speculative_config.suffix_decoding_max_tree_depth, self.speculative_config.suffix_decoding_max_cached_requests) From 72c85e632ade023791104e7b7d01bf389e223f47 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 30 Sep 2025 03:08:39 +0000 Subject: [PATCH 07/26] fix --- tests/v1/e2e/test_spec_decode.py | 124 ++++++++++++++--------------- vllm/v1/worker/gpu_model_runner.py | 43 ---------- 2 files changed, 61 insertions(+), 106 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 8e05666be098..af3f5d8a4691 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -109,21 +109,21 @@ def test_ngram_and_suffix_correctness( torch.cuda.empty_cache() cleanup_dist_env_and_memory() - spec_llm = LLM( - model=model_name, - speculative_config=speculative_config, - max_model_len=1024, - ) - spec_outputs = spec_llm.chat(test_prompts, sampling_config) - matches = 0 - misses = 0 - for ref_output, spec_output in zip(ref_outputs, spec_outputs): - if ref_output.outputs[0].text == spec_output.outputs[0].text: - matches += 1 - else: - misses += 1 - print(f"ref_output: {ref_output.outputs[0].text}") - print(f"spec_output: {spec_output.outputs[0].text}") + spec_llm = LLM( + model=model_name, + speculative_config=speculative_config, + max_model_len=1024, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") # Heuristic: expect at least 66% of the prompts to match exactly # Upon failure, inspect the outputs to check for inaccuracy. @@ -142,57 +142,55 @@ def test_suffix_decoding_acceptance( Check that suffix decoding caching takes effect and improves acceptance lengths and acceptance rates over multiple runs of the same prompts. ''' - with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "1") - test_prompts = get_test_prompts(mm_enabled=False) - - spec_llm = LLM( - model=model_name, - speculative_config={ - "method": "suffix", - "suffix_decoding_max_spec_factor": 2.0, - "suffix_decoding_max_cached_requests": 1000, - }, - max_model_len=1024, - disable_log_stats=False, - ) + test_prompts = get_test_prompts(mm_enabled=False) - # Run several times and check that the accepted tokens increase. + spec_llm = LLM( + model=model_name, + speculative_config={ + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + "suffix_decoding_max_cached_requests": 1000, + }, + max_model_len=1024, + disable_log_stats=False, + ) + + # Run several times and check that the accepted tokens increase. + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + num_draft = [] + num_accept = [] + for i in range(10): # Run multiple times to warm up the cache. spec_outputs = spec_llm.chat(test_prompts, sampling_config) - num_draft = [] - num_accept = [] - for i in range(10): # Run multiple times to warm up the cache. - spec_outputs = spec_llm.chat(test_prompts, sampling_config) - # Collect draft and acceptance stats. - metrics = spec_llm.get_metrics() - for metric in metrics: - if metric.name == "vllm:spec_decode_num_draft_tokens": - num_draft.append(metric.value) - if metric.name == "vllm:spec_decode_num_accepted_tokens": - num_accept.append(metric.value) - - # Calculate the acceptance rates for the first and last runs. - first_accept_tokens = num_accept[0] - first_draft_tokens = num_draft[0] - first_accept_rate = first_accept_tokens / first_draft_tokens - - # Take the diff since the stats are cumulative. - last_accept_tokens = num_accept[-1] - num_accept[-2] - last_draft_tokens = num_draft[-1] - num_draft[-2] - last_accept_rate = last_accept_tokens / last_draft_tokens - - # Expect the acceptance length to improve. - assert first_accept_tokens < last_accept_tokens - - # Expect the acceptance rate to improve. - assert first_accept_rate < last_accept_rate - - # Heuristic: expect at least 85% acceptance rate at the end. - assert last_accept_rate > 0.85 + # Collect draft and acceptance stats. + metrics = spec_llm.get_metrics() + for metric in metrics: + if metric.name == "vllm:spec_decode_num_draft_tokens": + num_draft.append(metric.value) + if metric.name == "vllm:spec_decode_num_accepted_tokens": + num_accept.append(metric.value) - del spec_llm - torch.cuda.empty_cache() - cleanup_dist_env_and_memory() + # Calculate the acceptance rates for the first and last runs. + first_accept_tokens = num_accept[0] + first_draft_tokens = num_draft[0] + first_accept_rate = first_accept_tokens / first_draft_tokens + + # Take the diff since the stats are cumulative. + last_accept_tokens = num_accept[-1] - num_accept[-2] + last_draft_tokens = num_draft[-1] - num_draft[-2] + last_accept_rate = last_accept_tokens / last_draft_tokens + + # Expect the acceptance length to improve. + assert first_accept_tokens < last_accept_tokens + + # Expect the acceptance rate to improve. + assert first_accept_rate < last_accept_rate + + # Heuristic: expect at least 85% acceptance rate at the end. + assert last_accept_rate > 0.85 + + del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() @pytest.mark.parametrize(["model_setup", "mm_enabled"], [ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index dfff3f4bb1ed..289032eddd44 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2753,49 +2753,6 @@ def propose_suffix_draft_token_ids( return draft_token_ids - def propose_suffix_draft_token_ids( - self, - sampled_token_ids: list[list[int]], - ) -> list[list[int]]: - from arctic_inference.suffix_decoding import SuffixDecodingDraft - req_ids = self.input_batch.req_ids - config = self.speculative_config - draft_token_ids = [] - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: - # Skip speculative decoding. - draft_token_ids.append([]) - continue - - # Skip requests that require sampling parameters that are not - # supported with speculative decoding. - req_id = req_ids[i] - if req_id in self.input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append([]) - continue - - num_tokens = self.input_batch.num_tokens_no_spec[i] - if num_tokens >= self.max_model_len: - # Skip requests that have already reached the max model length. - draft_token_ids.append([]) - continue - - start = max(0, num_tokens - config.suffix_decoding_max_tree_depth) - pattern = self.input_batch.token_ids_cpu[i, start:num_tokens] - pattern = pattern.tolist() - draft = self.suffix_cache.speculate( - req_id, - pattern, - max_spec_tokens=min(config.num_speculative_tokens, - self.max_model_len - num_tokens - 1), - max_spec_factor=config.suffix_decoding_max_spec_factor, - min_token_prob=config.suffix_decoding_min_token_prob) - - draft_token_ids.append(draft.token_ids) - - return draft_token_ids - def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): From da82b859d4e16a27de97894ee50656c474f61155 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 30 Sep 2025 04:14:22 +0000 Subject: [PATCH 08/26] SuffixDecodingProposer --- vllm/v1/spec_decode/suffix_decoding.py | 99 ++++++++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 88 +++-------------------- 2 files changed, 107 insertions(+), 80 deletions(-) create mode 100644 vllm/v1/spec_decode/suffix_decoding.py diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py new file mode 100644 index 000000000000..c9090ff6f7e0 --- /dev/null +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.config import VllmConfig +from vllm.v1.worker.gpu_input_batch import InputBatch + + +class SuffixDecodingProposer: + + def __init__(self, vllm_config: VllmConfig): + config = vllm_config.speculative_config + self.num_speculative_tokens = config.num_speculative_tokens + self.max_tree_depth = config.suffix_decoding_max_tree_depth + self.max_spec_factor = config.suffix_decoding_max_spec_factor + self.min_token_prob = config.suffix_decoding_min_token_prob + self.max_model_len = vllm_config.model_config.max_model_len + + # Lazy import to avoid error when Suffix Decoding is not used. + from arctic_inference.suffix_decoding import SuffixDecodingCache + + self.suffix_cache = SuffixDecodingCache( + max_tree_depth=config.suffix_decoding_max_tree_depth, + max_cached_requests=config.suffix_decoding_max_cached_requests) + + def update( + self, + input_batch: InputBatch, + sampled_token_ids: list[list[int]], + ): + seen_req_ids = set() + for i, sampled_ids in enumerate(sampled_token_ids): + req_id = input_batch.req_ids[i] + seen_req_ids.add(req_id) + + if not sampled_ids: + continue + + index = input_batch.req_id_to_index[req_id] + if req_id not in self.suffix_cache.active_requests: + if req_id in self.suffix_cache.cached_requests: + # Reset the suffix cache for this request. + self.suffix_cache.evict_cached_response(req_id) + num_prompt_tokens = input_batch.num_prompt_tokens[index] + prompt_token_ids = ( + input_batch.token_ids_cpu[index, :num_prompt_tokens]) + prompt_token_ids = prompt_token_ids.tolist() + self.suffix_cache.start_request(req_id, prompt_token_ids) + + self.suffix_cache.add_active_response(req_id, sampled_ids) + + # Stop requests that are not seen + for req_id in list(self.suffix_cache.active_requests): + if req_id not in seen_req_ids: + self.suffix_cache.stop_request(req_id) + + def propose( + self, + input_batch: InputBatch, + sampled_token_ids: list[list[int]], + ) -> list[list[int]]: + req_ids = input_batch.req_ids + draft_token_ids = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + draft_token_ids.append([]) + continue + + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. + req_id = req_ids[i] + if req_id in input_batch.spec_decode_unsupported_reqs: + draft_token_ids.append([]) + continue + + num_tokens = input_batch.num_tokens_no_spec[i] + if num_tokens >= self.max_model_len: + # Skip requests that have already reached the max model length. + draft_token_ids.append([]) + continue + + start = max(0, num_tokens - self.max_tree_depth) + pattern = input_batch.token_ids_cpu[i, start:num_tokens] + pattern = pattern.tolist() + draft = self.suffix_cache.speculate( + req_id, + pattern, + max_spec_tokens=min(self.num_speculative_tokens, + self.max_model_len - num_tokens - 1), + max_spec_factor=self.max_spec_factor, + min_token_prob=self.min_token_prob) + + draft_token_ids.append(draft.token_ids) + + return draft_token_ids + + def load_model(self, *args, **kwargs): + # No model to load. + pass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 289032eddd44..8ebe17f96617 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -95,6 +95,7 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch @@ -278,16 +279,11 @@ def __init__( # NOTE(Jiayi): currently we put the entire draft model on # the last PP rank. This is not ideal if there are many # layers in the draft model. - self.suffix_cache = None if self.speculative_config and get_pp_group().is_last_rank: if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.method == "suffix": - from arctic_inference.suffix_decoding import ( - SuffixDecodingCache) - self.suffix_cache = SuffixDecodingCache( - self.speculative_config.suffix_decoding_max_tree_depth, - self.speculative_config.suffix_decoding_max_cached_requests) + self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore @@ -2250,8 +2246,8 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if self.suffix_cache is not None: - self._update_suffix_cache(valid_sampled_token_ids) + if isinstance(getattr(self, "drafter", None), SuffixDecodingProposer): + self.drafter.update(self.input_batch, valid_sampled_token_ids) return ( num_nans_in_logits, @@ -2278,33 +2274,6 @@ def synchronize_input_prep(self): finally: self.prepare_inputs_event.record() - def _update_suffix_cache(self, sampled_token_ids: list[list[int]]) -> None: - seen_req_ids = set() - for i, sampled_ids in enumerate(sampled_token_ids): - req_id = self.input_batch.req_ids[i] - seen_req_ids.add(req_id) - - if not sampled_ids: - continue - - index = self.input_batch.req_id_to_index[req_id] - if req_id not in self.suffix_cache.active_requests: - if req_id in self.suffix_cache.cached_requests: - # Reset the suffix cache for this request. - self.suffix_cache.evict_cached_response(req_id) - num_prompt_tokens = self.input_batch.num_prompt_tokens[index] - prompt_token_ids = ( - self.input_batch.token_ids_cpu[index, :num_prompt_tokens]) - prompt_token_ids = prompt_token_ids.tolist() - self.suffix_cache.start_request(req_id, prompt_token_ids) - - self.suffix_cache.add_active_response(req_id, sampled_ids) - - # Stop requests that are not seen - for req_id in list(self.suffix_cache.active_requests): - if req_id not in seen_req_ids: - self.suffix_cache.stop_request(req_id) - def _model_forward( self, input_ids: Optional[torch.Tensor] = None, @@ -2596,8 +2565,10 @@ def propose_draft_token_ids( self.input_batch.token_ids_cpu, self.input_batch.spec_decode_unsupported_reqs) elif self.speculative_config.method == "suffix": - draft_token_ids = self.propose_suffix_draft_token_ids( - sampled_token_ids) + assert isinstance(sampled_token_ids, list) + assert isinstance(self.drafter, SuffixDecodingProposer) + draft_token_ids = self.drafter.propose( + self.input_batch, sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) @@ -2710,49 +2681,6 @@ def propose_draft_token_ids( return draft_token_ids - def propose_suffix_draft_token_ids( - self, - sampled_token_ids: list[list[int]], - ) -> list[list[int]]: - from arctic_inference.suffix_decoding import SuffixDecodingDraft - req_ids = self.input_batch.req_ids - config = self.speculative_config - draft_token_ids = [] - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: - # Skip speculative decoding. - draft_token_ids.append([]) - continue - - # Skip requests that require sampling parameters that are not - # supported with speculative decoding. - req_id = req_ids[i] - if req_id in self.input_batch.spec_decode_unsupported_reqs: - draft_token_ids.append([]) - continue - - num_tokens = self.input_batch.num_tokens_no_spec[i] - if num_tokens >= self.max_model_len: - # Skip requests that have already reached the max model length. - draft_token_ids.append([]) - continue - - start = max(0, num_tokens - config.suffix_decoding_max_tree_depth) - pattern = self.input_batch.token_ids_cpu[i, start:num_tokens] - pattern = pattern.tolist() - draft = self.suffix_cache.speculate( - req_id, - pattern, - max_spec_tokens=min(config.num_speculative_tokens, - self.max_model_len - num_tokens - 1), - max_spec_factor=config.suffix_decoding_max_spec_factor, - min_token_prob=config.suffix_decoding_min_token_prob) - - draft_token_ids.append(draft.token_ids) - - return draft_token_ids - def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): From 075184428da6f4079da35762313a8cf13631a158 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 30 Sep 2025 04:23:56 +0000 Subject: [PATCH 09/26] fix --- tests/v1/e2e/test_spec_decode.py | 52 ++++++++++++++++-------------- vllm/config/speculative.py | 5 ++- vllm/v1/worker/gpu_model_runner.py | 5 +-- 3 files changed, 35 insertions(+), 27 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index af3f5d8a4691..bda1bd2c6a89 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -193,30 +193,34 @@ def test_suffix_decoding_acceptance( cleanup_dist_env_and_memory() -@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ - (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), - (("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), - (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - False, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - True, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - (("eagle", "eagle618/deepseek-v3-random", - "eagle618/eagle-deepseek-v3-random", 1), False), -], - ids=[ - "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", - "llama4_eagle", "llama4_eagle_mm", - "deepseek_eagle" - ]) +@pytest.mark.parametrize( + ["model_setup", "mm_enabled"], + [ + (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + pytest.param(("eagle3", "Qwen/Qwen2.5-VL-7B-Instruct", + "Rayzl/qwen2.5-vl-7b-eagle3-sgl", 1), + False, + marks=pytest.mark.skip(reason="Skipping due to its " \ + "head_dim not being a a multiple of 32")), + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), + pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + False, + marks=large_gpu_mark(min_gb=80)), # works on 4x H100 + pytest.param(("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + True, + marks=large_gpu_mark(min_gb=80)), # works on 4x H100 + (("eagle", "eagle618/deepseek-v3-random", + "eagle618/eagle-deepseek-v3-random", 1), False), + ], + ids=[ + "qwen3_eagle3", "qwen2_5_vl_eagle3", "llama3_eagle", "llama3_eagle3", + "llama4_eagle", "llama4_eagle_mm", "deepseek_eagle" + ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) def test_eagle_correctness( diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 624b99b4c2e4..7a37d9e9f64d 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -600,6 +600,9 @@ def use_eagle(self) -> bool: def __repr__(self) -> str: method = self.method - model = None if method in ["ngram", "suffix"] else self.draft_model_config.model + if method in ("ngram", "suffix"): + model = None + else: + model = self.draft_model_config.model num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8ebe17f96617..57042d1e0836 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1228,7 +1228,7 @@ def _prepare_inputs( if (self.speculative_config and spec_decode_common_attn_metadata is None): - if isinstance(getattr(self, "drafter", None), EagleProposer): + if isinstance(self.drafter, EagleProposer): if (self.drafter.attn_layer_names[0] in kv_cache_group_spec.layer_names): spec_decode_common_attn_metadata = common_attn_metadata @@ -2246,7 +2246,8 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if isinstance(getattr(self, "drafter", None), SuffixDecodingProposer): + if (self.speculative_config and + isinstance(self.drafter, SuffixDecodingProposer)): self.drafter.update(self.input_batch, valid_sampled_token_ids) return ( From 9692d9d46f5299a0f0ede1aa90a4e0a2a225a8b1 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 30 Sep 2025 04:32:08 +0000 Subject: [PATCH 10/26] docstring --- vllm/config/speculative.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 7a37d9e9f64d..5d2e6c0df023 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -125,16 +125,24 @@ class SpeculativeConfig: # Suffix decoding configuration suffix_decoding_max_tree_depth: int = 64 - """The maximum depth of the suffix decoding tree.""" + """The maximum depth of the suffix decoding global and prompt trees. The + tree depth limits the sum of the prefix match and speculation lengths.""" suffix_decoding_max_cached_requests: int = 10000 - """The maximum number of requests to cache in the global suffix tree.""" + """The maximum number of requests to cache in the global suffix tree. If + exceeded, will trigger eviction in FIFO order. If set to 0, the global + suffix tree is disabled and past responses are not cached (prompt trees + are still used).""" suffix_decoding_max_spec_factor: float = 1.0 - """The maximum speculative factor for suffix decoding.""" + """The maximum spec factor for suffix decoding. The spec factor controls + speculation lengths based on the prefix match length: max_spec_tokens = + max_spec_factor * prefix_match_length.""" suffix_decoding_min_token_prob: float = 0.1 - """The minimum token probability for suffix decoding.""" + """The minimum token probability for suffix decoding. Will only speculate + tokens with estimated probability (based on frequency counts) greater than + or equal to this value.""" def compute_hash(self) -> str: """ From 4536270571f1f92720f8e5aea9433edbdcc855f2 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 30 Sep 2025 17:04:11 +0000 Subject: [PATCH 11/26] precommit --- tests/v1/e2e/test_spec_decode.py | 26 ++++++++++++-------------- vllm/config/speculative.py | 11 ++++------- vllm/utils/__init__.py | 6 ++++++ vllm/v1/spec_decode/suffix_decoding.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 11 ++++++----- 5 files changed, 29 insertions(+), 27 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index bda1bd2c6a89..be386327008a 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -79,18 +79,16 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" -@pytest.mark.parametrize("speculative_config", [ - { - "method": "ngram", - "prompt_lookup_max": 5, - "prompt_lookup_min": 3, - "num_speculative_tokens": 3, - }, - { - "method": "suffix", - "suffix_decoding_max_spec_factor": 2.0, - } -]) +@pytest.mark.parametrize("speculative_config", + [{ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, { + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + }]) def test_ngram_and_suffix_correctness( speculative_config: dict, monkeypatch: pytest.MonkeyPatch, @@ -156,11 +154,11 @@ def test_suffix_decoding_acceptance( ) # Run several times and check that the accepted tokens increase. - spec_outputs = spec_llm.chat(test_prompts, sampling_config) + spec_llm.chat(test_prompts, sampling_config) num_draft = [] num_accept = [] for i in range(10): # Run multiple times to warm up the cache. - spec_outputs = spec_llm.chat(test_prompts, sampling_config) + spec_llm.chat(test_prompts, sampling_config) # Collect draft and acceptance stats. metrics = spec_llm.get_metrics() for metric in metrics: diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 5d2e6c0df023..49ff519ebba2 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -13,7 +13,7 @@ from vllm.config.parallel import ParallelConfig from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import LazyLoader +from vllm.utils import LazyLoader, has_arctic_inference if TYPE_CHECKING: from transformers import PretrainedConfig @@ -295,10 +295,7 @@ def __post_init__(self): self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config elif self.method == "suffix": - try: - from arctic_inference.suffix_decoding import ( - SuffixDecodingCache) - except ImportError: + if not has_arctic_inference(): raise ImportError( "Arctic Inference is required for suffix decoding. " "Please install via `pip install arctic-inference`.") @@ -317,8 +314,8 @@ def __post_init__(self): raise ValueError( f"suffix_decoding_max_spec_factor=" f"{self.suffix_decoding_max_spec_factor} must be >= 0") - if (self.suffix_decoding_min_token_prob < 0 or - self.suffix_decoding_min_token_prob > 1): + if (self.suffix_decoding_min_token_prob < 0 + or self.suffix_decoding_min_token_prob > 1): raise ValueError( f"suffix_decoding_min_token_prob=" f"{self.suffix_decoding_min_token_prob} must be in [0, 1]") diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index ba280d6dbe4a..5f6835e139ba 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3433,6 +3433,12 @@ def has_triton_kernels() -> bool: return _has_module("triton_kernels") +def has_arctic_inference() -> bool: + """Whether the optional `arctic_inference` package is available.""" + + return _has_module("arctic_inference") + + def set_process_title(name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None: diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index c9090ff6f7e0..37136e596a56 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -58,7 +58,7 @@ def propose( sampled_token_ids: list[list[int]], ) -> list[list[int]]: req_ids = input_batch.req_ids - draft_token_ids = [] + draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) if not num_sampled_ids: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 57042d1e0836..a10fbd00cf3a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -283,7 +283,8 @@ def __init__( if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.method == "suffix": - self.drafter = SuffixDecodingProposer(self.vllm_config) + self.drafter = SuffixDecodingProposer( + self.vllm_config) # type: ignore elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore @@ -2246,8 +2247,8 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if (self.speculative_config and - isinstance(self.drafter, SuffixDecodingProposer)): + if (self.speculative_config + and isinstance(self.drafter, SuffixDecodingProposer)): self.drafter.update(self.input_batch, valid_sampled_token_ids) return ( @@ -2568,8 +2569,8 @@ def propose_draft_token_ids( elif self.speculative_config.method == "suffix": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer) - draft_token_ids = self.drafter.propose( - self.input_batch, sampled_token_ids) + draft_token_ids = self.drafter.propose(self.input_batch, + sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) From 6342db2d45049bdeee834144247882d032509d0f Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 30 Sep 2025 17:25:02 +0000 Subject: [PATCH 12/26] minor --- docs/features/spec_decode.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/features/spec_decode.md b/docs/features/spec_decode.md index 2dcf0403e6c1..2ca24181dfaf 100644 --- a/docs/features/spec_decode.md +++ b/docs/features/spec_decode.md @@ -135,7 +135,7 @@ matching n-grams in the prompt. For more information read [this thread.](https:/ The following code configures vLLM to use speculative decoding where proposals are generated using Suffix Decoding ([technical report](https://arxiv.org/abs/2411.04975)). -Like n-gram, Suffix Decoding can generate draft tokens from pattern-matching the prompt. Unlike n-gram, Suffix Decoding (1) can also pattern-match against previous generations, (2) uses frequency counts to propose the most likely continuations, and (3) speculates an adaptive number of tokens for each request at each iteration to get better acceptance rates. +Like n-gram, Suffix Decoding can generate draft tokens by pattern-matching using the last `n` generated tokens. Unlike n-gram, Suffix Decoding (1) can pattern-match against both the prompt and previous generations, (2) uses frequency counts to propose the most likely continuations, and (3) speculates an adaptive number of tokens for each request at each iteration to get better acceptance rates. Suffix Decoding can achieve better performance for tasks with high repetition, such as code-editing, agentic loops (e.g. self-reflection, self-consistency), and RL rollouts. From ecf3efdefc80da6369a8c8fd40a8594474bfa3b1 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 2 Oct 2025 09:40:44 -0700 Subject: [PATCH 13/26] Update speculative.py --- vllm/config/speculative.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 1d9411f41506..ba7ea6a68566 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -302,10 +302,10 @@ def __post_init__(self): if self.num_speculative_tokens is None: self.num_speculative_tokens = 32 # Validate values - if self.suffix_decoding_max_tree_depth < 4: + if self.suffix_decoding_max_tree_depth < 1: raise ValueError( f"suffix_decoding_max_tree_depth=" - f"{self.suffix_decoding_max_tree_depth} must be >= 4") + f"{self.suffix_decoding_max_tree_depth} must be >= 1") if self.suffix_decoding_max_cached_requests < 0: raise ValueError( f"suffix_decoding_max_cached_requests=" From 213276faea2469b5ed1c4b52b0f9daf22fd88885 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 9 Oct 2025 18:34:36 +0000 Subject: [PATCH 14/26] format --- vllm/config/speculative.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index bbc6e55c8729..875284cca6ae 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -322,7 +322,7 @@ def __post_init__(self): if not has_arctic_inference(): raise ImportError( "Arctic Inference is required for suffix decoding. " - "Please install via `pip install arctic-inference`.") + "Install via `pip install arctic-inference==0.0.9`.") if self.num_speculative_tokens is None: self.num_speculative_tokens = 32 # Validate values @@ -648,9 +648,7 @@ def use_eagle(self) -> bool: def __repr__(self) -> str: method = self.method - if method in ("ngram", "suffix"): - model = None - else: - model = self.draft_model_config.model + model = (None if method in ("ngram", "suffix") + else self.draft_model_config.model) num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" From 4724b2578e1420a741bb589f08fec00cbca691b0 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 9 Oct 2025 19:37:18 +0000 Subject: [PATCH 15/26] format --- tests/v1/e2e/test_spec_decode.py | 29 +++++++++++++++----------- vllm/config/speculative.py | 24 +++++++++++++-------- vllm/utils/__init__.py | 6 +++--- vllm/v1/spec_decode/suffix_decoding.py | 15 ++++++------- vllm/v1/worker/gpu_model_runner.py | 9 +++----- 5 files changed, 46 insertions(+), 37 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 457447a9b567..9187adc84a17 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -77,16 +77,21 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" -@pytest.mark.parametrize("speculative_config", - [{ - "method": "ngram", - "prompt_lookup_max": 5, - "prompt_lookup_min": 3, - "num_speculative_tokens": 3, - }, { - "method": "suffix", - "suffix_decoding_max_spec_factor": 2.0, - }]) +@pytest.mark.parametrize( + "speculative_config", + [ + { + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": 3, + }, + { + "method": "suffix", + "suffix_decoding_max_spec_factor": 2.0, + }, + ], +) def test_ngram_and_suffix_correctness( speculative_config: dict, monkeypatch: pytest.MonkeyPatch, @@ -134,10 +139,10 @@ def test_suffix_decoding_acceptance( sampling_config: SamplingParams, model_name: str, ): - ''' + """ Check that suffix decoding caching takes effect and improves acceptance lengths and acceptance rates over multiple runs of the same prompts. - ''' + """ test_prompts = get_test_prompts(mm_enabled=False) spec_llm = LLM( diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 875284cca6ae..c88f4cb5075e 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -322,27 +322,34 @@ def __post_init__(self): if not has_arctic_inference(): raise ImportError( "Arctic Inference is required for suffix decoding. " - "Install via `pip install arctic-inference==0.0.9`.") + "Install via `pip install arctic-inference==0.0.9`." + ) if self.num_speculative_tokens is None: self.num_speculative_tokens = 32 # Validate values if self.suffix_decoding_max_tree_depth < 1: raise ValueError( f"suffix_decoding_max_tree_depth=" - f"{self.suffix_decoding_max_tree_depth} must be >= 1") + f"{self.suffix_decoding_max_tree_depth} must be >= 1" + ) if self.suffix_decoding_max_cached_requests < 0: raise ValueError( f"suffix_decoding_max_cached_requests=" - f"{self.suffix_decoding_max_cached_requests} must be >= 0") + f"{self.suffix_decoding_max_cached_requests} must be >= 0" + ) if self.suffix_decoding_max_spec_factor < 0: raise ValueError( f"suffix_decoding_max_spec_factor=" - f"{self.suffix_decoding_max_spec_factor} must be >= 0") - if (self.suffix_decoding_min_token_prob < 0 - or self.suffix_decoding_min_token_prob > 1): + f"{self.suffix_decoding_max_spec_factor} must be >= 0" + ) + if ( + self.suffix_decoding_min_token_prob < 0 + or self.suffix_decoding_min_token_prob > 1 + ): raise ValueError( f"suffix_decoding_min_token_prob=" - f"{self.suffix_decoding_min_token_prob} must be in [0, 1]") + f"{self.suffix_decoding_min_token_prob} must be in [0, 1]" + ) else: self.prompt_lookup_max = 0 self.prompt_lookup_min = 0 @@ -648,7 +655,6 @@ def use_eagle(self) -> bool: def __repr__(self) -> str: method = self.method - model = (None if method in ("ngram", "suffix") - else self.draft_model_config.model) + model = None if method in ("ngram", "suffix") else self.draft_model_config.model num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 5f9b32a99b66..a410f58721a6 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3309,9 +3309,9 @@ def has_arctic_inference() -> bool: return _has_module("arctic_inference") -def set_process_title(name: str, - suffix: str = "", - prefix: str = envs.VLLM_PROCESS_NAME_PREFIX) -> None: +def set_process_title( + name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX +) -> None: """ Set the current process title to a specific name with an optional suffix. diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index 37136e596a56..d3fab77b9be4 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -5,7 +5,6 @@ class SuffixDecodingProposer: - def __init__(self, vllm_config: VllmConfig): config = vllm_config.speculative_config self.num_speculative_tokens = config.num_speculative_tokens @@ -19,7 +18,8 @@ def __init__(self, vllm_config: VllmConfig): self.suffix_cache = SuffixDecodingCache( max_tree_depth=config.suffix_decoding_max_tree_depth, - max_cached_requests=config.suffix_decoding_max_cached_requests) + max_cached_requests=config.suffix_decoding_max_cached_requests, + ) def update( self, @@ -40,8 +40,7 @@ def update( # Reset the suffix cache for this request. self.suffix_cache.evict_cached_response(req_id) num_prompt_tokens = input_batch.num_prompt_tokens[index] - prompt_token_ids = ( - input_batch.token_ids_cpu[index, :num_prompt_tokens]) + prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens] prompt_token_ids = prompt_token_ids.tolist() self.suffix_cache.start_request(req_id, prompt_token_ids) @@ -85,10 +84,12 @@ def propose( draft = self.suffix_cache.speculate( req_id, pattern, - max_spec_tokens=min(self.num_speculative_tokens, - self.max_model_len - num_tokens - 1), + max_spec_tokens=min( + self.num_speculative_tokens, self.max_model_len - num_tokens - 1 + ), max_spec_factor=self.max_spec_factor, - min_token_prob=self.min_token_prob) + min_token_prob=self.min_token_prob, + ) draft_token_ids.append(draft.token_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 99d8d518358d..fc2f00c59d10 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -323,8 +323,7 @@ def __init__( if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.method == "suffix": - self.drafter = SuffixDecodingProposer( - self.vllm_config) # type: ignore + self.drafter = SuffixDecodingProposer(self.vllm_config) # type: ignore elif self.speculative_config.use_eagle(): self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore if self.speculative_config.method == "eagle3": @@ -2349,8 +2348,7 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if (self.speculative_config - and isinstance(self.drafter, SuffixDecodingProposer)): + if self.speculative_config and isinstance(self.drafter, SuffixDecodingProposer): self.drafter.update(self.input_batch, valid_sampled_token_ids) return ( @@ -2707,8 +2705,7 @@ def propose_draft_token_ids( elif self.speculative_config.method == "suffix": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, SuffixDecodingProposer) - draft_token_ids = self.drafter.propose(self.input_batch, - sampled_token_ids) + draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(sampled_token_ids, list) assert isinstance(self.drafter, MedusaProposer) From dba74a65170013f46685cc79d4deedfc0429056a Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 9 Oct 2025 21:08:01 +0000 Subject: [PATCH 16/26] doc --- docs/features/spec_decode.md | 5 ++++- vllm/v1/spec_decode/suffix_decoding.py | 25 +++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/docs/features/spec_decode.md b/docs/features/spec_decode.md index 885b457ea620..efd35ddc7cc6 100644 --- a/docs/features/spec_decode.md +++ b/docs/features/spec_decode.md @@ -141,6 +141,9 @@ Suffix Decoding can achieve better performance for tasks with high repetition, s !!! tip "Install Arctic Inference" Suffix Decoding requires [Arctic Inference](https://github.com/snowflakedb/ArcticInference). You can install it with `pip install arctic-inference`. +!!! tip "Suffix Decoding Speculative Tokens" + Suffix Decoding will speculate a dynamic number of tokens for each request at each decoding step, so the `num_speculative_tokens` configuration specifies the *maximum* number of speculative tokens. It is suggested to use a high number such as `16` or `32` (default). + ??? code ```python @@ -156,7 +159,7 @@ Suffix Decoding can achieve better performance for tasks with high repetition, s tensor_parallel_size=1, speculative_config={ "method": "suffix", - "num_speculative_tokens": 16, + "num_speculative_tokens": 32, }, ) outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index d3fab77b9be4..d5eaedd55f1b 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -5,6 +5,12 @@ class SuffixDecodingProposer: + """ + Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975). + This class imports and uses the official implementation from Arctic Inference + (https://github.com/snowflakedb/ArcticInference). + """ + def __init__(self, vllm_config: VllmConfig): config = vllm_config.speculative_config self.num_speculative_tokens = config.num_speculative_tokens @@ -16,6 +22,8 @@ def __init__(self, vllm_config: VllmConfig): # Lazy import to avoid error when Suffix Decoding is not used. from arctic_inference.suffix_decoding import SuffixDecodingCache + # Initialize and empty cache. This object will take care of caching request + # outputs, evicting old requests, and manages the per-prompt suffix trees. self.suffix_cache = SuffixDecodingCache( max_tree_depth=config.suffix_decoding_max_tree_depth, max_cached_requests=config.suffix_decoding_max_cached_requests, @@ -26,12 +34,18 @@ def update( input_batch: InputBatch, sampled_token_ids: list[list[int]], ): + """ + Update the suffix cache with the newly sampled token ids for each active request. + Assumes that any request id not in `input_batch.req_ids` is no longer active and + should be stopped (i.e. deletes the per-prompt tree for that request id). + """ seen_req_ids = set() for i, sampled_ids in enumerate(sampled_token_ids): req_id = input_batch.req_ids[i] seen_req_ids.add(req_id) if not sampled_ids: + # No sampled ids for partial prefills. continue index = input_batch.req_id_to_index[req_id] @@ -42,11 +56,13 @@ def update( num_prompt_tokens = input_batch.num_prompt_tokens[index] prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens] prompt_token_ids = prompt_token_ids.tolist() + # Start a new request, this will build the suffix tree for that prompt. self.suffix_cache.start_request(req_id, prompt_token_ids) + # Append the newly sampled ids to the suffix cache for this request. self.suffix_cache.add_active_response(req_id, sampled_ids) - # Stop requests that are not seen + # Stop requests that are not seen in the input batch. for req_id in list(self.suffix_cache.active_requests): if req_id not in seen_req_ids: self.suffix_cache.stop_request(req_id) @@ -56,12 +72,17 @@ def propose( input_batch: InputBatch, sampled_token_ids: list[list[int]], ) -> list[list[int]]: + """ + Propose speculative tokens for each request in the input batch. Suffix Decoding + will speculate a dynamic number of tokens for each request at each decoding step, + so each entry in the returned list may have different lengths. + """ req_ids = input_batch.req_ids draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): num_sampled_ids = len(sampled_ids) if not num_sampled_ids: - # Skip speculative decoding. + # Skip speculative decoding for partial prefills. draft_token_ids.append([]) continue From 0f4555482da899f2f3db9e40098bff6cbe4a0ce0 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Thu, 9 Oct 2025 21:58:53 +0000 Subject: [PATCH 17/26] update --- vllm/v1/spec_decode/suffix_decoding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index d5eaedd55f1b..f3eff7bdf5a2 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -35,8 +35,8 @@ def update( sampled_token_ids: list[list[int]], ): """ - Update the suffix cache with the newly sampled token ids for each active request. - Assumes that any request id not in `input_batch.req_ids` is no longer active and + Update suffix cache with the newly sampled token ids for each active request. + Assumes that request ids not in `input_batch.req_ids` is no longer active and should be stopped (i.e. deletes the per-prompt tree for that request id). """ seen_req_ids = set() @@ -74,7 +74,7 @@ def propose( ) -> list[list[int]]: """ Propose speculative tokens for each request in the input batch. Suffix Decoding - will speculate a dynamic number of tokens for each request at each decoding step, + will speculate a dynamic number of tokens for each request every decoding step, so each entry in the returned list may have different lengths. """ req_ids = input_batch.req_ids From af6a2378b03fa14872bedceef4a865c9ecc29cf8 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Sun, 19 Oct 2025 03:31:21 +0000 Subject: [PATCH 18/26] merge update with propose --- vllm/v1/spec_decode/suffix_decoding.py | 63 ++++++++------------------ vllm/v1/worker/gpu_model_runner.py | 3 -- 2 files changed, 20 insertions(+), 46 deletions(-) diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index f3eff7bdf5a2..b3f3d1624b86 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -29,44 +29,6 @@ def __init__(self, vllm_config: VllmConfig): max_cached_requests=config.suffix_decoding_max_cached_requests, ) - def update( - self, - input_batch: InputBatch, - sampled_token_ids: list[list[int]], - ): - """ - Update suffix cache with the newly sampled token ids for each active request. - Assumes that request ids not in `input_batch.req_ids` is no longer active and - should be stopped (i.e. deletes the per-prompt tree for that request id). - """ - seen_req_ids = set() - for i, sampled_ids in enumerate(sampled_token_ids): - req_id = input_batch.req_ids[i] - seen_req_ids.add(req_id) - - if not sampled_ids: - # No sampled ids for partial prefills. - continue - - index = input_batch.req_id_to_index[req_id] - if req_id not in self.suffix_cache.active_requests: - if req_id in self.suffix_cache.cached_requests: - # Reset the suffix cache for this request. - self.suffix_cache.evict_cached_response(req_id) - num_prompt_tokens = input_batch.num_prompt_tokens[index] - prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens] - prompt_token_ids = prompt_token_ids.tolist() - # Start a new request, this will build the suffix tree for that prompt. - self.suffix_cache.start_request(req_id, prompt_token_ids) - - # Append the newly sampled ids to the suffix cache for this request. - self.suffix_cache.add_active_response(req_id, sampled_ids) - - # Stop requests that are not seen in the input batch. - for req_id in list(self.suffix_cache.active_requests): - if req_id not in seen_req_ids: - self.suffix_cache.stop_request(req_id) - def propose( self, input_batch: InputBatch, @@ -77,18 +39,16 @@ def propose( will speculate a dynamic number of tokens for each request every decoding step, so each entry in the returned list may have different lengths. """ - req_ids = input_batch.req_ids draft_token_ids: list[list[int]] = [] for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: + if not sampled_ids: # Skip speculative decoding for partial prefills. draft_token_ids.append([]) continue # Skip requests that require sampling parameters that are not # supported with speculative decoding. - req_id = req_ids[i] + req_id = input_batch.req_ids[i] if req_id in input_batch.spec_decode_unsupported_reqs: draft_token_ids.append([]) continue @@ -99,9 +59,21 @@ def propose( draft_token_ids.append([]) continue + index = input_batch.req_id_to_index[req_id] + if req_id not in self.suffix_cache.active_requests: + if req_id in self.suffix_cache.cached_requests: + # Reset the suffix cache for this request. + self.suffix_cache.evict_cached_response(req_id) + num_prompt_tokens = input_batch.num_prompt_tokens[index] + prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens] + # Start a new request, this will build the suffix tree for that prompt. + self.suffix_cache.start_request(req_id, prompt_token_ids) + + # Append the newly sampled ids to the suffix cache for this request. + self.suffix_cache.add_active_response(req_id, sampled_ids) + start = max(0, num_tokens - self.max_tree_depth) pattern = input_batch.token_ids_cpu[i, start:num_tokens] - pattern = pattern.tolist() draft = self.suffix_cache.speculate( req_id, pattern, @@ -114,6 +86,11 @@ def propose( draft_token_ids.append(draft.token_ids) + # Stop requests that were not seen in the input batch. + for req_id in (self.suffix_cache.active_requests - + input_batch.req_id_to_index.keys()): + self.suffix_cache.stop_request(req_id) + return draft_token_ids def load_model(self, *args, **kwargs): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fc2f00c59d10..6db1436b2969 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2348,9 +2348,6 @@ def _bookkeeping_sync( req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if self.speculative_config and isinstance(self.drafter, SuffixDecodingProposer): - self.drafter.update(self.input_batch, valid_sampled_token_ids) - return ( num_nans_in_logits, logprobs_lists, From 802d291169346306e5d4ca52ae24ec75a8bb7f2d Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 21 Oct 2025 00:49:21 +0000 Subject: [PATCH 19/26] update --- vllm/config/speculative.py | 14 ++++++++------ vllm/v1/spec_decode/suffix_decoding.py | 7 +++++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 8c6ff7d12ecf..53ebd3073e4e 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -143,7 +143,7 @@ class SpeculativeConfig: """The parallel configuration for the draft model initialized internal.""" # Suffix decoding configuration - suffix_decoding_max_tree_depth: int = 64 + suffix_decoding_max_tree_depth: int = 24 """The maximum depth of the suffix decoding global and prompt trees. The tree depth limits the sum of the prefix match and speculation lengths.""" @@ -326,7 +326,12 @@ def __post_init__(self): "Install via `pip install arctic-inference==0.0.9`." ) if self.num_speculative_tokens is None: - self.num_speculative_tokens = 32 + # Suffix decoding decides the actual number of speculative tokens + # dynamically and treats num_speculative_tokens as a maximum limit. + max_spec_tokens = self.suffix_decoding_max_tree_depth + self.num_speculative_tokens = max_spec_tokens + logger.warning(f"Defaulted num_speculative_tokens to " + f"{max_spec_tokens} for suffix decoding.") # Validate values if self.suffix_decoding_max_tree_depth < 1: raise ValueError( @@ -343,10 +348,7 @@ def __post_init__(self): f"suffix_decoding_max_spec_factor=" f"{self.suffix_decoding_max_spec_factor} must be >= 0" ) - if ( - self.suffix_decoding_min_token_prob < 0 - or self.suffix_decoding_min_token_prob > 1 - ): + if not 0 <= self.suffix_decoding_min_token_prob <= 1: raise ValueError( f"suffix_decoding_min_token_prob=" f"{self.suffix_decoding_min_token_prob} must be in [0, 1]" diff --git a/vllm/v1/spec_decode/suffix_decoding.py b/vllm/v1/spec_decode/suffix_decoding.py index b3f3d1624b86..049e335db325 100644 --- a/vllm/v1/spec_decode/suffix_decoding.py +++ b/vllm/v1/spec_decode/suffix_decoding.py @@ -72,6 +72,8 @@ def propose( # Append the newly sampled ids to the suffix cache for this request. self.suffix_cache.add_active_response(req_id, sampled_ids) + # Suffix decoding only uses the most recent tokens up to max_tree_depth, so + # we extract the pattern from the end of the input. start = max(0, num_tokens - self.max_tree_depth) pattern = input_batch.token_ids_cpu[i, start:num_tokens] draft = self.suffix_cache.speculate( @@ -87,8 +89,9 @@ def propose( draft_token_ids.append(draft.token_ids) # Stop requests that were not seen in the input batch. - for req_id in (self.suffix_cache.active_requests - - input_batch.req_id_to_index.keys()): + for req_id in ( + self.suffix_cache.active_requests - input_batch.req_id_to_index.keys() + ): self.suffix_cache.stop_request(req_id) return draft_token_ids From 7050bc4eecf6d9ef45c5cd65cbcbf4ea3c6ed31f Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 21 Oct 2025 00:56:09 +0000 Subject: [PATCH 20/26] format --- vllm/config/speculative.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 53ebd3073e4e..2810f85d1752 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -330,8 +330,10 @@ def __post_init__(self): # dynamically and treats num_speculative_tokens as a maximum limit. max_spec_tokens = self.suffix_decoding_max_tree_depth self.num_speculative_tokens = max_spec_tokens - logger.warning(f"Defaulted num_speculative_tokens to " - f"{max_spec_tokens} for suffix decoding.") + logger.warning( + f"Defaulted num_speculative_tokens to " + f"{max_spec_tokens} for suffix decoding." + ) # Validate values if self.suffix_decoding_max_tree_depth < 1: raise ValueError( From 1f3d3268138397400fe39ec37f161d61c729c3d5 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 21 Oct 2025 01:01:13 +0000 Subject: [PATCH 21/26] version --- requirements/test.in | 2 +- requirements/test.txt | 2 +- vllm/config/speculative.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements/test.in b/requirements/test.in index 2ccb7d88159f..41fb55c7034d 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -48,7 +48,7 @@ buildkite-test-collector==0.1.9 genai_perf==0.0.8 tritonclient==2.51.0 -arctic-inference == 0.0.9 # Required for suffix decoding test +arctic-inference == 0.1.0 # Required for suffix decoding test numba == 0.61.2 # Required for N-gram speculative decoding numpy runai-model-streamer[s3,gcs]==0.14.0 diff --git a/requirements/test.txt b/requirements/test.txt index bc92d10cbf82..4f749d17cf5f 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -40,7 +40,7 @@ anyio==4.6.2.post1 # via # httpx # starlette -arctic-inference==0.0.9 +arctic-inference==0.1.0 # via -r requirements/test.in argcomplete==3.5.1 # via datamodel-code-generator diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 2810f85d1752..0f71d077b573 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -323,7 +323,7 @@ def __post_init__(self): if not has_arctic_inference(): raise ImportError( "Arctic Inference is required for suffix decoding. " - "Install via `pip install arctic-inference==0.0.9`." + "Install via `pip install arctic-inference==0.1.0`." ) if self.num_speculative_tokens is None: # Suffix decoding decides the actual number of speculative tokens From a79107fec9cf80b766d05bf02b52f119289cb17a Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 21 Oct 2025 01:03:05 +0000 Subject: [PATCH 22/26] f-string --- vllm/config/speculative.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 0f71d077b573..4880e1452a1f 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -328,11 +328,10 @@ def __post_init__(self): if self.num_speculative_tokens is None: # Suffix decoding decides the actual number of speculative tokens # dynamically and treats num_speculative_tokens as a maximum limit. - max_spec_tokens = self.suffix_decoding_max_tree_depth - self.num_speculative_tokens = max_spec_tokens + self.num_speculative_tokens = self.suffix_decoding_max_tree_depth logger.warning( - f"Defaulted num_speculative_tokens to " - f"{max_spec_tokens} for suffix decoding." + "Defaulted num_speculative_tokens to %s for suffix decoding.", + self.num_speculative_tokens, ) # Validate values if self.suffix_decoding_max_tree_depth < 1: From e1d62bc5e8e73f88f058fadb8ec7dbe93e833e2e Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Wed, 22 Oct 2025 04:13:20 +0000 Subject: [PATCH 23/26] type --- vllm/v1/worker/gpu_model_runner.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 29559dacdeb6..c250351606e8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -313,18 +313,21 @@ def __init__( # the last PP rank. This is not ideal if there are many # layers in the draft model. if self.speculative_config and get_pp_group().is_last_rank: + self.drafter: ( + NgramProposer | SuffixDecodingProposer | EagleProposer | MedusaProposer + ) if self.speculative_config.method == "ngram": self.drafter = NgramProposer(self.vllm_config) elif self.speculative_config.method == "suffix": - self.drafter = SuffixDecodingProposer(self.vllm_config) # type: ignore + self.drafter = SuffixDecodingProposer(self.vllm_config) elif self.speculative_config.use_eagle(): - self.drafter = EagleProposer(self.vllm_config, self.device, self) # type: ignore + self.drafter = EagleProposer(self.vllm_config, self.device, self) if self.speculative_config.method == "eagle3": self.use_aux_hidden_state_outputs = True elif self.speculative_config.method == "medusa": self.drafter = MedusaProposer( vllm_config=self.vllm_config, device=self.device - ) # type: ignore + ) else: raise ValueError( "Unknown speculative decoding method: " From d85dc3f0edda5c8e5e7a3fef6f77671ea1f377ef Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 27 Oct 2025 18:56:24 +0000 Subject: [PATCH 24/26] fix import --- vllm/config/speculative.py | 3 +-- vllm/utils/import_utils.py | 6 ++++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index f2d20e0eb229..e3dd3e717b4e 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -13,8 +13,7 @@ from vllm.config.parallel import ParallelConfig from vllm.config.utils import config from vllm.logger import init_logger -from vllm.utils import has_arctic_inference -from vllm.utils.import_utils import LazyLoader +from vllm.utils.import_utils import LazyLoader, has_arctic_inference if TYPE_CHECKING: from transformers import PretrainedConfig diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index 65f588b52e5e..5b0c5015faee 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -360,3 +360,9 @@ def has_triton_kernels() -> bool: def has_tilelang() -> bool: """Whether the optional `tilelang` package is available.""" return _has_module("tilelang") + + +def has_arctic_inference() -> bool: + """Whether the optional `arctic_inference` package is available.""" + + return _has_module("arctic_inference") From aae2b616e0a0414a688876c7663f3a91c0e3eb82 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Mon, 27 Oct 2025 19:21:36 +0000 Subject: [PATCH 25/26] config --- vllm/config/speculative.py | 71 ++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 34 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index e3dd3e717b4e..e1d98ce0bebe 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -318,40 +318,7 @@ def __post_init__(self): self.draft_model_config = self.target_model_config self.draft_parallel_config = self.target_parallel_config elif self.method == "suffix": - if not has_arctic_inference(): - raise ImportError( - "Arctic Inference is required for suffix decoding. " - "Install via `pip install arctic-inference==0.1.0`." - ) - if self.num_speculative_tokens is None: - # Suffix decoding decides the actual number of speculative tokens - # dynamically and treats num_speculative_tokens as a maximum limit. - self.num_speculative_tokens = self.suffix_decoding_max_tree_depth - logger.warning( - "Defaulted num_speculative_tokens to %s for suffix decoding.", - self.num_speculative_tokens, - ) - # Validate values - if self.suffix_decoding_max_tree_depth < 1: - raise ValueError( - f"suffix_decoding_max_tree_depth=" - f"{self.suffix_decoding_max_tree_depth} must be >= 1" - ) - if self.suffix_decoding_max_cached_requests < 0: - raise ValueError( - f"suffix_decoding_max_cached_requests=" - f"{self.suffix_decoding_max_cached_requests} must be >= 0" - ) - if self.suffix_decoding_max_spec_factor < 0: - raise ValueError( - f"suffix_decoding_max_spec_factor=" - f"{self.suffix_decoding_max_spec_factor} must be >= 0" - ) - if not 0 <= self.suffix_decoding_min_token_prob <= 1: - raise ValueError( - f"suffix_decoding_min_token_prob=" - f"{self.suffix_decoding_min_token_prob} must be in [0, 1]" - ) + self._validate_suffix_decoding() else: self.prompt_lookup_max = 0 self.prompt_lookup_min = 0 @@ -506,6 +473,42 @@ def __post_init__(self): ) return self + def _validate_suffix_decoding(self): + if not has_arctic_inference(): + raise ImportError( + "Arctic Inference is required for suffix decoding. " + "Install via `pip install arctic-inference==0.1.0`." + ) + if self.num_speculative_tokens is None: + # Suffix decoding decides the actual number of speculative tokens + # dynamically and treats num_speculative_tokens as a maximum limit. + self.num_speculative_tokens = self.suffix_decoding_max_tree_depth + logger.warning( + "Defaulted num_speculative_tokens to %s for suffix decoding.", + self.num_speculative_tokens, + ) + # Validate values + if self.suffix_decoding_max_tree_depth < 1: + raise ValueError( + f"suffix_decoding_max_tree_depth=" + f"{self.suffix_decoding_max_tree_depth} must be >= 1" + ) + if self.suffix_decoding_max_cached_requests < 0: + raise ValueError( + f"suffix_decoding_max_cached_requests=" + f"{self.suffix_decoding_max_cached_requests} must be >= 0" + ) + if self.suffix_decoding_max_spec_factor < 0: + raise ValueError( + f"suffix_decoding_max_spec_factor=" + f"{self.suffix_decoding_max_spec_factor} must be >= 0" + ) + if not 0 <= self.suffix_decoding_min_token_prob <= 1: + raise ValueError( + f"suffix_decoding_min_token_prob=" + f"{self.suffix_decoding_min_token_prob} must be in [0, 1]" + ) + @staticmethod def _maybe_override_draft_max_model_len( speculative_max_model_len: int | None, From ae0beb3ee7bd743493796d8aa99512752fd6ab32 Mon Sep 17 00:00:00 2001 From: Aurick Qiao Date: Tue, 28 Oct 2025 03:41:23 +0000 Subject: [PATCH 26/26] Trigger CI