Skip to content

Andy/spec probs#35398

Closed
andylolu2 wants to merge 4 commits intovllm-project:mainfrom
andylolu2:andy/spec-probs
Closed

Andy/spec probs#35398
andylolu2 wants to merge 4 commits intovllm-project:mainfrom
andylolu2:andy/spec-probs

Conversation

@andylolu2
Copy link
Copy Markdown
Contributor

@andylolu2 andylolu2 commented Feb 26, 2026

Purpose

Support drafter probabilities in ModelRunnerV2.

Test Plan

TBD

Test Result

TBD

Signed-off-by: Andy Lo <andy@mistral.ai>
Signed-off-by: Andy Lo <andy@mistral.ai>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for using draft model probabilities in speculative decoding, a significant enhancement. The changes involve refactoring the sampling logic and updating data structures like InputBatch and RequestState to handle draft logits. While the overall direction is good, I've identified two critical bugs related to incorrect tensor shapes and arguments being passed to kernel functions. These issues will lead to incorrect behavior and must be addressed.

return draft_tokens.view(-1, 1)
return Speculation(
draft_tokens=draft_tokens.view(-1, 1),
draft_logits=draft_logits.view(-1, 1),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The draft_logits tensor is being reshaped incorrectly. draft_logits has a shape of [num_reqs, vocab_size]. Using .view(-1, 1) results in a tensor of shape [num_reqs * vocab_size, 1], which is incorrect.

The Speculation dataclass expects draft_logits to have a shape of [num_reqs, num_speculative_steps, vocab_size]. For num_speculative_steps=1, the correct shape is [num_reqs, 1, vocab_size]. You should use .unsqueeze(1) to add a dimension.

Suggested change
draft_logits=draft_logits.view(-1, 1),
draft_logits=draft_logits.unsqueeze(1),

Signed-off-by: Andy Lo <andy@mistral.ai>
@andylolu2
Copy link
Copy Markdown
Contributor Author

/gemini review I think there's a bug when cudagraph is enabled, do you see it

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for using draft model probabilities in speculative decoding, which is a significant enhancement. The changes involve propagating draft logits through the system, refactoring the sampler for probabilistic rejection sampling, and updating the speculator accordingly. A new correctness test for sampling recovered and bonus tokens is a valuable addition. While the overall direction is good, I've identified a critical issue with the log-probability calculation in the speculative decoding path. The logic for computing logprobs for accepted tokens appears to be incorrect due to tensor shape mismatches and not accounting for different token origins (accepted draft vs. recovered), which will likely lead to incorrect output.

Comment on lines +166 to +172
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
if max_num_logprobs != NO_LOGPROBS:
expanded_logits = logits.shape[0] != idx_mapping_np.shape[0]
cu_num_logits_list = cu_num_logits_np.tolist() if expanded_logits else None
logprobs_tensors = compute_topk_logprobs(
processed_logits, max_num_logprobs, sampled, cu_num_logits_list
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The log-probability calculation for speculative decoding appears to be incorrect. The compute_topk_logprobs function is called with a sampled tensor of shape [num_reqs, num_speculative_steps + 1] and processed_logits of shape [num_draft_tokens + num_reqs, vocab_size]. The sampled tensor is 2D and padded, while compute_topk_logprobs likely expects a 1D tensor of token IDs that aligns with the provided logits.

Furthermore, the sampled tokens are a mix of accepted draft tokens and "recovered" tokens. The logprobs for these two types of tokens come from different distributions (p_target vs. renormalize(max(0, p_target - p_draft))). The current implementation does not seem to distinguish between them when computing logprobs, passing only processed_logits (derived from p_target for proposed tokens). This will lead to incorrect logprob values.

Signed-off-by: Andy Lo <andy@mistral.ai>
@TheEpicDolphin
Copy link
Copy Markdown
Collaborator

Hi @andylolu2, I was tasked with working on probabilistic rejection sampling for MRV2 in #35461, and stumbled upon your PR here. Looks like we took similar approaches. I have some benchmark results on my PR showing some acceptance rate improvements. I'm happy to compare implementations and figure out the best path forward for us to unlock this feature. Perhaps we can combine the best parts of both!

cc: @WoosukKwon

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 3, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @andylolu2.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 3, 2026
@andylolu2
Copy link
Copy Markdown
Contributor Author

andylolu2 commented Mar 3, 2026

Hi @andylolu2, I was tasked with working on probabilistic rejection sampling for MRV2 in #35461, and stumbled upon your PR here. Looks like we took similar approaches. I have some benchmark results on my PR showing some acceptance rate improvements. I'm happy to compare implementations and figure out the best path forward for us to unlock this feature. Perhaps we can combine the best parts of both!

cc: @WoosukKwon

Oooo amazing, I am a bit short on bandwidth atm so please feel free to take over!

I have some take aways while implementing this, let me comment in your PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants