Skip to content

[Model Runner V2] Add probabilistic rejection sampling for spec decoding#35461

Merged
WoosukKwon merged 1 commit intovllm-project:mainfrom
TheEpicDolphin:gdelfin/mrv2-spec-decode-rejection-sample
Mar 11, 2026
Merged

[Model Runner V2] Add probabilistic rejection sampling for spec decoding#35461
WoosukKwon merged 1 commit intovllm-project:mainfrom
TheEpicDolphin:gdelfin/mrv2-spec-decode-rejection-sample

Conversation

@TheEpicDolphin
Copy link
Copy Markdown
Collaborator

@TheEpicDolphin TheEpicDolphin commented Feb 27, 2026

Purpose

Use probabilistic rejection sampling for MRV2 spec decoding to increase draft token acceptance rate.

Currently, strict rejection sampling is being used, where draft tokens are accepted only if they exactly match the target model's sampled tokens. This ensures that the target model's distribution is being respected, but is suboptimal for acceptance rate.

NOTE: I added a SpeculativeConfig parameter that allows setting either strict or probabilistic rejection sampling. By default, strict rejection sampling is used. Probabilistic rejection sampling yields higher acceptance rates, but comes with the tradeoff of more memory usage from caching draft logits.

Benchmarks

Server

VLLM_USE_V2_MODEL_RUNNER=1 vllm serve <target_model> --tensor-parallel-size=1 --data-parallel-size=1 --speculative-config '{"method": <spec_decode_method>, "model": <draft_model>, "num_speculative_tokens": 3, "rejection_sample_method": <method>}'

Client

vllm bench serve --model <target_model> --tokenizer <target_model>  --host 0.0.0.0 --dataset-name hf --dataset-path philschmid/mt-bench --ignore-eos --request-rate inf --max-concurrency 16

Speculative Decoding: Strict vs Probabilistic Rejection Sampling

GLM 4.7 (MTP), 3 Speculative Tokens

Model: zai-org/GLM-4.7-Flash

Metric Strict Probabilistic Delta
Benchmark duration (s) 104.21 90.65 -13.0%
Request throughput (req/s) 9.60 11.03 +14.9%
Output token throughput (tok/s) 2,456.67 2,824.11 +15.0%
Peak output token throughput (tok/s) 1,403.00 1,367.00 -2.6%
Total token throughput (tok/s) 3,176.66 3,651.80 +15.0%
Mean TTFT (ms) 161.23 66.30 -58.9%
Median TTFT (ms) 35.06 35.34 +0.8%
P99 TTFT (ms) 8,407.65 2,245.87 -73.3%
Mean TPOT (ms) 5.87 5.39 -8.2%
Median TPOT (ms) 5.85 5.39 -7.9%
P99 TPOT (ms) 6.53 6.03 -7.7%
Mean ITL (ms) 11.58 11.71 +1.1%
Median ITL (ms) 11.42 11.56 +1.2%
P99 ITL (ms) 13.60 13.72 +0.9%
Acceptance rate (%) 32.60 39.29 +6.69pp
Acceptance length 1.98 2.18 +10.1%
Drafts 129,218 117,409 -9.1%
Draft tokens 387,654 352,227 -9.1%
Accepted tokens 126,372 138,378 +9.5%
Position 0 acceptance (%) 83.47 83.62 +0.15pp
Position 1 acceptance (%) 13.32 28.04 +14.72pp
Position 2 acceptance (%) 1.01 6.19 +5.18pp

Llama 3 (EAGLE), 3 Speculative Tokens

Target: meta-llama/Meta-Llama-3-8B-Instruct
Draft: yuhuili/EAGLE-LLaMA3.1-Instruct-8B

Metric Strict Probabilistic Delta
Benchmark duration (s) 57.46 55.70 -3.1%
Request throughput (req/s) 17.40 17.95 +3.2%
Output token throughput (tok/s) 4,455.01 4,596.33 +3.2%
Peak output token throughput (tok/s) 2,848.00 2,816.00 -1.1%
Total token throughput (tok/s) 5,864.41 6,050.44 +3.2%
Mean TTFT (ms) 35.22 34.80 -1.2%
Median TTFT (ms) 29.24 28.18 -3.6%
P99 TTFT (ms) 360.00 426.68 +18.5%
Mean TPOT (ms) 3.44 3.34 -2.9%
Median TPOT (ms) 3.43 3.31 -3.5%
P99 TPOT (ms) 4.37 4.23 -3.2%
Mean ITL (ms) 5.64 5.66 +0.4%
Median ITL (ms) 5.00 5.10 +2.0%
P99 ITL (ms) 16.17 15.22 -5.9%
Acceptance rate (%) 21.45 23.36 +1.91pp
Acceptance length 1.64 1.70 +3.7%
Drafts 155,502 150,306 -3.3%
Draft tokens 466,506 450,918 -3.3%
Accepted tokens 100,046 105,339 +5.3%
Position 0 acceptance (%) 42.29 43.60 +1.31pp
Position 1 acceptance (%) 16.28 19.11 +2.83pp
Position 2 acceptance (%) 5.77 7.38 +1.61pp

Qwen3-8B (EAGLE-3), 5 Speculative Tokens

Target: Qwen/Qwen3-8B
Draft: RedHatAI/Qwen3-8B-speculator.eagle3

Metric Strict Probabilistic Delta
Benchmark duration (s) 54.75 52.88 -3.4%
Request throughput (req/s) 18.26 18.91 +3.6%
Output token throughput (tok/s) 4,675.65 4,840.71 +3.5%
Peak output token throughput (tok/s) 2,340.00 2,232.00 -4.6%
Total token throughput (tok/s) 6,133.27 6,349.78 +3.5%
Mean TTFT (ms) 43.44 42.09 -3.1%
Median TTFT (ms) 35.36 34.13 -3.5%
P99 TTFT (ms) 510.63 408.70 -20.0%
Mean TPOT (ms) 3.24 3.13 -3.4%
Median TPOT (ms) 3.26 3.14 -3.7%
P99 TPOT (ms) 4.12 4.05 -1.7%
Mean ITL (ms) 6.92 7.11 +2.7%
Median ITL (ms) 5.85 5.98 +2.2%
P99 ITL (ms) 21.59 21.71 +0.6%
Acceptance rate (%) 22.84 25.61 +2.77pp
Acceptance length 2.14 2.28 +6.5%
Drafts 119,475 112,286 -6.0%
Draft tokens 597,375 561,430 -6.0%
Accepted tokens 136,466 143,805 +5.4%
Position 0 acceptance (%) 58.61 59.93 +1.32pp
Position 1 acceptance (%) 31.38 36.60 +5.22pp
Position 2 acceptance (%) 13.97 17.65 +3.68pp
Position 3 acceptance (%) 6.74 9.14 +2.40pp
Position 4 acceptance (%) 3.52 4.75 +1.23pp

Summary

Probabilistic rejection sampling improves acceptance rates across all model configurations (+1.91pp for Llama 3, +2.77pp for Qwen3-8B, +6.69pp for GLM-4.7-Flash), with gains largest at middle draft positions and compounding through later positions. All configurations show throughput improvements (+3.2% for Llama 3, +3.5% for Qwen3-8B, +15.0% for GLM-4.7-Flash), demonstrating that the acceptance rate gains more than offset the overhead of computing the residual distribution, particularly when sampling with higher temperatures. The GLM-4.7-Flash MTP result shows the largest gains, with position 1 acceptance more than doubling (13.32% → 28.04%), highlighting that probabilistic rejection sampling is especially effective when draft model quality drops off steeply at later positions.

Config: 1000 requests, max concurrency 16, temperature 1.0, MT-Bench dataset.

Manual Testing

Served 8 requests concurrently. Below are the prompts and responses:

Test Prompts and Responses

Prompt 0: "Explain the theory of relativity in simple terms."

The theory of relativity! It's a mind-bending concept that can be tricky to wrap your head around, but I'll try to simplify it for you.

What is the theory of relativity?

The theory of relativity, developed by Albert Einstein, is a way of understanding how the universe works. It's based on two main ideas: special relativity and general relativity.

Special Relativity (1905)

Imagine you're on a train, and you throw a ball straight up in the air. What happens? The ball comes back down and lands in your hand, right? Now, imagine someone is standing outside the train, watching you throw the ball. From their perspective, the ball doesn't just go straight up and down – it also moves forward, because the train is moving really fast.

Einstein said that how we measure time and space depends on how fast we're moving and where we are. This is called "relativity." It means that:

  1. Time can seem to pass slower or faster depending on how fast you're moving.
  2. Distance can appear shorter or longer depending on how fast you're moving and where you are.

General Relativity (1915)

Now, imagine you're standing on a tr...

Prompt 1: "What is the capital of France?"

The capital of France is Paris.

Prompt 2: "Write a haiku about coding."

Here is a haiku about coding:

Lines of code unfold
Meaning hidden in the dark
Beauty in the bug

Prompt 3: "List three benefits of regular exercise."

Here are three benefits of regular exercise:

  1. Improves Physical Health: Regular exercise can help to strengthen your heart and lungs, improving circulation and reducing the risk of heart disease, stroke, and high blood pressure. Exercise can also help to manage your weight, reduce the risk of type 2 diabetes, and improve your overall physical fitness.
  2. Boosts Mental Health: Exercise has been shown to have a positive impact on mental health, reducing symptoms of anxiety and depression, and improving mood. Regular exercise can also help to improve sleep quality, reduce stress, and increase self-esteem.
  3. Increases Energy and Productivity: Regular exercise can help to increase energy levels and improve overall physical function. Exercise has also been shown to improve cognitive function, including concentration, memory, and problem-solving skills, which can help to boost productivity and overall well-being.

Overall, regular exercise can have numerous benefits for both physical and mental health, making it an important part of a healthy lifestyle.

Prompt 4: "How does a refrigerator keep food cold?"

A refrigerator keeps food cold by using a combination of technologies to transfer heat from the interior of the fridge to the exterior. Here's a simplified explanation:

  1. Refrigerant: The refrigerator contains a refrigerant, a liquid that changes state from liquid to gas as it absorbs heat. The most common refrigerant is Freon (R-22), although many modern refrigerators use alternative refrigerants like R-410A.
  2. Compressor: The compressor is the heart of the refrigerator. It compresses the refrigerant, causing its temperature and pressure to increase. This hot, high-pressure gas is then pumped into the condenser coils.
  3. Condenser coils: The condenser coils are located outside the refrigerator and are usually at the back or bottom. As the hot gas flows through these coils, it releases its heat to the surrounding air and condenses back into a liquid. This process is called condensation.
  4. Expansion valve: The liquid refrigerant then passes through an expansion valve, which reduces its pressure and allows it to expand. This causes the temperature to drop even further.
  5. Evaporator coils: The cold, low-pressure refrigerant then enters the evaporator coils, which are usually located inside the refrigerator. As...
Prompt 5: "What is the difference between HTTP and HTTPS?"

HTTP (Hypertext Transfer Protocol) and HTTPS (Hypertext Transfer Protocol Secure) are two protocols used to transfer data over the internet. The main difference between them is the level of security and encryption used to protect the data being transmitted.

HTTP:

HTTP is an unencrypted protocol, which means that any data transmitted over HTTP is sent in plain text. This makes it vulnerable to eavesdropping, tampering, and man-in-the-middle attacks. When you enter a URL starting with "http", your browser sends an unencrypted request to the server, and the server responds with unencrypted data.

HTTPS:

HTTPS is an encrypted protocol that adds an extra layer of security to HTTP. It uses Transport Layer Security (TLS) or Secure Sockets Layer (SSL) to encrypt the data being transmitted. This means that any data sent over HTTPS is encrypted, making it much harder for hackers to intercept and read the data.

Here are some key differences between HTTP and HTTPS:

  1. Encryption: HTTPS uses encryption to protect data, while HTTP does not.
  2. Security: HTTPS is more secure than HTTP, as it encrypts data and makes it harder for hackers to intercept and read it.
  3. Trust: HTTPS is trusted by most...
Prompt 6: "Suggest a short book to read on a rainy day."

A perfect excuse to stay indoors and get cozy with a good book! Here's a short book suggestion for a rainy day:

"The Little Paris Bookshop" by Nina George (192 pages)

This charming novel tells the story of Jean Perdu, a melancholic bookseller who sets out to find his lost love and rediscover himself. The story is set in the charming streets of Paris, where Jean's bookshop is a haven for book lovers. The book is a delightful blend of romance, literature, and self-discovery, making it a perfect companion for a rainy day.

Why it's perfect for a rainy day:

  1. Short and sweet: With only 192 pages, you can easily finish this book in one sitting.
  2. Cozy atmosphere: The book's setting in a Parisian bookshop and the protagonist's introspective nature will transport you to a cozy and intimate atmosphere.
  3. Uplifting story: The story is a heartwarming tale of love, loss, and self-discovery, which will leave you feeling uplifted and inspired.

So, grab a cup of hot tea, snuggle up with a blanket, and get lost in the charming world of "The Little Paris Bookshop" on your rainy day.

Prompt 7: "2+2=?"

The answer is 4!

@mergify mergify bot added the v1 label Feb 27, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces probabilistic rejection sampling for speculative decoding, which is a nice enhancement. The implementation involves passing draft model logits through the system to the new sampling function. The core logic seems sound, but I've identified a critical issue in the resampling probability calculation that could lead to a crash due to division by zero or invalid input to torch.multinomial. My review includes a suggestion to fix this for improved numerical stability.

@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample branch 12 times, most recently from 2e308c9 to 3e657f5 Compare March 2, 2026 20:32
@TheEpicDolphin TheEpicDolphin marked this pull request as ready for review March 2, 2026 20:32
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 2, 2026

Hi @TheEpicDolphin, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample branch from 3e657f5 to ce91568 Compare March 2, 2026 20:41
@TheEpicDolphin TheEpicDolphin mentioned this pull request Mar 2, 2026
@TheEpicDolphin TheEpicDolphin changed the title [WIP][Model Runner V2] Add probabilistic rejection sampling for spec decoding [Model Runner V2] Add probabilistic rejection sampling for spec decoding Mar 2, 2026
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample branch from ce91568 to 74a2273 Compare March 3, 2026 00:53
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample branch 2 times, most recently from 6d08ae6 to 3808b63 Compare March 4, 2026 23:05
Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! The results look reasonable.

I found one critical issue: The code tries to gather draft_logits, which could cause unnecessary large tensor allocation. Please fix this.

Also, left some comments on the style.

Comment on lines +807 to +809
self.req_states.draft_logits[input_batch.idx_mapping]
if self.req_states.draft_logits is not None
else None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't directly gather draft_logits because the tensor is big. Please use idx_mapping inside the Triton kernel to get corresponding rows.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushing back on this, I did profiling yesterday and found that gathering the draft logits first, and then passing into the kernels is significantly faster than index-mapping in the kernel, due to cache locality.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please provide the profiler trace and the commands to reproduce it?
Also, besides performance, it increases the GPU memory usage, which could be concerning.

Copy link
Copy Markdown
Collaborator Author

@TheEpicDolphin TheEpicDolphin Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not able to get profiler traces atm, but i reproduced this by running the following on server:

VLLM_USE_V2_MODEL_RUNNER=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct --no-enable-prefix-caching --tensor-parallel-size=1 --data-parallel-size=1 --speculative-config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 3, "rejection_sample_method": "probabilistic"}'

And the following benchmark:

vllm bench serve --model meta-llama/Meta-Llama-3-8B-Instruct --tokenizer meta-llama/Meta-Llama-3-8B-Instruct  --host 0.0.0.0 --dataset-name hf --dataset-path philschmid/mt-bench --ignore-eos --request-rate inf --max-concurrency 16 --temperature 1.0

I ran it with the current implementation, and then again after index-mapping in the kernel. The changes for that are simple, just apply idx_mapping offset to each req_idx in _probabilistic_rejection_sample_kernel and _compute_residual_logits_kernel, and use that to index into the draft probs instead.

Looking at it again, the cache locality may not be the main reason for the perf difference (or even most of it). It's probably from the fact that we have to run torch.softmax(draft_logits, dim=-1). If we don't index-map beforehand, then this will operate on the entire draft logits buffer for max requests. So it seems best to keep it for now, and later optimize it by applying softmax in a custom kernel to avoid operation on all draft logits.

Comment on lines +392 to +393
# TODO (TheEpicDolphin): Return logprobs for accepted token ids.
logprobs_tensors=None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: What is needed to get this?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least for use_strict_rejection_sampling, I think logprobs are supported.

Copy link
Copy Markdown
Collaborator Author

@TheEpicDolphin TheEpicDolphin Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit of extra work is needed to support logprobs_tensors for the probabilistic case. I need to gather the sampled tokens into the correct shape for compute_topk_logprobs.

I fixed it so that logprobs are supported for strict rejection sampling.

@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample branch from 8274aa6 to c8551d1 Compare March 5, 2026 18:56
@mergify mergify bot removed the needs-rebase label Mar 5, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 9, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @TheEpicDolphin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 9, 2026
@WoosukKwon
Copy link
Copy Markdown
Collaborator

@TheEpicDolphin Can you please rebase again? sorry for the frequent changes 😅

@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample branch from c8551d1 to c52ce3c Compare March 10, 2026 19:26
@mergify mergify bot removed the needs-rebase label Mar 10, 2026
@TheEpicDolphin
Copy link
Copy Markdown
Collaborator Author

@WoosukKwon rebased!

Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TheEpicDolphin Sorry for the delays in review. LGTM overall. I left some final minor comments on style. Please address them and merge this asap!

Also, please follow up with:

  1. Kernel fusion for draft_logits[idx_mapping] and two following softmaxs.
  2. Logprob support.

Comment on lines +75 to +76
# Draft token logits.
self.draft_logits: torch.Tensor | None = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment like below

Suggested change
# Draft token logits.
self.draft_logits: torch.Tensor | None = None
# Draft token logits.
# NOTE: This tensor maintains the "processed" logits after applying temperature, top-p, etc.
self.draft_logits: torch.Tensor | None = None

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, should we use FP32?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that model dtype is the right choice here because fp32 will double the already large memory allocation, for not much extra gain.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TheEpicDolphin How do you know it's not much of gain? I think this could be a sensitive problem given that it affects model accuracy now (not just the acceptance rate).

@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample branch from c52ce3c to a894279 Compare March 11, 2026 18:03
@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 11, 2026
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-spec-decode-rejection-sample branch from a894279 to f25399a Compare March 11, 2026 19:24
@WoosukKwon WoosukKwon merged commit c77181e into vllm-project:main Mar 11, 2026
46 of 51 checks passed
@TheEpicDolphin TheEpicDolphin deleted the gdelfin/mrv2-spec-decode-rejection-sample branch March 12, 2026 01:45
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
…ing (vllm-project#35461)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
…ing (vllm-project#35461)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…ing (vllm-project#35461)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
…ing (vllm-project#35461)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
…ing (vllm-project#35461)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants