Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2d4180c
Integrate Suffix Decoding
sfc-gh-aqiao Sep 24, 2025
db621b4
update
sfc-gh-aqiao Sep 26, 2025
a15727d
add tests
sfc-gh-aqiao Sep 29, 2025
50d6fd3
fix
sfc-gh-aqiao Sep 29, 2025
92fe31e
docs
sfc-gh-aqiao Sep 30, 2025
6755a31
move import check
sfc-gh-aqiao Sep 30, 2025
6acc529
Merge branch 'main' into HEAD
sfc-gh-aqiao Sep 30, 2025
72c85e6
fix
sfc-gh-aqiao Sep 30, 2025
da82b85
SuffixDecodingProposer
sfc-gh-aqiao Sep 30, 2025
0751844
fix
sfc-gh-aqiao Sep 30, 2025
9692d9d
docstring
sfc-gh-aqiao Sep 30, 2025
4536270
precommit
sfc-gh-aqiao Sep 30, 2025
6a24831
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Sep 30, 2025
6342db2
minor
sfc-gh-aqiao Sep 30, 2025
ecf3efd
Update speculative.py
aurickq Oct 2, 2025
062bafe
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 7, 2025
1e9ac44
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 7, 2025
9774fb3
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 9, 2025
213276f
format
sfc-gh-aqiao Oct 9, 2025
4724b25
format
sfc-gh-aqiao Oct 9, 2025
dba74a6
doc
sfc-gh-aqiao Oct 9, 2025
0f45554
update
sfc-gh-aqiao Oct 9, 2025
af6a237
merge update with propose
sfc-gh-aqiao Oct 19, 2025
e94166b
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 21, 2025
802d291
update
sfc-gh-aqiao Oct 21, 2025
7050bc4
format
sfc-gh-aqiao Oct 21, 2025
1f3d326
version
sfc-gh-aqiao Oct 21, 2025
a79107f
f-string
sfc-gh-aqiao Oct 21, 2025
2d8e05d
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 22, 2025
e1d62bc
type
sfc-gh-aqiao Oct 22, 2025
1d64189
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 24, 2025
149c907
Merge branch 'main' into suffix-decoding
aurickq Oct 24, 2025
d85dc3f
fix import
sfc-gh-aqiao Oct 27, 2025
2f5b451
Merge branch 'suffix-decoding' of https://github.com/aurickq/vllm int…
sfc-gh-aqiao Oct 27, 2025
387fe61
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 27, 2025
aae2b61
config
sfc-gh-aqiao Oct 27, 2025
7ade5e7
Merge branch 'main' into suffix-decoding
aurickq Oct 28, 2025
ae0beb3
Trigger CI
sfc-gh-aqiao Oct 28, 2025
ba82677
Merge branch 'main' into suffix-decoding
aurickq Oct 28, 2025
cfcfcde
Merge branch 'main' into suffix-decoding
aurickq Oct 29, 2025
78798c2
Merge branch 'main' into suffix-decoding
aurickq Oct 30, 2025
d851834
Merge branch 'main' into suffix-decoding
aurickq Nov 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/features/spec_decode.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,46 @@ matching n-grams in the prompt. For more information read [this thread.](https:/
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

## Speculating using Suffix Decoding

The following code configures vLLM to use speculative decoding where proposals are generated using Suffix Decoding ([technical report](https://arxiv.org/abs/2411.04975)).

Like n-gram, Suffix Decoding can generate draft tokens by pattern-matching using the last `n` generated tokens. Unlike n-gram, Suffix Decoding (1) can pattern-match against both the prompt and previous generations, (2) uses frequency counts to propose the most likely continuations, and (3) speculates an adaptive number of tokens for each request at each iteration to get better acceptance rates.

Suffix Decoding can achieve better performance for tasks with high repetition, such as code-editing, agentic loops (e.g. self-reflection, self-consistency), and RL rollouts.

!!! tip "Install Arctic Inference"
Suffix Decoding requires [Arctic Inference](https://github.com/snowflakedb/ArcticInference). You can install it with `pip install arctic-inference`.

!!! tip "Suffix Decoding Speculative Tokens"
Suffix Decoding will speculate a dynamic number of tokens for each request at each decoding step, so the `num_speculative_tokens` configuration specifies the *maximum* number of speculative tokens. It is suggested to use a high number such as `16` or `32` (default).

??? code

```python
from vllm import LLM, SamplingParams

prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_config={
"method": "suffix",
"num_speculative_tokens": 32,
},
)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

## Speculating using MLP speculators

The following code configures vLLM to use speculative decoding where proposals are generated by
Expand Down
1 change: 1 addition & 0 deletions requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ buildkite-test-collector==0.1.9
genai_perf==0.0.8
tritonclient==2.51.0

arctic-inference == 0.1.0 # Required for suffix decoding test
numba == 0.61.2 # Required for N-gram speculative decoding
numpy
runai-model-streamer[s3,gcs]==0.14.0
Expand Down
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ anyio==4.6.2.post1
# via
# httpx
# starlette
arctic-inference==0.1.0
# via -r requirements/test.in
argcomplete==3.5.1
# via datamodel-code-generator
arrow==1.3.0
Expand Down
85 changes: 78 additions & 7 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,23 @@ def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"


def test_ngram_correctness(
@pytest.mark.parametrize(
"speculative_config",
[
{
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
{
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
},
],
)
def test_ngram_and_suffix_correctness(
speculative_config: dict,
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
Expand All @@ -94,12 +110,7 @@ def test_ngram_correctness(

spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
speculative_config=speculative_config,
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
Expand All @@ -121,6 +132,66 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()


def test_suffix_decoding_acceptance(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
):
"""
Check that suffix decoding caching takes effect and improves acceptance
lengths and acceptance rates over multiple runs of the same prompts.
"""
test_prompts = get_test_prompts(mm_enabled=False)

spec_llm = LLM(
model=model_name,
speculative_config={
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
"suffix_decoding_max_cached_requests": 1000,
},
max_model_len=1024,
disable_log_stats=False,
)

# Run several times and check that the accepted tokens increase.
spec_llm.chat(test_prompts, sampling_config)
num_draft = []
num_accept = []
for i in range(10): # Run multiple times to warm up the cache.
spec_llm.chat(test_prompts, sampling_config)
# Collect draft and acceptance stats.
metrics = spec_llm.get_metrics()
for metric in metrics:
if metric.name == "vllm:spec_decode_num_draft_tokens":
num_draft.append(metric.value)
if metric.name == "vllm:spec_decode_num_accepted_tokens":
num_accept.append(metric.value)

# Calculate the acceptance rates for the first and last runs.
first_accept_tokens = num_accept[0]
first_draft_tokens = num_draft[0]
first_accept_rate = first_accept_tokens / first_draft_tokens

# Take the diff since the stats are cumulative.
last_accept_tokens = num_accept[-1] - num_accept[-2]
last_draft_tokens = num_draft[-1] - num_draft[-2]
last_accept_rate = last_accept_tokens / last_draft_tokens

# Expect the acceptance length to improve.
assert first_accept_tokens < last_accept_tokens

# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate

# Heuristic: expect at least 85% acceptance rate at the end.
assert last_accept_rate > 0.85

del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
Expand Down
62 changes: 61 additions & 1 deletion vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import has_arctic_inference
from vllm.utils.import_utils import LazyLoader

if TYPE_CHECKING:
Expand Down Expand Up @@ -43,6 +44,7 @@
"mimo_mtp",
"longcat_flash_mtp",
"mtp",
"suffix",
]
MTP_MODEL_TYPES = (
"deepseek_mtp",
Expand Down Expand Up @@ -140,6 +142,27 @@ class SpeculativeConfig:
draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""

# Suffix decoding configuration
suffix_decoding_max_tree_depth: int = 24
"""The maximum depth of the suffix decoding global and prompt trees. The
tree depth limits the sum of the prefix match and speculation lengths."""

suffix_decoding_max_cached_requests: int = 10000
"""The maximum number of requests to cache in the global suffix tree. If
exceeded, will trigger eviction in FIFO order. If set to 0, the global
suffix tree is disabled and past responses are not cached (prompt trees
are still used)."""
Comment on lines +138 to +142
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind to explore the memory usage of the suffix decoding cache? Just curious, if there're other mechanism to limit the cache memory usage in the case that all requests are long context, in that case, the corresponding suffix tree corresponding to 10000 requests could be quite large.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small clarification that only the responses are stored in the global cache, not prompts. Suffix decoding builds one tree for each prompt which is only used while that request is alive, and deleted after the request finishes.

For other ways to limit the global cache size, perhaps a limit based on number of tokens might be better? Open to suggestions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small clarification that only the responses are stored in the global cache, not prompts. Suffix decoding builds one tree for each prompt which is only used while that request is alive, and deleted after the request finishes.

Yeah, make sense. However, there would still be cases with long outputs (e.g. reasoning models). So I think adding some token-level guarding might still be helpful, and we could do that in a relatively simpler way.

  1. max_cached_recent_requests
  2. max_tokens or max_tree_nodes to cache (whenever the value exceeded, we would evict the earliest cached requests to cut down the tree)

WDYT? But I think that could be a followup enhancement, but might worth the flags the memory usages could be linear to max(total tokens in <max_cached_recent_requests> requests) IIUC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, having an additional config to limit the number of tokens makes sense. Created snowflakedb/ArcticInference#215 to track.


suffix_decoding_max_spec_factor: float = 1.0
"""The maximum spec factor for suffix decoding. The spec factor controls
speculation lengths based on the prefix match length: max_spec_tokens =
max_spec_factor * prefix_match_length."""

suffix_decoding_min_token_prob: float = 0.1
"""The minimum token probability for suffix decoding. Will only speculate
tokens with estimated probability (based on frequency counts) greater than
or equal to this value."""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -246,6 +269,8 @@ def __post_init__(self):
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
else:
raise ValueError(
"num_speculative_tokens was provided but without speculative model."
Expand Down Expand Up @@ -293,6 +318,41 @@ def __post_init__(self):
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
if not has_arctic_inference():
raise ImportError(
"Arctic Inference is required for suffix decoding. "
"Install via `pip install arctic-inference==0.1.0`."
)
if self.num_speculative_tokens is None:
# Suffix decoding decides the actual number of speculative tokens
# dynamically and treats num_speculative_tokens as a maximum limit.
self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
logger.warning(
"Defaulted num_speculative_tokens to %s for suffix decoding.",
self.num_speculative_tokens,
)
# Validate values
if self.suffix_decoding_max_tree_depth < 1:
raise ValueError(
f"suffix_decoding_max_tree_depth="
f"{self.suffix_decoding_max_tree_depth} must be >= 1"
)
if self.suffix_decoding_max_cached_requests < 0:
raise ValueError(
f"suffix_decoding_max_cached_requests="
f"{self.suffix_decoding_max_cached_requests} must be >= 0"
)
if self.suffix_decoding_max_spec_factor < 0:
raise ValueError(
f"suffix_decoding_max_spec_factor="
f"{self.suffix_decoding_max_spec_factor} must be >= 0"
)
if not 0 <= self.suffix_decoding_min_token_prob <= 1:
raise ValueError(
f"suffix_decoding_min_token_prob="
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
)
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
Expand Down Expand Up @@ -599,6 +659,6 @@ def use_eagle(self) -> bool:

def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
101 changes: 101 additions & 0 deletions vllm/v1/spec_decode/suffix_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import VllmConfig
from vllm.v1.worker.gpu_input_batch import InputBatch


class SuffixDecodingProposer:
"""
Speculative decoding proposer for Suffix Decoding (https://arxiv.org/pdf/2411.04975).
This class imports and uses the official implementation from Arctic Inference
(https://github.com/snowflakedb/ArcticInference).
"""

def __init__(self, vllm_config: VllmConfig):
config = vllm_config.speculative_config
self.num_speculative_tokens = config.num_speculative_tokens
self.max_tree_depth = config.suffix_decoding_max_tree_depth
self.max_spec_factor = config.suffix_decoding_max_spec_factor
self.min_token_prob = config.suffix_decoding_min_token_prob
self.max_model_len = vllm_config.model_config.max_model_len

# Lazy import to avoid error when Suffix Decoding is not used.
from arctic_inference.suffix_decoding import SuffixDecodingCache

# Initialize and empty cache. This object will take care of caching request
# outputs, evicting old requests, and manages the per-prompt suffix trees.
self.suffix_cache = SuffixDecodingCache(
max_tree_depth=config.suffix_decoding_max_tree_depth,
max_cached_requests=config.suffix_decoding_max_cached_requests,
)

def propose(
self,
input_batch: InputBatch,
sampled_token_ids: list[list[int]],
) -> list[list[int]]:
"""
Propose speculative tokens for each request in the input batch. Suffix Decoding
will speculate a dynamic number of tokens for each request every decoding step,
so each entry in the returned list may have different lengths.
"""
draft_token_ids: list[list[int]] = []
for i, sampled_ids in enumerate(sampled_token_ids):
if not sampled_ids:
# Skip speculative decoding for partial prefills.
draft_token_ids.append([])
continue

# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = input_batch.req_ids[i]
if req_id in input_batch.spec_decode_unsupported_reqs:
draft_token_ids.append([])
continue

num_tokens = input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue
Comment on lines +42 to +60
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I might forget to flush one of my previous comment. Seems there's quite some duplicated code here from NgramProposer. I'm wondering if we should come up with some ModelFreeProposer class, and put the common logic here.

Ideally, that would make the future extensions easier as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems currently it's just the three continue statements that overlap, which is a pretty small part of the NgramProposer


index = input_batch.req_id_to_index[req_id]
if req_id not in self.suffix_cache.active_requests:
if req_id in self.suffix_cache.cached_requests:
# Reset the suffix cache for this request.
self.suffix_cache.evict_cached_response(req_id)
num_prompt_tokens = input_batch.num_prompt_tokens[index]
prompt_token_ids = input_batch.token_ids_cpu[index, :num_prompt_tokens]
# Start a new request, this will build the suffix tree for that prompt.
self.suffix_cache.start_request(req_id, prompt_token_ids)

# Append the newly sampled ids to the suffix cache for this request.
self.suffix_cache.add_active_response(req_id, sampled_ids)

# Suffix decoding only uses the most recent tokens up to max_tree_depth, so
# we extract the pattern from the end of the input.
start = max(0, num_tokens - self.max_tree_depth)
pattern = input_batch.token_ids_cpu[i, start:num_tokens]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b question: IIUC, self.max_tree_depth is served similar to max_ngram in ngram spec decoding, right? And we will use all suffix in the pattern from length=1 to length=max_tree_depth to try finding the most frequently matched subtree in suffix_cache, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in suffix decoding, max_tree_depth is the upper limit for match_length + spec_length, and yeah it will try all suffixes in that range

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in suffix decoding, max_tree_depth is the upper limit for match_length + spec_length, and yeah it will try all suffixes in that range

when the pattern string is very short (1-2 tokens), is the probability of getting the correct predicted token lower? Is it necessary to set a parameter similar to prompt_lookup_min?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah the probability will be lower, but suffix decoding will avoid making that speculation if it's too low, controlled by the suffix_decoding_min_token_prob parameter. basically, it will look at the frequency of all the next tokens in the suffix tree, and will stop speculating if the most frequent token occurs less than suffix_decoding_min_token_prob of the time.

there is this underlying suffix_decoding_max_spec_offset parameter that performs a similar role as prompt_lookup_min but we haven't found any cases where setting that actually helped beyond that existing mechanism, so did not expose it.

draft = self.suffix_cache.speculate(
req_id,
pattern,
max_spec_tokens=min(
self.num_speculative_tokens, self.max_model_len - num_tokens - 1
),
max_spec_factor=self.max_spec_factor,
min_token_prob=self.min_token_prob,
)

draft_token_ids.append(draft.token_ids)

# Stop requests that were not seen in the input batch.
for req_id in (
self.suffix_cache.active_requests - input_batch.req_id_to_index.keys()
):
self.suffix_cache.stop_request(req_id)

return draft_token_ids

def load_model(self, *args, **kwargs):
# No model to load.
pass
Loading