Skip to content

[Model Runner V2] Enable piecewise & full CUDA graphs for pipeline parallelism#35162

Merged
WoosukKwon merged 11 commits intovllm-project:mainfrom
ZhanqiuHu:feature/pp-piecewise-cudagraph
Mar 22, 2026
Merged

[Model Runner V2] Enable piecewise & full CUDA graphs for pipeline parallelism#35162
WoosukKwon merged 11 commits intovllm-project:mainfrom
ZhanqiuHu:feature/pp-piecewise-cudagraph

Conversation

@ZhanqiuHu
Copy link
Copy Markdown
Contributor

@ZhanqiuHu ZhanqiuHu commented Feb 24, 2026

Summary

Add piecewise CUDA graph capture/replay support for PP in V2 model runner.

Related: #33960

model_runner.py:

  • Enable PP cudagraph mode handling
  • Add IntermediateTensors buffer during graph replay
  • Copy received tensors into the buffer at runtime

cudagraph_utils.py:

  • Add intermediate_tensors through capture pipeline
  • Handle IntermediateTensors output on non-last PP ranks
  • Fix num_reqs divisibility for uniform query length backends

Purpose

V2 model runner did not support CUDA graph capture with PP, falling back to eager mode. This PR adds piecewise CUDA graph capture for PP.

Question

With CUDA graph capture, I ran into AssertionError: TRTLLM decode requires uniform query lengths per request (flashinfer.py:1109) when num_tokens % num_reqs != 0. So I added this workaround to ensure divisibility, but not sure if this is the right approach.

if num_reqs > 0 and num_tokens > num_reqs and num_tokens % num_reqs != 0:
    tokens_per_req = cdiv(num_tokens, num_reqs)
    num_reqs = num_tokens // tokens_per_req
    if num_tokens % num_reqs != 0:
        num_reqs = 1

Test plan

export MODEL=Qwen/Qwen3-30B-A3B-Thinking-2507-FP8

# 1. Serve with V2 + PP=2 (piecewise CG, the default)
VLLM_USE_V2_MODEL_RUNNER=1 vllm serve $MODEL -pp 2 --enable-expert-parallel --max-num-seqs 128 --no-enable-prefix-caching

# 2. Throughput benchmark
vllm bench serve --model $MODEL --dataset-name random \
  --random-input-len 2 --random-output-len 512 --num-prompts 128 --num-warmups 16

# 3. Accuracy eval (GSM8K 5-shot)
lm_eval --model local-completions \
  --model_args "model=$MODEL,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=1024" \
  --tasks gsm8k --num_fewshot 5

Test results

Config: Qwen3-30B-A3B-Thinking-2507-FP8, PP=2, 2×B200, --max-num-seqs 128

Performance (128 prompts, input=2, output=512)

Model Runner CUDA Graph Req/s Output tok/s TTFT (ms) TPOT (ms)
V2 Eager 13.89 7,111 231 17.5
V2 Piecewise 23.07 11,810 167 10.4
V1 (baseline) Piecewise 23.23 11,896 165 10.3

Accuracy (GSM8K, 5-shot)

Model Runner CUDA Graph strict-match flexible-extract
V2 Eager 0.7824 0.6732
V2 Piecewise 0.7862 0.6778
V1 (baseline) Piecewise 0.7824 0.6717

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables piecewise CUDA graph capture for pipeline parallelism in the V2 model runner, which was previously unsupported and forced eager mode execution. The changes are well-structured and correctly handle the complexities of this feature. Key changes include downgrading CUDA graph modes for compatibility with pipeline parallelism, introducing a persistent buffer for intermediate tensors to ensure stable memory addresses for graph replay, and updating the graph capture utilities to handle pipeline parallelism constructs. I have one suggestion to improve robustness by adding an assertion to verify the consistency of intermediate tensors between pipeline stages.

Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!

Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also try #34903 this PR?
Not sure if it is the same issue for enabling full cuda graph

@ZhanqiuHu
Copy link
Copy Markdown
Contributor Author

ZhanqiuHu commented Feb 25, 2026

Could you also try #34903 this PR? Not sure if it is the same issue for enabling full cuda graph

Cool! will take a look 👍

@mergify
Copy link
Copy Markdown

mergify bot commented Feb 26, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZhanqiuHu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 26, 2026
@ZhanqiuHu
Copy link
Copy Markdown
Contributor Author

Could you also try #34903 this PR? Not sure if it is the same issue for enabling full cuda graph

I think for now the error I hit was only the uniform query lengths assertion with TRTLLM decode, I don't think it has illegal memory access from misaligned attention metadata at the moment.

I will keep that issue in mind, and I think I might run into it later with full cuda graph support.

@ZhanqiuHu ZhanqiuHu force-pushed the feature/pp-piecewise-cudagraph branch from 0602419 to 17cc226 Compare March 2, 2026 18:31
@ZhanqiuHu ZhanqiuHu requested a review from njhill as a code owner March 2, 2026 18:31
@mergify mergify bot removed the needs-rebase label Mar 2, 2026
Add piecewise CUDA graph capture/replay support for PP in V2 model runner.

model_runner.py:
- Replace eager-only PP guard with PP-aware cudagraph mode handling
- Create persistent IntermediateTensors buffer during capture
- Copy received tensors into the buffer at runtime for address stability

cudagraph_utils.py:
- Thread intermediate_tensors through the capture pipeline
- Handle IntermediateTensors output on non-last PP ranks
- Fix num_reqs divisibility for uniform query length backends
- Assert FULL cudagraph mode is not used with PP

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
@ZhanqiuHu ZhanqiuHu requested a review from yewentao256 March 3, 2026 19:30
@ZhanqiuHu
Copy link
Copy Markdown
Contributor Author

Friendly ping @WoosukKwon @njhill: This PR has been rebased and is ready for review. Would appreciate any feedback when you get a chance. Thanks!

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 10, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @ZhanqiuHu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
@mergify mergify bot removed the needs-rebase label Mar 22, 2026
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 22, 2026
Copy link
Copy Markdown
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the PR and sorry for the late review.

I've edited the PR for faster merge. Will follow up with the FULL graph support.

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 22, 2026
@WoosukKwon WoosukKwon enabled auto-merge (squash) March 22, 2026 17:25
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
@WoosukKwon
Copy link
Copy Markdown
Collaborator

Actually, I ended up merging with #37821 for easier testing.

@WoosukKwon WoosukKwon changed the title [Model Runner V2] Enable piecewise CUDA graphs for pipeline parallelism [Model Runner V2] Enable piecewise & full CUDA graphs for pipeline parallelism Mar 22, 2026
@ZhanqiuHu
Copy link
Copy Markdown
Contributor Author

ZhanqiuHu commented Mar 22, 2026

LGTM! Thanks for the PR and sorry for the late review.

I've edited the PR for faster merge. Will follow up with the FULL graph support.

That sounds great! Thank you!

@WoosukKwon WoosukKwon merged commit 63f49b8 into vllm-project:main Mar 22, 2026
55 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 22, 2026
yzong-rh pushed a commit to yzong-rh/vllm that referenced this pull request Mar 23, 2026
…sm (vllm-project#35162)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
RhizoNymph pushed a commit to RhizoNymph/vllm that referenced this pull request Mar 26, 2026
…sm (vllm-project#35162)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Mar 27, 2026
…sm (vllm-project#35162)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
…sm (vllm-project#35162)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
…sm (vllm-project#35162)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…sm (vllm-project#35162)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
…sm (vllm-project#35162)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
…sm (vllm-project#35162)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants