Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
2d4180c
Integrate Suffix Decoding
sfc-gh-aqiao Sep 24, 2025
db621b4
update
sfc-gh-aqiao Sep 26, 2025
a15727d
add tests
sfc-gh-aqiao Sep 29, 2025
50d6fd3
fix
sfc-gh-aqiao Sep 29, 2025
92fe31e
docs
sfc-gh-aqiao Sep 30, 2025
6755a31
move import check
sfc-gh-aqiao Sep 30, 2025
6acc529
Merge branch 'main' into HEAD
sfc-gh-aqiao Sep 30, 2025
72c85e6
fix
sfc-gh-aqiao Sep 30, 2025
da82b85
SuffixDecodingProposer
sfc-gh-aqiao Sep 30, 2025
0751844
fix
sfc-gh-aqiao Sep 30, 2025
9692d9d
docstring
sfc-gh-aqiao Sep 30, 2025
4536270
precommit
sfc-gh-aqiao Sep 30, 2025
6a24831
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Sep 30, 2025
6342db2
minor
sfc-gh-aqiao Sep 30, 2025
ecf3efd
Update speculative.py
aurickq Oct 2, 2025
062bafe
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 7, 2025
1e9ac44
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 7, 2025
9774fb3
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 9, 2025
213276f
format
sfc-gh-aqiao Oct 9, 2025
4724b25
format
sfc-gh-aqiao Oct 9, 2025
dba74a6
doc
sfc-gh-aqiao Oct 9, 2025
0f45554
update
sfc-gh-aqiao Oct 9, 2025
af6a237
merge update with propose
sfc-gh-aqiao Oct 19, 2025
e94166b
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 21, 2025
802d291
update
sfc-gh-aqiao Oct 21, 2025
7050bc4
format
sfc-gh-aqiao Oct 21, 2025
1f3d326
version
sfc-gh-aqiao Oct 21, 2025
a79107f
f-string
sfc-gh-aqiao Oct 21, 2025
2d8e05d
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 22, 2025
e1d62bc
type
sfc-gh-aqiao Oct 22, 2025
1d64189
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 24, 2025
149c907
Merge branch 'main' into suffix-decoding
aurickq Oct 24, 2025
d85dc3f
fix import
sfc-gh-aqiao Oct 27, 2025
2f5b451
Merge branch 'suffix-decoding' of https://github.com/aurickq/vllm int…
sfc-gh-aqiao Oct 27, 2025
387fe61
Merge branch 'main' into suffix-decoding
sfc-gh-aqiao Oct 27, 2025
aae2b61
config
sfc-gh-aqiao Oct 27, 2025
7ade5e7
Merge branch 'main' into suffix-decoding
aurickq Oct 28, 2025
ae0beb3
Trigger CI
sfc-gh-aqiao Oct 28, 2025
ba82677
Merge branch 'main' into suffix-decoding
aurickq Oct 28, 2025
cfcfcde
Merge branch 'main' into suffix-decoding
aurickq Oct 29, 2025
78798c2
Merge branch 'main' into suffix-decoding
aurickq Oct 30, 2025
d851834
Merge branch 'main' into suffix-decoding
aurickq Nov 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions docs/features/spec_decode.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,46 @@ matching n-grams in the prompt. For more information read [this thread.](https:/
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

## Speculating using Suffix Decoding

The following code configures vLLM to use speculative decoding where proposals are generated using Suffix Decoding ([technical report](https://arxiv.org/abs/2411.04975)).

Like n-gram, Suffix Decoding can generate draft tokens by pattern-matching using the last `n` generated tokens. Unlike n-gram, Suffix Decoding (1) can pattern-match against both the prompt and previous generations, (2) uses frequency counts to propose the most likely continuations, and (3) speculates an adaptive number of tokens for each request at each iteration to get better acceptance rates.

Suffix Decoding can achieve better performance for tasks with high repetition, such as code-editing, agentic loops (e.g. self-reflection, self-consistency), and RL rollouts.

!!! tip "Install Arctic Inference"
Suffix Decoding requires [Arctic Inference](https://github.com/snowflakedb/ArcticInference). You can install it with `pip install arctic-inference`.

!!! tip "Suffix Decoding Speculative Tokens"
Suffix Decoding will speculate a dynamic number of tokens for each request at each decoding step, so the `num_speculative_tokens` configuration specifies the *maximum* number of speculative tokens. It is suggested to use a high number such as `16` or `32` (default).

??? code

```python
from vllm import LLM, SamplingParams

prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(
model="facebook/opt-6.7b",
tensor_parallel_size=1,
speculative_config={
"method": "suffix",
"num_speculative_tokens": 32,
},
)
outputs = llm.generate(prompts, sampling_params)

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

## Speculating using MLP speculators

The following code configures vLLM to use speculative decoding where proposals are generated by
Expand Down
1 change: 1 addition & 0 deletions requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ buildkite-test-collector==0.1.9
genai_perf==0.0.8
tritonclient==2.51.0

arctic-inference == 0.0.9 # Required for suffix decoding test
numba == 0.61.2 # Required for N-gram speculative decoding
numpy
runai-model-streamer[s3,gcs]==0.14.0
Expand Down
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ anyio==4.6.2.post1
# via
# httpx
# starlette
arctic-inference==0.0.9
# via -r requirements/test.in
argcomplete==3.5.1
# via datamodel-code-generator
arrow==1.3.0
Expand Down
85 changes: 78 additions & 7 deletions tests/v1/e2e/test_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,23 @@ def model_name():
return "meta-llama/Llama-3.1-8B-Instruct"


def test_ngram_correctness(
@pytest.mark.parametrize(
"speculative_config",
[
{
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
{
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
},
],
)
def test_ngram_and_suffix_correctness(
speculative_config: dict,
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
Expand All @@ -96,12 +112,7 @@ def test_ngram_correctness(

spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
speculative_config=speculative_config,
max_model_len=1024,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
Expand All @@ -123,6 +134,66 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()


def test_suffix_decoding_acceptance(
monkeypatch: pytest.MonkeyPatch,
sampling_config: SamplingParams,
model_name: str,
):
"""
Check that suffix decoding caching takes effect and improves acceptance
lengths and acceptance rates over multiple runs of the same prompts.
"""
test_prompts = get_test_prompts(mm_enabled=False)

spec_llm = LLM(
model=model_name,
speculative_config={
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
"suffix_decoding_max_cached_requests": 1000,
},
max_model_len=1024,
disable_log_stats=False,
)

# Run several times and check that the accepted tokens increase.
spec_llm.chat(test_prompts, sampling_config)
num_draft = []
num_accept = []
for i in range(10): # Run multiple times to warm up the cache.
spec_llm.chat(test_prompts, sampling_config)
# Collect draft and acceptance stats.
metrics = spec_llm.get_metrics()
for metric in metrics:
if metric.name == "vllm:spec_decode_num_draft_tokens":
num_draft.append(metric.value)
if metric.name == "vllm:spec_decode_num_accepted_tokens":
num_accept.append(metric.value)

# Calculate the acceptance rates for the first and last runs.
first_accept_tokens = num_accept[0]
first_draft_tokens = num_draft[0]
first_accept_rate = first_accept_tokens / first_draft_tokens

# Take the diff since the stats are cumulative.
last_accept_tokens = num_accept[-1] - num_accept[-2]
last_draft_tokens = num_draft[-1] - num_draft[-2]
last_accept_rate = last_accept_tokens / last_draft_tokens

# Expect the acceptance length to improve.
assert first_accept_tokens < last_accept_tokens

# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate

# Heuristic: expect at least 85% acceptance rate at the end.
assert last_accept_rate > 0.85

del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()


@pytest.mark.parametrize(
["model_setup", "mm_enabled"],
[
Expand Down
60 changes: 58 additions & 2 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils import LazyLoader
from vllm.utils import LazyLoader, has_arctic_inference

if TYPE_CHECKING:
from transformers import PretrainedConfig
Expand Down Expand Up @@ -43,6 +43,7 @@
"mimo_mtp",
"longcat_flash_mtp",
"mtp",
"suffix",
]
MTP_MODEL_TYPES = (
"deepseek_mtp",
Expand Down Expand Up @@ -140,6 +141,27 @@ class SpeculativeConfig:
draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""

# Suffix decoding configuration
suffix_decoding_max_tree_depth: int = 64
"""The maximum depth of the suffix decoding global and prompt trees. The
tree depth limits the sum of the prefix match and speculation lengths."""

suffix_decoding_max_cached_requests: int = 10000
"""The maximum number of requests to cache in the global suffix tree. If
exceeded, will trigger eviction in FIFO order. If set to 0, the global
suffix tree is disabled and past responses are not cached (prompt trees
are still used)."""
Comment on lines +138 to +142
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind to explore the memory usage of the suffix decoding cache? Just curious, if there're other mechanism to limit the cache memory usage in the case that all requests are long context, in that case, the corresponding suffix tree corresponding to 10000 requests could be quite large.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small clarification that only the responses are stored in the global cache, not prompts. Suffix decoding builds one tree for each prompt which is only used while that request is alive, and deleted after the request finishes.

For other ways to limit the global cache size, perhaps a limit based on number of tokens might be better? Open to suggestions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small clarification that only the responses are stored in the global cache, not prompts. Suffix decoding builds one tree for each prompt which is only used while that request is alive, and deleted after the request finishes.

Yeah, make sense. However, there would still be cases with long outputs (e.g. reasoning models). So I think adding some token-level guarding might still be helpful, and we could do that in a relatively simpler way.

  1. max_cached_recent_requests
  2. max_tokens or max_tree_nodes to cache (whenever the value exceeded, we would evict the earliest cached requests to cut down the tree)

WDYT? But I think that could be a followup enhancement, but might worth the flags the memory usages could be linear to max(total tokens in <max_cached_recent_requests> requests) IIUC.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, having an additional config to limit the number of tokens makes sense. Created snowflakedb/ArcticInference#215 to track.


suffix_decoding_max_spec_factor: float = 1.0
"""The maximum spec factor for suffix decoding. The spec factor controls
speculation lengths based on the prefix match length: max_spec_tokens =
max_spec_factor * prefix_match_length."""

suffix_decoding_min_token_prob: float = 0.1
"""The minimum token probability for suffix decoding. Will only speculate
tokens with estimated probability (based on frequency counts) greater than
or equal to this value."""

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down Expand Up @@ -247,6 +269,8 @@ def __post_init__(self):
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
else:
raise ValueError(
"num_speculative_tokens was provided but without speculative model."
Expand Down Expand Up @@ -294,6 +318,38 @@ def __post_init__(self):
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
if not has_arctic_inference():
raise ImportError(
"Arctic Inference is required for suffix decoding. "
"Install via `pip install arctic-inference==0.0.9`."
)
if self.num_speculative_tokens is None:
self.num_speculative_tokens = 32
# Validate values
if self.suffix_decoding_max_tree_depth < 1:
raise ValueError(
f"suffix_decoding_max_tree_depth="
f"{self.suffix_decoding_max_tree_depth} must be >= 1"
)
if self.suffix_decoding_max_cached_requests < 0:
raise ValueError(
f"suffix_decoding_max_cached_requests="
f"{self.suffix_decoding_max_cached_requests} must be >= 0"
)
if self.suffix_decoding_max_spec_factor < 0:
raise ValueError(
f"suffix_decoding_max_spec_factor="
f"{self.suffix_decoding_max_spec_factor} must be >= 0"
)
if (
self.suffix_decoding_min_token_prob < 0
or self.suffix_decoding_min_token_prob > 1
):
raise ValueError(
f"suffix_decoding_min_token_prob="
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
)
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
Expand Down Expand Up @@ -599,6 +655,6 @@ def use_eagle(self) -> bool:

def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
6 changes: 6 additions & 0 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3303,6 +3303,12 @@ def has_tilelang() -> bool:
return _has_module("tilelang")


def has_arctic_inference() -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add funcutils.cache context manager.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the _has_module is already cached

"""Whether the optional `arctic_inference` package is available."""

return _has_module("arctic_inference")


def set_process_title(
name: str, suffix: str = "", prefix: str = envs.VLLM_PROCESS_NAME_PREFIX
) -> None:
Expand Down
Loading
Loading