Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2d4180c
Integrate Suffix Decoding
sfc-gh-aqiao Sep 24, 2025
db621b4
update
sfc-gh-aqiao Sep 26, 2025
a15727d
add tests
sfc-gh-aqiao Sep 29, 2025
50d6fd3
fix
sfc-gh-aqiao Sep 29, 2025
92fe31e
docs
sfc-gh-aqiao Sep 30, 2025
6755a31
move import check
sfc-gh-aqiao Sep 30, 2025
6acc529
Merge branch 'main' into HEAD
sfc-gh-aqiao Sep 30, 2025
72c85e6
fix
sfc-gh-aqiao Sep 30, 2025
da82b85
SuffixDecodingProposer
sfc-gh-aqiao Sep 30, 2025
0751844
fix
sfc-gh-aqiao Sep 30, 2025
9692d9d
docstring
sfc-gh-aqiao Sep 30, 2025
4536270
precommit
sfc-gh-aqiao Sep 30, 2025
6a24831
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Sep 30, 2025
6342db2
minor
sfc-gh-aqiao Sep 30, 2025
ecf3efd
Update speculative.py
aurickq Oct 2, 2025
062bafe
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 7, 2025
1e9ac44
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 7, 2025
9774fb3
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 9, 2025
213276f
format
sfc-gh-aqiao Oct 9, 2025
4724b25
format
sfc-gh-aqiao Oct 9, 2025
dba74a6
doc
sfc-gh-aqiao Oct 9, 2025
0f45554
update
sfc-gh-aqiao Oct 9, 2025
af6a237
merge update with propose
sfc-gh-aqiao Oct 19, 2025
e94166b
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 21, 2025
802d291
update
sfc-gh-aqiao Oct 21, 2025
7050bc4
format
sfc-gh-aqiao Oct 21, 2025
1f3d326
version
sfc-gh-aqiao Oct 21, 2025
a79107f
f-string
sfc-gh-aqiao Oct 21, 2025
2d8e05d
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 22, 2025
e1d62bc
type
sfc-gh-aqiao Oct 22, 2025
1d64189
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 24, 2025
149c907
Merge branch 'main' into suffix-decoding
aurickq Oct 24, 2025
d85dc3f
fix import
sfc-gh-aqiao Oct 27, 2025
2f5b451
Merge branch 'suffix-decoding' of https://github.com/aurickq/vllm int…
sfc-gh-aqiao Oct 27, 2025
387fe61
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 27, 2025
aae2b61
config
sfc-gh-aqiao Oct 27, 2025
7ade5e7
Merge branch 'main' into suffix-decoding
aurickq Oct 28, 2025
ae0beb3
Trigger CI
sfc-gh-aqiao Oct 28, 2025
ba82677
Merge branch 'main' into suffix-decoding
aurickq Oct 28, 2025
cfcfcde
Merge branch 'main' into suffix-decoding
aurickq Oct 29, 2025
78798c2
Merge branch 'main' into suffix-decoding
aurickq Oct 30, 2025
d851834
Merge branch 'main' into suffix-decoding
aurickq Nov 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp"]
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp", "suffix"]


@config
Expand Down Expand Up @@ -120,6 +120,19 @@ class SpeculativeConfig:
ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""

# Suffix decoding configuration
suffix_decoding_max_tree_depth: int = 64
"""The maximum depth of the suffix decoding tree."""

suffix_decoding_max_cached_requests: int = 10000
"""The maximum number of requests to cache in the global suffix tree."""

suffix_decoding_max_spec_factor: float = 1.0
"""The maximum speculative factor for suffix decoding."""

suffix_decoding_min_token_prob: float = 0.1
"""The minimum token probability for suffix decoding."""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -215,6 +228,8 @@ def __post_init__(self):
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
else:
raise ValueError("num_speculative_tokens was provided without "
"speculative model.")
Expand Down Expand Up @@ -258,6 +273,27 @@ def __post_init__(self):
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
if self.num_speculative_tokens is None:
self.num_speculative_tokens = 32
# Validate values
if self.suffix_decoding_max_tree_depth < 4:
raise ValueError(
f"suffix_decoding_max_tree_depth="
f"{self.suffix_decoding_max_tree_depth} must be >= 4")
if self.suffix_decoding_max_cached_requests < 0:
raise ValueError(
f"suffix_decoding_max_cached_requests="
f"{self.suffix_decoding_max_cached_requests} must be >= 0")
if self.suffix_decoding_max_spec_factor < 0:
raise ValueError(
f"suffix_decoding_max_spec_factor="
f"{self.suffix_decoding_max_spec_factor} must be >= 0")
if (self.suffix_decoding_min_token_prob < 0 or
self.suffix_decoding_min_token_prob > 1):
raise ValueError(
f"suffix_decoding_min_token_prob="
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]")
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
Expand Down Expand Up @@ -554,6 +590,6 @@ def use_eagle(self) -> bool:

def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
model = None if method in ["ngram", "suffix"] else self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
89 changes: 88 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,21 @@ def __init__(
# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
# layers in the draft model.
self.suffix_cache = None
if self.speculative_config and get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.method == "suffix":
try:
from arctic_inference.suffix_decoding import (
SuffixDecodingCache)
except ImportError:
raise ImportError(
"Arctic Inference is required for suffix decoding. "
"Please install via `pip install arctic-inference`.")
self.suffix_cache = SuffixDecodingCache(
self.speculative_config.suffix_decoding_max_tree_depth,
self.speculative_config.suffix_decoding_max_cached_requests)
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config, self.device,
self) # type: ignore
Expand Down Expand Up @@ -1179,7 +1191,7 @@ def _prepare_inputs(

if (self.speculative_config
and spec_decode_common_attn_metadata is None):
if isinstance(self.drafter, EagleProposer):
if isinstance(getattr(self, "drafter", None), EagleProposer):
if (self.drafter.attn_layer_names[0]
in kv_cache_group_spec.layer_names):
spec_decode_common_attn_metadata = common_attn_metadata
Expand Down Expand Up @@ -2121,6 +2133,9 @@ def _bookkeeping_sync(
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)

if self.suffix_cache is not None:
self._update_suffix_cache(valid_sampled_token_ids)

return (
num_nans_in_logits,
logprobs_lists,
Expand All @@ -2146,6 +2161,32 @@ def synchronize_input_prep(self):
finally:
self.prepare_inputs_event.record()

def _update_suffix_cache(self, sampled_token_ids: list[list[int]]) -> None:
seen_req_ids = set()
for i, sampled_ids in enumerate(sampled_token_ids):
req_id = self.input_batch.req_ids[i]
seen_req_ids.add(req_id)

if not sampled_ids:
continue

index = self.input_batch.req_id_to_index[req_id]
if req_id not in self.suffix_cache.active_requests:
if req_id in self.suffix_cache.cached_requests:
# Reset the suffix cache for this request.
self.suffix_cache.evict_cached_response(req_id)
num_prompt_tokens = self.input_batch.num_prompt_tokens[index]
prompt_token_ids = (
self.input_batch.token_ids_cpu[index, :num_prompt_tokens])
self.suffix_cache.start_request(req_id, prompt_token_ids)

self.suffix_cache.add_active_response(req_id, sampled_ids)

# Stop requests that are not seen
for req_id in list(self.suffix_cache.active_requests):
if req_id not in seen_req_ids:
self.suffix_cache.stop_request(req_id)

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -2377,6 +2418,9 @@ def propose_draft_token_ids(
assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids)
elif self.speculative_config.method == "suffix":
draft_token_ids = self.propose_suffix_draft_token_ids(
sampled_token_ids)
elif self.speculative_config.method == "medusa":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, MedusaProposer)
Expand Down Expand Up @@ -2521,6 +2565,49 @@ def propose_ngram_draft_token_ids(
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids

def propose_suffix_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
from arctic_inference.suffix_decoding import SuffixDecodingDraft
req_ids = self.input_batch.req_ids
config = self.speculative_config
draft_token_ids = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
draft_token_ids.append([])
continue

# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = req_ids[i]
if req_id in self.input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue

num_tokens = self.input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue

start = max(0, num_tokens - config.suffix_decoding_max_tree_depth)
pattern = self.input_batch.token_ids_cpu[i, start:num_tokens]
pattern = pattern.tolist()
draft = self.suffix_cache.speculate(
req_id,
pattern,
max_spec_tokens=min(config.num_speculative_tokens,
self.max_model_len - num_tokens - 1),
max_spec_factor=config.suffix_decoding_max_spec_factor,
min_token_prob=config.suffix_decoding_min_token_prob)

draft_token_ids.append(draft.token_ids)

return draft_token_ids

def update_config(self, overrides: dict[str, Any]) -> None:
allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():
Expand Down