Skip to content

[Model Runner V2] Fix draft logits not populated during cudagraph replay#37639

Merged
WoosukKwon merged 1 commit intovllm-project:mainfrom
TheEpicDolphin:gdelfin/mrv2-fix-draft-logits-during-cudagraph-replay
Mar 20, 2026
Merged

[Model Runner V2] Fix draft logits not populated during cudagraph replay#37639
WoosukKwon merged 1 commit intovllm-project:mainfrom
TheEpicDolphin:gdelfin/mrv2-fix-draft-logits-during-cudagraph-replay

Conversation

@TheEpicDolphin
Copy link
Copy Markdown
Collaborator

@TheEpicDolphin TheEpicDolphin commented Mar 20, 2026

TLDR

When using probabilistic rejection sampling with Eagle speculative decoding and CUDA graphs enabled, the draft logits for speculative steps 1+ were not being written, causing incorrect rejection sampling behavior.

Root Cause

The draft logits (draft_logits_out) passed into EagleSpeculator was not being passed into EagleCudaGraphManager, and thus not included in the CUDA graph capture.

Fix

Move draft_logits from RequestState to EagleSpeculator, matching the existing pattern used by draft_tokens. This approach also makes it easier for my upcoming PR to add FULL cudagraph for Eagle prefill: #37588

Benchmark

Comprehensive accuracy & performance benchmark results across multiple models (llama3, mimo, qwen, glm), spec decode methods (eagle-1, eagle-3, MTP), parallelisms (TP, EP), output lengths (1, 1024), and concurrencies (1, 8, 64): https://gistpreview.github.io/?1dc71a0fa70ae78b1aa2a70e635a5a01.

We see clear improvements in acceptance rates, and increases in output token throughput for some models. However for some models, we see that the bump in acceptance rates doesn't offset the per-step overhead of probabilistic rejection sampling.

Server

VLLM_USE_V2_MODEL_RUNNER=1 vllm serve openai/gpt-oss-20b --no-enable-prefix-caching --tensor-parallel-size=1 --data-parallel-size=1 --speculative-config '{"method": "eagle3", "model": "RedHatAI/gpt-oss-20b-speculator.eagle3", "num_speculative_tokens": 3}'

Client

vllm bench serve --model openai/gpt-oss-20b --tokenizer openai/gpt-oss-20b --host 0.0.0.0 --dataset-name hf --dataset-path philschmid/mt-bench --ignore-eos --request-rate inf --max-concurrency 16 --temperature 1.0

MRV2 + probabilistic rejection sampling yield better acceptances for all draft steps compared to MRV2 + strict and MRV1. The improvement is not as dramatic as before, due to the incorrect draft logits for steps 1+. In retrospect, the draft acceptances for positions 1 and 2 in my previous PR (#37364) were sus... I should probably add some accuracy tests soon.

Metric MRV1 MRV2 (strict) MRV2 (probabilistic)
Benchmark duration (s) 141.13 108.38 107.00
Successful requests 1000 1000 1000
Failed requests 0 0 0
Max request concurrency 16 16 16
Peak concurrent requests 29 31 31
Total input tokens 136,120 136,120 136,120
Total generated tokens 256,000 256,000 256,000
Request throughput (req/s) 7.09 9.23 9.35
Output token throughput (tok/s) 1,813.92 2,362.12 2,392.50
Peak output token throughput (tok/s) 966.00 1,131.00 1,100.00
Total token throughput (tok/s) 2,778.41 3,618.10 3,664.64
Mean TTFT (ms) 66.28 56.23 58.23
Median TTFT (ms) 59.19 50.36 51.28
P99 TTFT (ms) 386.53 327.76 384.18
Mean TPOT (ms) 8.53 6.53 6.44
Median TPOT (ms) 8.60 6.52 6.43
P99 TPOT (ms) 11.17 8.37 8.29
Mean ITL (ms) 17.52 14.79 15.24
Median ITL (ms) 16.53 13.73 14.15
P99 ITL (ms) 28.09 24.69 24.91
Acceptance rate (%) 35.33 42.44 45.79
Acceptance length 2.06 2.27 2.37
Drafts 124,189 112,584 107,821
Draft tokens 372,567 337,752 323,463
Accepted tokens 131,632 143,358 148,111
Acceptance pos 0 (%) 54.56 63.10 66.39
Acceptance pos 1 (%) 32.44 40.92 44.73
Acceptance pos 2 (%) 18.99 23.31 26.24

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an issue with draft logits not being populated during CUDA graph replay in Eagle speculative decoding. The fix involves moving draft_logits from RequestState to EagleSpeculator, which is a sound approach. The implementation correctly updates the code to reflect this change. However, I've identified a potential precision issue with the data type of the newly located draft_logits tensor that could affect the correctness of probabilistic rejection sampling.

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
@TheEpicDolphin TheEpicDolphin force-pushed the gdelfin/mrv2-fix-draft-logits-during-cudagraph-replay branch from 18c2f79 to 33b9e44 Compare March 20, 2026 04:15
@TheEpicDolphin TheEpicDolphin marked this pull request as ready for review March 20, 2026 04:17
@TheEpicDolphin TheEpicDolphin requested a review from njhill as a code owner March 20, 2026 04:17
@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 20, 2026
Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 20, 2026
@WoosukKwon WoosukKwon enabled auto-merge (squash) March 20, 2026 05:53
@WoosukKwon WoosukKwon merged commit dcee9be into vllm-project:main Mar 20, 2026
61 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 20, 2026
@TheEpicDolphin TheEpicDolphin deleted the gdelfin/mrv2-fix-draft-logits-during-cudagraph-replay branch March 20, 2026 16:56
chooper26 pushed a commit to intellistream/vllm-hust that referenced this pull request Mar 21, 2026
…lay (vllm-project#37639)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
…lay (vllm-project#37639)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…lay (vllm-project#37639)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…lay (vllm-project#37639)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…lay (vllm-project#37639)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…lay (vllm-project#37639)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

2 participants