Skip to content

Conversation

@zixi-qi
Copy link
Collaborator

@zixi-qi zixi-qi commented Jul 23, 2025

Purpose

Support running eagle speculative decoding draft model in full cudagraph mode for v1 (referencing #16072). On H100 TP1 with Llama3.1 8B model, the full cudagraph version shows on par (or marginal improvement) on performance.

  • trace with piecewise cudagraph
Screenshot 2025-07-23 at 2 11 23 PM - benchmark with piecewise cudagraph (1 * H100, Llama3.1 8B, max bs = 2, acceptance rate = 0)
============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  135.05    
Total input tokens:                      9568      
Total generated tokens:                  21269     
Request throughput (req/s):              0.74      
Output token throughput (tok/s):         157.49    
Total Token throughput (tok/s):          228.33    
---------------Time to First Token----------------
Mean TTFT (ms):                          27.53     
Median TTFT (ms):                        24.96     
P99 TTFT (ms):                           46.25     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          12.47     
Median TPOT (ms):                        12.40     
P99 TPOT (ms):                           13.37     
---------------Inter-token Latency----------------
Mean ITL (ms):                           12.47     
Median ITL (ms):                         12.37     
P99 ITL (ms):                            13.71     
==================================================
  • trace with full cudagraph
Screenshot 2025-07-23 at 1 56 36 PM - benchmark with piecewise cudagraph (1 * H100, Llama3.1 8B, max bs = 2, acceptance rate = 0)
============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  132.64    
Total input tokens:                      9568      
Total generated tokens:                  21087     
Request throughput (req/s):              0.75      
Output token throughput (tok/s):         158.97    
Total Token throughput (tok/s):          231.11    
---------------Time to First Token----------------
Mean TTFT (ms):                          27.23     
Median TTFT (ms):                        24.89     
P99 TTFT (ms):                           45.50     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          12.39     
Median TPOT (ms):                        12.30     
P99 TPOT (ms):                           13.36     
---------------Inter-token Latency----------------
Mean ITL (ms):                           12.39     
Median ITL (ms):                         12.26     
P99 ITL (ms):                            13.80     
==================================================

Test Plan & Result

  • Acceptance rate comparison between full and piecewise cudagraph:

VLLM_USE_V1=1 python examples/offline_inference/spec_decode.py --num_spec_tokens 5 --num_prompts 80 --dataset-name hf --dataset-path philschmid/mt-bench --compilation_config '{"full_cuda_graph": false}'

--------------------------------------------------
total_num_output_tokens: 17059
num_drafts: 6988
num_draft_tokens: 34940
num_accepted_tokens: 10091
mean acceptance length: 2.44
--------------------------------------------------
acceptance at token 0: 0.68
acceptance at token 1: 0.39
acceptance at token 2: 0.21
acceptance at token 3: 0.11
acceptance at token 4: 0.05

VLLM_USE_V1=1 python examples/offline_inference/spec_decode.py --num_spec_tokens 5 --num_prompts 80 --dataset-name hf --dataset-path philschmid/mt-bench --compilation_config '{"full_cuda_graph": true}'

--------------------------------------------------
total_num_output_tokens: 17006
num_drafts: 6986
num_draft_tokens: 34930
num_accepted_tokens: 10038
mean acceptance length: 2.44
--------------------------------------------------
acceptance at token 0: 0.68
acceptance at token 1: 0.39
acceptance at token 2: 0.21
acceptance at token 3: 0.11
acceptance at token 4: 0.05
  • unit test with full cudagraph {True, False}

pytest -v tests/v1/e2e/test_spec_decode.py::test_eagle_correctness

tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[True-llama3_eagle] PASSED                                                                                                                                            [ 16%]
tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[True-llama3_eagle3] PASSED                                                                                                                                           [ 33%]
tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[True-llama4_eagle] SKIPPED (Skipping due to CI OOM issues)                                                                                                           [ 50%]
tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[False-llama3_eagle] PASSED                                                                                                                                           [ 66%]
tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[False-llama3_eagle3] PASSED                                                                                                                                          [ 83%]
tests/v1/e2e/test_spec_decode.py::test_eagle_correctness[False-llama4_eagle] SKIPPED (Skipping due to CI OOM issues)                                                                                                          [100%]

Side Note

In previous draft PR (#20190 (comment)) a numerical difference between full and piecewise cudagraph mode was identified, @zou3519 helped identified the root cause to be rms_norm kernel selected by inductor having slightly better numerics than eager version (pytorch/pytorch#158699) which is expected behavior.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation speculative-decoding v1 labels Jul 23, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for full CUDA graph execution with the Eagle speculative decoding model. The changes are well-structured, adding a new compilation configuration and updating the necessary components in the model runner and the Eagle proposer to handle CUDA graph-compatible attention metadata. The tests have also been extended to cover this new functionality.

I've identified a critical issue where an IndexError could occur if cudagraph_batch_sizes becomes empty, which can happen under certain configurations with sequence parallelism. I've provided suggestions to fix this by adding a check before accessing the list.

Other than that, the implementation looks solid and correctly follows the patterns for enabling CUDA graphs in vLLM.

@zixi-qi zixi-qi force-pushed the eagle-full-cudagraph branch 4 times, most recently from 99e87fe to 7251b9e Compare July 23, 2025 22:39
@zixi-qi zixi-qi marked this pull request as ready for review July 23, 2025 22:42
@zixi-qi zixi-qi requested review from houseroad and zou3519 July 23, 2025 22:46
@zixi-qi zixi-qi changed the title run eagle with full cudagraph support [v1][spec decode] Run eagle with full cudagraph support Jul 25, 2025
@mergify
Copy link

mergify bot commented Aug 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zixi-qi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 1, 2025
@fhl2000
Copy link
Contributor

fhl2000 commented Aug 10, 2025

Thanks for the great work in addressing the output discrepancies between full cudagraph and piecewise cudagraph. May I ask how the numerical difference of rms_norm was addressed compared to the previous #20190? I didn't identify how this was being done after I read through the code. It would be very helpful if you could point out the way, as I encountered a similar issue at CI test of #20059, where some output discrepancies are identified between full cudagraph and piecewise cudagraph.

@zixi-qi
Copy link
Collaborator Author

zixi-qi commented Aug 10, 2025

Thanks for the great work in addressing the output discrepancies between full cudagraph and piecewise cudagraph. May I ask how the numerical difference of rms_norm was addressed compared to the previous #20190? I didn't identify how this was being done after I read through the code. It would be very helpful if you could point out the way, as I encountered a similar issue at CI test of #20059, where some output discrepancies are identified between full cudagraph and piecewise cudagraph.

Hi @fhl2000, really impressed by your great work in #20059!

For the accuracy issue, what I learned from @zou3519 was that when using full cudagraph mode, inductor would pick a rms_norm kernel that has different numerics than what's in the piecewise version: pytorch/pytorch#158699 can confirm if this is the same issue by turning off inductor "use_inductor": false and / or turning off compilation for rms_norm "custom_ops": ["+rms_norm"] in the compilation config.

@zixi-qi zixi-qi force-pushed the eagle-full-cudagraph branch 2 times, most recently from eea737f to fa49837 Compare August 11, 2025 18:28
@mergify mergify bot removed the needs-rebase label Aug 11, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR and sorry for the delayed review.
LGTM overall. Left some questions.

Comment on lines 334 to 349
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this? This looks redundant as the loop updates attn_metadata in place.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the loop updates the eager mode attn_metadata in place, however for the cudagraph case we need to update the attn_metadata_cudagraph with buffered tensors.

Directly copying over the eager mode attn_metadata is the easiest way to do this but let me also try to see if I can refactor it to be cleaner

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored and removed this code block, thanks for the suggestion!

@zixi-qi zixi-qi force-pushed the eagle-full-cudagraph branch from aa14306 to 935ee07 Compare August 12, 2025 21:28
@mergify
Copy link

mergify bot commented Aug 13, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zixi-qi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the refactoring!

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 13, 2025
@WoosukKwon
Copy link
Collaborator

@zixi-qi Can you please rebase the PR with main? While I think the failing tests are not relevant to this PR, but it'd be nice to double check

@zixi-qi zixi-qi force-pushed the eagle-full-cudagraph branch from 4299dfd to 51939f0 Compare August 13, 2025 22:14
@zixi-qi
Copy link
Collaborator Author

zixi-qi commented Aug 13, 2025

@zixi-qi Can you please rebase the PR with main? While I think the failing tests are not relevant to this PR, but it'd be nice to double check

Done! Btw I also rebased #22691 , which is a small fix and shouldn't trigger any test failures on its own. Maybe we can also use it to cross validate if any test failures are due to this test or not

Signed-off-by: qizixi <[email protected]>
@zixi-qi
Copy link
Collaborator Author

zixi-qi commented Aug 14, 2025

@zixi-qi Can you please rebase the PR with main? While I think the failing tests are not relevant to this PR, but it'd be nice to double check

Seems test_eagle_correctness test is broken on trunk and adding another dimension to it won't help, so moved the full cudagraph validation to a separate test.

Signed-off-by: qizixi <[email protected]>
@popsiclexu
Copy link

@zixi-qi Hi, Thanks for the great work. I have a quick question: Have you compared the acceptance rates between full CUDA graph mode and eager mode? If so, do you observe any differences in acceptance behavior between them?

@mergify
Copy link

mergify bot commented Aug 20, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @zixi-qi.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@fhl2000
Copy link
Contributor

fhl2000 commented Aug 26, 2025

@zixi-qi Sorry for this delay. Previously, the cudagraph ability of the eagle drafter was broken by #20059. Now #23679 should address that. Can you try working forward most of the current code based on #23679 to support full cudagraph mode? Will be very appreciate for that!

@zixi-qi
Copy link
Collaborator Author

zixi-qi commented Aug 26, 2025

@zixi-qi Sorry for this delay. Previously, the cudagraph ability of the eagle drafter was broken by #20059. Now #23679 should address that. Can you try working forward most of the current code based on #23679 to support full cudagraph mode? Will be very appreciate for that!

Awesome! Let me rebase and try!

@geaned
Copy link

geaned commented Sep 10, 2025

@zixi-qi Hi, thank you for the great work! What's the state of this PR? Or is the proposed feature going to be implemented in #23679? Would be awesome and very useful if CUDA graph support was finally added!

@zixi-qi
Copy link
Collaborator Author

zixi-qi commented Sep 10, 2025

@zixi-qi Hi, thank you for the great work! What's the state of this PR? Or is the proposed feature going to be implemented in #23679? Would be awesome and very useful if CUDA graph support was finally added!

Thanks for the follow up! CUDA graph itself should be supported in #23679 , this change adds full graph support on top of that but still needs to be rebased

@zixi-qi
Copy link
Collaborator Author

zixi-qi commented Sep 11, 2025

After #23679 this PR is no longer relevant and draft full cudagraph needs to be supported based off the new design. Hence closing this one to reduce confusion

@zixi-qi zixi-qi closed this Sep 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation needs-rebase ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants