Skip to content

[Model Runner V2] Enable forcing a specific acceptance rate during rejection sampling#38045

Merged
WoosukKwon merged 1 commit intovllm-project:mainfrom
TheEpicDolphin:gdelfin/mrv2-forced-acceptance-rate
Mar 26, 2026
Merged

[Model Runner V2] Enable forcing a specific acceptance rate during rejection sampling#38045
WoosukKwon merged 1 commit intovllm-project:mainfrom
TheEpicDolphin:gdelfin/mrv2-forced-acceptance-rate

Conversation

@TheEpicDolphin
Copy link
Copy Markdown
Collaborator

@TheEpicDolphin TheEpicDolphin commented Mar 24, 2026

Purpose

Enable testing/debugging with a fixed, expected acceptance rate. To easily support this in MRV2, I added a rejection sampling method to the speculative config (synthetic), and added a field for setting the desired, mean, acceptance rate (synthetic_acceptance_rate). That field is only used if synthetic rejection sampling is selected.

In practice with speculative decoding with > 1 draft tokens, we see a drop off after each position. To simulate this drop-off, the base acceptance rate (position = 0) and decay factor are computed based on the desired, mean acceptance rate. Those are used to compute the per-position conditional acceptance rates, which are used for accepting/rejecting draft tokens. This is all to add realism to the results. The decay factor (minimum value of 0.85, what is generally seen in practice for well-tuned draft models) has to be selected carefully to ensure that the base acceptance rate does not exceed 1, which would prevent from achieving the desired mean acceptance rate. That is what compute_base_acceptance_rate and min_valid_decay_factor do.

Test Plan

Added test to ensure that the base acceptance rate and decay factor are computed correctly, and that the average of the joint probabilities in fact equals the desired value.

(vllm) gdelfin@h200-1:~/vllm-main$ pytest tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py -v
================================================== test session starts ==================================================
platform linux -- Python 3.12.12, pytest-9.0.2, pluggy-1.6.0 -- /home/gdelfin/vllm/.venv/bin/python3
cachedir: .pytest_cache
rootdir: /home/gdelfin/vllm-main
configfile: pyproject.toml
plugins: anyio-4.12.1
collected 7 items                                                                                                       

tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py::test_compute_synthetic_rejection_sampler_params[1] PASSED [ 14%]
tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py::test_compute_synthetic_rejection_sampler_params[2] PASSED [ 28%]
tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py::test_compute_synthetic_rejection_sampler_params[3] PASSED [ 42%]
tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py::test_compute_synthetic_rejection_sampler_params[4] PASSED [ 57%]
tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py::test_compute_synthetic_rejection_sampler_params[5] PASSED [ 71%]
tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py::test_compute_synthetic_rejection_sampler_params[7] PASSED [ 85%]
tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py::test_compute_synthetic_rejection_sampler_params[10] PASSED [100%]

=================================================== 7 passed in 1.09s ===================================================

Additionally I manually tested using spec decoding with Llama3:

Server

VLLM_USE_V2_MODEL_RUNNER=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct --no-enable-prefix-caching --tensor-parallel-size 1 --data-parallel-size 1 --speculative-config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 3, "rejection_sample_method": "synthetic", "synthetic_acceptance_rate": <test-value>}'

Benchmark

vllm bench serve --model meta-llama/Meta-Llama-3-8B-Instruct --tokenizer meta-llama/Meta-Llama-3-8B-Instruct --host 0.0.0.0 --dataset-name hf --dataset-path philschmid/mt-bench --ignore-eos --request-rate inf --max-concurrency 16 --temperature 0

Results

Speculative Decoding Benchmark Results

Metric Target 0.24 Target 0.69 Target 1.0
Acceptance rate (%) 23.98 69.06 100.00
Acceptance length 1.72 3.07 4.00
Drafts 148,650 83,405 64,000
Draft tokens 445,950 250,215 192,000
Accepted tokens 106,952 172,806 192,000
Per-position acceptance (%)      
Position 0 46.92 90.86 100.00
Position 1 18.62 70.20 100.00
Position 2 6.41 46.13 100.00

@mergify mergify bot added the v1 label Mar 24, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new "synthetic" rejection sampling method for speculative decoding, including its configuration and implementation. The RejectionSampler class has been updated to support this new method, alongside existing "strict" and "probabilistic" methods. A review comment suggests replacing an assert statement with a ValueError for validating the synthetic_acceptance_rate to ensure robust error handling.

@TheEpicDolphin TheEpicDolphin changed the title [Model Runner V2] Enable forcing a specific acceptance rate during rejection sampling [WIP][Model Runner V2] Enable forcing a specific acceptance rate during rejection sampling Mar 24, 2026
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-forced-acceptance-rate branch 3 times, most recently from b91ef17 to e91b99b Compare March 25, 2026 01:37
@TheEpicDolphin TheEpicDolphin changed the title [WIP][Model Runner V2] Enable forcing a specific acceptance rate during rejection sampling [Model Runner V2] Enable forcing a specific acceptance rate during rejection sampling Mar 25, 2026
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-forced-acceptance-rate branch from e91b99b to b918c9e Compare March 25, 2026 02:14
@TheEpicDolphin TheEpicDolphin marked this pull request as ready for review March 25, 2026 02:15
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-forced-acceptance-rate branch from b918c9e to 5873fdc Compare March 25, 2026 03:39
@zhewenl
Copy link
Copy Markdown
Collaborator

zhewenl commented Mar 25, 2026

@claude review

Copy link
Copy Markdown
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Way too much complexity here for an optional feature.
  • I don't like the idea of skipping the main rejection sampling and replacing it with a synthetic kernel. If we are using synthetic sampling to approximate a workload, I would prefer to still call the main rejection sampler and then override the outputs so that we never underestimate the cost of the sampling itself

@WoosukKwon
Copy link
Copy Markdown
Collaborator

WoosukKwon commented Mar 25, 2026

@benchislett Thanks for sharing your thought.

Way too much complexity here for an optional feature.

I don't really agree. I think 99% of the "complexity" in this PR can be gone if we relocate the code for synthetic rejection sampling into a separate file. After that, one can clearly tell that the synthetic rejection sampling is not entangled with other options.

I don't like the idea of skipping the main rejection sampling and replacing it with a synthetic kernel. If we are using synthetic sampling to approximate a workload, I would prefer to still call the main rejection sampler and then override the outputs so that we never underestimate the cost of the sampling itself

I think the overhead matches that of strict rejection sampling. Probabilistic rejection sampling will be more expensive, but we probably don't need to model it precisely.

@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-forced-acceptance-rate branch from 5873fdc to f94cfcf Compare March 25, 2026 23:21
@TheEpicDolphin
Copy link
Copy Markdown
Collaborator Author

TheEpicDolphin commented Mar 25, 2026

Addressed the following:

  • Moved the synthetic rejection sampling functions into a separate file to reduce cognitive burden.
  • Set default synthetic_acceptance_rate to None, and raise exception if not set.

And echoing what Woosuk said, the overhead here should closely approximate the strict rejection sample kernel, which is the main one we are focused on approximating.

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 25, 2026
Comment on lines +460 to +467
if (
synthetic_acceptance_rate is None
or not 0.0 <= synthetic_acceptance_rate <= 1.0
):
raise ValueError(
f"synthetic_acceptance_rate must be in [0, 1], "
f"but got {synthetic_acceptance_rate}"
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe we can move this logic into compute_synthetic_rejection_sampler_params? This might look a bit cleaner

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm tempted to, but we'd still want to raise a ValueError here if synthetic_acceptance_rate is None (because it would be weird for compute_synthetic_rejection_sampler_params to accept a nullable input). So all we'd gain in terms of cleanliness is losing the not 0.0 <= synthetic_acceptance_rate <= 1.0 condition here

@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-forced-acceptance-rate branch from f94cfcf to b200a1c Compare March 26, 2026 00:22
@mergify mergify bot added the ci/build label Mar 26, 2026
@TheEpicDolphin
Copy link
Copy Markdown
Collaborator Author

@WoosukKwon added test_synthetic_rejection_sampler_utils to the MRV2 spec decoding group, looks like it's running now as part of checks 👍

Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@benchislett The PR looks good to me. Can you take another look?

return sampled, num_sampled


def compute_synthetic_rejection_sampler_params(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this be too complicated? I was thinking we just impose a position-independent acceptance rate, like accepting every token with a constant ratio (e.g. 90%).

Copy link
Copy Markdown
Collaborator Author

@TheEpicDolphin TheEpicDolphin Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you're suggesting that the exposed parameter determines the fixed conditional acceptance rate at each draft position.

In this PR, the parameter determines the average of the joint per-positon acceptance rates (the average acceptance rate you see in the benchmark metrics). That is the metric i generally find most useful and intuitive, but i'm open to hear other people's thoughts

@WoosukKwon
Copy link
Copy Markdown
Collaborator

@TheEpicDolphin Can you please rebase?

…jection sampling

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-forced-acceptance-rate branch from b200a1c to 860c81e Compare March 26, 2026 18:54
@WoosukKwon WoosukKwon merged commit c32e976 into vllm-project:main Mar 26, 2026
67 checks passed
@TheEpicDolphin TheEpicDolphin deleted the gdelfin/mrv2-forced-acceptance-rate branch March 26, 2026 21:25
malaiwah pushed a commit to malaiwah/vllm that referenced this pull request Mar 27, 2026
…jection sampling (vllm-project#38045)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…jection sampling (vllm-project#38045)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
SandishKumarHN pushed a commit to SandishKumarHN/vllm that referenced this pull request Mar 27, 2026
…hold to 0.91

The LM Eval Large Models (H200) CI job was failing because the
NVIDIA-Nemotron-3-Super-120B-A12B-BF16 model scored slightly below the
0.93 accuracy threshold on GSM8K.

The model uses MTP speculative decoding with 5 speculative tokens. Recent
changes to the Model Runner V2 spec decode path (PRs vllm-project#38045 and vllm-project#38311)
adjusted rejection sampling behavior and rebuilt attention metadata before
eagle decode, which can marginally affect the acceptance rate and therefore
the final accuracy score.

Lower the threshold from 0.93 to 0.91 to reflect the current achievable
accuracy with the updated spec decode implementation. The model still
demonstrates strong GSM8K performance above 91%.

Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…jection sampling (vllm-project#38045)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
SandishKumarHN added a commit to SandishKumarHN/vllm that referenced this pull request Mar 27, 2026
…hold to 0.91

The LM Eval Large Models (H200) CI job was failing because the
NVIDIA-Nemotron-3-Super-120B-A12B-BF16 model scored slightly below the
0.93 accuracy threshold on GSM8K.

The model uses MTP speculative decoding with 5 speculative tokens. Recent
changes to the Model Runner V2 spec decode path (PRs vllm-project#38045 and vllm-project#38311)
adjusted rejection sampling behavior and rebuilt attention metadata before
eagle decode, which can marginally affect the acceptance rate and therefore
the final accuracy score.

Lower the threshold from 0.93 to 0.91 to reflect the current achievable
accuracy with the updated spec decode implementation. The model still
demonstrates strong GSM8K performance above 91%.

Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…jection sampling (vllm-project#38045)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…jection sampling (vllm-project#38045)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…jection sampling (vllm-project#38045)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
bhargav-patel-29 pushed a commit to Bharatgen-Tech/vllm that referenced this pull request Apr 1, 2026
…jection sampling (vllm-project#38045)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: bhargav-patel-29 <bhargav.patel@tihiitb.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants