[Model Runner V2] Add probabilistic rejection sampling for spec decoding#35461
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces probabilistic rejection sampling for speculative decoding, which is a nice enhancement. The implementation involves passing draft model logits through the system to the new sampling function. The core logic seems sound, but I've identified a critical issue in the resampling probability calculation that could lead to a crash due to division by zero or invalid input to torch.multinomial. My review includes a suggestion to fix this for improved numerical stability.
2e308c9 to
3e657f5
Compare
|
Hi @TheEpicDolphin, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
3e657f5 to
ce91568
Compare
ce91568 to
74a2273
Compare
6d08ae6 to
3808b63
Compare
WoosukKwon
left a comment
There was a problem hiding this comment.
Thanks for the PR! The results look reasonable.
I found one critical issue: The code tries to gather draft_logits, which could cause unnecessary large tensor allocation. Please fix this.
Also, left some comments on the style.
vllm/v1/worker/gpu/model_runner.py
Outdated
| self.req_states.draft_logits[input_batch.idx_mapping] | ||
| if self.req_states.draft_logits is not None | ||
| else None, |
There was a problem hiding this comment.
We shouldn't directly gather draft_logits because the tensor is big. Please use idx_mapping inside the Triton kernel to get corresponding rows.
There was a problem hiding this comment.
Pushing back on this, I did profiling yesterday and found that gathering the draft logits first, and then passing into the kernels is significantly faster than index-mapping in the kernel, due to cache locality.
There was a problem hiding this comment.
Can you please provide the profiler trace and the commands to reproduce it?
Also, besides performance, it increases the GPU memory usage, which could be concerning.
There was a problem hiding this comment.
Not able to get profiler traces atm, but i reproduced this by running the following on server:
VLLM_USE_V2_MODEL_RUNNER=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct --no-enable-prefix-caching --tensor-parallel-size=1 --data-parallel-size=1 --speculative-config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 3, "rejection_sample_method": "probabilistic"}'
And the following benchmark:
vllm bench serve --model meta-llama/Meta-Llama-3-8B-Instruct --tokenizer meta-llama/Meta-Llama-3-8B-Instruct --host 0.0.0.0 --dataset-name hf --dataset-path philschmid/mt-bench --ignore-eos --request-rate inf --max-concurrency 16 --temperature 1.0
I ran it with the current implementation, and then again after index-mapping in the kernel. The changes for that are simple, just apply idx_mapping offset to each req_idx in _probabilistic_rejection_sample_kernel and _compute_residual_logits_kernel, and use that to index into the draft probs instead.
Looking at it again, the cache locality may not be the main reason for the perf difference (or even most of it). It's probably from the fact that we have to run torch.softmax(draft_logits, dim=-1). If we don't index-map beforehand, then this will operate on the entire draft logits buffer for max requests. So it seems best to keep it for now, and later optimize it by applying softmax in a custom kernel to avoid operation on all draft logits.
| # TODO (TheEpicDolphin): Return logprobs for accepted token ids. | ||
| logprobs_tensors=None, |
There was a problem hiding this comment.
QQ: What is needed to get this?
There was a problem hiding this comment.
At least for use_strict_rejection_sampling, I think logprobs are supported.
There was a problem hiding this comment.
A bit of extra work is needed to support logprobs_tensors for the probabilistic case. I need to gather the sampled tokens into the correct shape for compute_topk_logprobs.
I fixed it so that logprobs are supported for strict rejection sampling.
8274aa6 to
c8551d1
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
|
@TheEpicDolphin Can you please rebase again? sorry for the frequent changes 😅 |
c8551d1 to
c52ce3c
Compare
|
@WoosukKwon rebased! |
WoosukKwon
left a comment
There was a problem hiding this comment.
@TheEpicDolphin Sorry for the delays in review. LGTM overall. I left some final minor comments on style. Please address them and merge this asap!
Also, please follow up with:
- Kernel fusion for
draft_logits[idx_mapping]and two followingsoftmaxs. - Logprob support.
| # Draft token logits. | ||
| self.draft_logits: torch.Tensor | None = None |
There was a problem hiding this comment.
Please add a comment like below
| # Draft token logits. | |
| self.draft_logits: torch.Tensor | None = None | |
| # Draft token logits. | |
| # NOTE: This tensor maintains the "processed" logits after applying temperature, top-p, etc. | |
| self.draft_logits: torch.Tensor | None = None |
There was a problem hiding this comment.
Also, should we use FP32?
There was a problem hiding this comment.
I feel that model dtype is the right choice here because fp32 will double the already large memory allocation, for not much extra gain.
There was a problem hiding this comment.
@TheEpicDolphin How do you know it's not much of gain? I think this could be a sensitive problem given that it affects model accuracy now (not just the acceptance rate).
c52ce3c to
a894279
Compare
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
a894279 to
f25399a
Compare
…ing (vllm-project#35461) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
…ing (vllm-project#35461) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
…ing (vllm-project#35461) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
…ing (vllm-project#35461) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Vinay Damodaran <vrdn@hey.com>
…ing (vllm-project#35461) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: EricccYang <yangyang4991@gmail.com>
Purpose
Use probabilistic rejection sampling for MRV2 spec decoding to increase draft token acceptance rate.
Currently, strict rejection sampling is being used, where draft tokens are accepted only if they exactly match the target model's sampled tokens. This ensures that the target model's distribution is being respected, but is suboptimal for acceptance rate.
NOTE: I added a
SpeculativeConfigparameter that allows setting either strict or probabilistic rejection sampling. By default, strict rejection sampling is used. Probabilistic rejection sampling yields higher acceptance rates, but comes with the tradeoff of more memory usage from caching draft logits.Benchmarks
Server
Client
Speculative Decoding: Strict vs Probabilistic Rejection Sampling
GLM 4.7 (MTP), 3 Speculative Tokens
Model:
zai-org/GLM-4.7-FlashLlama 3 (EAGLE), 3 Speculative Tokens
Target:
meta-llama/Meta-Llama-3-8B-InstructDraft:
yuhuili/EAGLE-LLaMA3.1-Instruct-8BQwen3-8B (EAGLE-3), 5 Speculative Tokens
Target:
Qwen/Qwen3-8BDraft:
RedHatAI/Qwen3-8B-speculator.eagle3Summary
Probabilistic rejection sampling improves acceptance rates across all model configurations (+1.91pp for Llama 3, +2.77pp for Qwen3-8B, +6.69pp for GLM-4.7-Flash), with gains largest at middle draft positions and compounding through later positions. All configurations show throughput improvements (+3.2% for Llama 3, +3.5% for Qwen3-8B, +15.0% for GLM-4.7-Flash), demonstrating that the acceptance rate gains more than offset the overhead of computing the residual distribution, particularly when sampling with higher temperatures. The GLM-4.7-Flash MTP result shows the largest gains, with position 1 acceptance more than doubling (13.32% → 28.04%), highlighting that probabilistic rejection sampling is especially effective when draft model quality drops off steeply at later positions.
Config: 1000 requests, max concurrency 16, temperature 1.0, MT-Bench dataset.
Manual Testing
Served 8 requests concurrently. Below are the prompts and responses:
Test Prompts and Responses
Prompt 0: "Explain the theory of relativity in simple terms."
The theory of relativity! It's a mind-bending concept that can be tricky to wrap your head around, but I'll try to simplify it for you.
What is the theory of relativity?
The theory of relativity, developed by Albert Einstein, is a way of understanding how the universe works. It's based on two main ideas: special relativity and general relativity.
Special Relativity (1905)
Imagine you're on a train, and you throw a ball straight up in the air. What happens? The ball comes back down and lands in your hand, right? Now, imagine someone is standing outside the train, watching you throw the ball. From their perspective, the ball doesn't just go straight up and down – it also moves forward, because the train is moving really fast.
Einstein said that how we measure time and space depends on how fast we're moving and where we are. This is called "relativity." It means that:
General Relativity (1915)
Now, imagine you're standing on a tr...
Prompt 1: "What is the capital of France?"
The capital of France is Paris.
Prompt 2: "Write a haiku about coding."
Here is a haiku about coding:
Lines of code unfold
Meaning hidden in the dark
Beauty in the bug
Prompt 3: "List three benefits of regular exercise."
Here are three benefits of regular exercise:
Overall, regular exercise can have numerous benefits for both physical and mental health, making it an important part of a healthy lifestyle.
Prompt 4: "How does a refrigerator keep food cold?"
A refrigerator keeps food cold by using a combination of technologies to transfer heat from the interior of the fridge to the exterior. Here's a simplified explanation:
Prompt 5: "What is the difference between HTTP and HTTPS?"
HTTP (Hypertext Transfer Protocol) and HTTPS (Hypertext Transfer Protocol Secure) are two protocols used to transfer data over the internet. The main difference between them is the level of security and encryption used to protect the data being transmitted.
HTTP:
HTTP is an unencrypted protocol, which means that any data transmitted over HTTP is sent in plain text. This makes it vulnerable to eavesdropping, tampering, and man-in-the-middle attacks. When you enter a URL starting with "http", your browser sends an unencrypted request to the server, and the server responds with unencrypted data.
HTTPS:
HTTPS is an encrypted protocol that adds an extra layer of security to HTTP. It uses Transport Layer Security (TLS) or Secure Sockets Layer (SSL) to encrypt the data being transmitted. This means that any data sent over HTTPS is encrypted, making it much harder for hackers to intercept and read the data.
Here are some key differences between HTTP and HTTPS:
Prompt 6: "Suggest a short book to read on a rainy day."
A perfect excuse to stay indoors and get cozy with a good book! Here's a short book suggestion for a rainy day:
"The Little Paris Bookshop" by Nina George (192 pages)
This charming novel tells the story of Jean Perdu, a melancholic bookseller who sets out to find his lost love and rediscover himself. The story is set in the charming streets of Paris, where Jean's bookshop is a haven for book lovers. The book is a delightful blend of romance, literature, and self-discovery, making it a perfect companion for a rainy day.
Why it's perfect for a rainy day:
So, grab a cup of hot tea, snuggle up with a blanket, and get lost in the charming world of "The Little Paris Bookshop" on your rainy day.
Prompt 7: "2+2=?"
The answer is 4!