Skip to content

[Model Runner V2] Spec decode rejection sampler greedy support#37238

Merged
WoosukKwon merged 1 commit intovllm-project:mainfrom
TheEpicDolphin:gdelfin/mrv2-spec-decode-rejection-sample-greedy
Mar 18, 2026
Merged

[Model Runner V2] Spec decode rejection sampler greedy support#37238
WoosukKwon merged 1 commit intovllm-project:mainfrom
TheEpicDolphin:gdelfin/mrv2-spec-decode-rejection-sample-greedy

Conversation

@TheEpicDolphin
Copy link
Copy Markdown
Collaborator

@TheEpicDolphin TheEpicDolphin commented Mar 16, 2026

Purpose

Following up on #35461, specifically with support for greedy sampling (temperature = 0).

In order to support this in an efficient way, I get local argmax/max from target logits for greedy requests in _gather_draft_logits_and_target_argmax_kernel. Then during _probabilistic_rejection_kernel, the target argmax token is sampled only for the greedy requests. This limits the performance impact of greedy requests on the rest of the batch.

Benchmarks

Server

VLLM_USE_V2_MODEL_RUNNER=<version> vllm serve openai/gpt-oss-20b --no-enable-prefix-caching --tensor-parallel-size=1 --data-parallel-size=1 --speculative-config '{"method": "eagle3", "model": "RedHatAI/gpt-oss-20b-speculator.eagle3", "num_speculative_tokens": 3, "rejection_sample_method": "<method>"}'

Client

vllm bench serve --model openai/gpt-oss-20b --tokenizer openai/gpt-oss-20b --host 0.0.0.0 --dataset-name hf --dataset-path philschmid/mt-bench --ignore-eos --request-rate inf --max-concurrency 16 --temperature <temperature>

Base model: openai/gpt-oss-20b
Draft model: RedHatAI/gpt-oss-20b-speculator.eagle3
Benchmark: mt-bench · 1000 requests · max concurrency 16 · 256K total generated tokens

Benchmark Comparison

Metric MRV1 MRV2 (strict) MRV2 (probabilistic)
General
Successful requests 1000 1000 1000
Failed requests 0 0 0
Benchmark duration (s) 73.41 64.10 54.06
Request throughput (req/s) 13.62 15.60 18.50
Output token throughput (tok/s) 3487.49 3994.00 4735.59
Peak output token throughput (tok/s) 1780.00 1830.00 1760.00
Total token throughput (tok/s) 5341.86 6117.68 7253.59
Peak concurrent requests 33 34 42
Time to First Token
Mean TTFT (ms) 34.76 32.42 33.90
Median TTFT (ms) 29.36 28.29 29.14
P99 TTFT (ms) 337.20 256.14 283.18
Time per Output Token (excl. 1st token)
Mean TPOT (ms) 4.43 3.87 3.24
Median TPOT (ms) 4.47 3.87 3.22
P99 TPOT (ms) 5.82 4.94 4.10
Inter-token Latency
Mean ITL (ms) 9.07 8.78 9.08
Median ITL (ms) 8.77 8.45 8.69
P99 ITL (ms) 12.68 13.46 13.98
Speculative Decoding
Acceptance rate (%) 35.17 42.60 60.66
Acceptance length 2.06 2.28 2.82
Drafts 124,457 112,343 90,870
Draft tokens 373,371 337,029 272,610
Accepted tokens 131,313 143,575 165,364
Per-position acceptance (%)
Position 0 54.45 63.48 66.63
Position 1 32.14 41.09 62.25
Position 2 18.91 23.23 53.09

@mergify mergify bot added the v1 label Mar 16, 2026
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample-greedy branch from d0bbcc1 to f6618c3 Compare March 16, 2026 23:17
rejected_steps = sampled.new_empty(num_reqs)
_probabilistic_rejection_sample_kernel[(num_reqs,)](
# [num_reqs]
rejected_pos = pos.new_empty(num_reqs)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt it made more sense to compute this in _probabilistic_rejection_kernel rather than _compute_residual_logits_kernel, so i moved it here. Also, renamed it from residual_pos to rejected_pos.

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 16, 2026

Hi @TheEpicDolphin, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for greedy sampling (temperature=0) in the speculative decoding rejection sampler. The changes are well-structured, introducing new Triton kernels to handle greedy and probabilistic paths efficiently. The logic for rejection sampling and resampling in the greedy case is sound. I've found one potential issue in a newly added but currently unused kernel that should be addressed.

@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample-greedy branch 3 times, most recently from 2c46096 to f188893 Compare March 17, 2026 00:04
Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I think we can fuse more kernels to minimize the materialization of *_logits tensors, but we can probably follow up after this.

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 18, 2026
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample-greedy branch from f188893 to 47f633e Compare March 18, 2026 20:32
@WoosukKwon WoosukKwon enabled auto-merge (squash) March 18, 2026 21:04
@WoosukKwon WoosukKwon merged commit 04244fd into vllm-project:main Mar 18, 2026
60 checks passed
@TheEpicDolphin TheEpicDolphin deleted the gdelfin/mrv2-spec-decode-rejection-sample-greedy branch March 18, 2026 23:02
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request Mar 19, 2026
…project#37238)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
ikaadil pushed a commit to ikaadil/vllm that referenced this pull request Mar 19, 2026
…project#37238)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…project#37238)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…project#37238)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants