Skip to content

[Feature] ViT Full CUDA Graph#35963

Merged
Isotr0py merged 36 commits intovllm-project:mainfrom
CentML:bmu/vit-full-cudagraph-with-dp-fi
Mar 23, 2026
Merged

[Feature] ViT Full CUDA Graph#35963
Isotr0py merged 36 commits intovllm-project:mainfrom
CentML:bmu/vit-full-cudagraph-with-dp-fi

Conversation

@b-mu
Copy link
Copy Markdown
Contributor

@b-mu b-mu commented Mar 4, 2026

Purpose

Add full CUDA graph for the ViT to reduce kernel launch overheads.

Features:

  • Budget-based graphs with a maximum batch size:
    • Capture CUDA graphs at configurable token budgets (e.g., [256, 512, 1024, 2048, 4096]).
    • Pad sequence metadata (e.g. cu_seqlen) so that we can use the same budget-based graph for various number of images
      during replays.
  • Greedy bin-packing:
    • Sort images in a batch in ascending order to reduce the number of graphs.
  • Data-parallel (DP) support:
    • When mm_encoder_tp_mode=data, each TP rank runs the ViT independently via data parallelism.
  • FlashInfer cuDNN attention support:
    • Override FlashInfer buckets in the CUDA graph path.
  • Model-agnostic protocol:
    • SupportsEncoderCudaGraph protocol in interfaces.py — models opt in by implementing 9 protocol methods for input
      handling, metadata computation, and forward dispatch.
    • EncoderCudaGraphManager is fully model-agnostic; all model-specific logic (grid config, dummy inputs, embedding
      computation) lives in the model class.

New config flags (via --compilation-config):

  • cudagraph_mm_encoder: true — enable encoder CUDA graph
  • encoder_cudagraph_token_budgets: [...] — list of token budget sizes to capture
  • encoder_cudagraph_max_images_per_batch: N — max images per graph replay

Files changed:

  • vllm/config/compilation.py — new config flags
  • vllm/model_executor/models/interfaces.pySupportsEncoderCudaGraph protocol and supports_encoder_cudagraph()
    type guard
  • vllm/model_executor/models/qwen3_vl.py — implement SupportsEncoderCudaGraph on Qwen3VLForConditionalGeneration
  • vllm/v1/worker/gpu/mm/encoder_cudagraph_defs.pyEncoderCudaGraphConfig, EncoderCudaGraphCaptureInputs,
    EncoderCudaGraphReplayBuffers dataclasses
  • vllm/v1/worker/gpu/mm/encoder_cudagraph.pyEncoderCudaGraphManager (capture, replay, packing, DP)
  • vllm/v1/worker/gpu_model_runner.py — integration into V1 model runner
  • tests/v1/cudagraph/test_encoder_cudagraph.py — unit and GPU tests

cc @maxyanghu @wangshangsam @Anerudhan

Test Plan

Unit Tests:

pytest tests/v1/cudagraph/test_encoder_cudagraph.py -v

End-to-End Tests:

  • Single GPU: Qwen3-VL-30B-A3B-Instruct, VisionArena-Chat dataset, 3000 prompts + 300 warmup
vllm bench mm-processor \
  --model Qwen/Qwen3-VL-30B-A3B-Instruct \
  --dataset-name hf --dataset-path lmarena-ai/VisionArena-Chat \
  --num-prompts 3000 --num-warmups 300 \
  --max-model-len 32768 --seed 42 \
  --mm-encoder-attn-backend FLASHINFER \
  --compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_images_per_batch": 8}'
  • Multi GPU: Qwen3-VL-32B-Instruct, 4×GB200 TP=4 + ViT DP=4, random-mm dataset (20 imgs/req, 336×336), 1000 prompts + 200 warmup
vllm bench mm-processor \
  --model Qwen/Qwen3-VL-32B-Instruct \
  --dataset-name random-mm \
  --random-mm-base-items-per-request 20 \
  --random-mm-num-mm-items-range-ratio 0.0 \
  --random-mm-bucket-config '{"(336,336,1)": 1.0}' \
  --num-prompts 1000 --num-warmups 200 \
  --max-model-len 8192 --seed 42 \
  --mm-encoder-attn-backend FLASHINFER \
  --tensor-parallel-size 4 --mm-encoder-tp-mode data \
  --compilation-config '{"cudagraph_mm_encoder": true, "encoder_cudagraph_token_budgets": [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4864], "encoder_cudagraph_max_images_per_batch": 8}'

Test Result

Single GPU (Qwen3-VL-30B, 1×GB200, VisionArena-Chat, 3000 prompts):

Backend Mean P99
FLASH_ATTN +11.8% (5.13→4.52ms) +31.6% (9.16→6.26ms)
FLASH_ATTN +19.6% (5.42→4.36ms) +40.3% (10.87→6.49ms)

Multi GPU (Qwen3-VL-32B, 4×GB200 TP=4 DP=4, random-mm 20img/req, 1000 prompts):

Backend Mean P99
FLASH_ATTN +18.4% (28.39→23.16ms) +14.0% (238.78→205.28ms)
FLASHINFER +44.4% (23.24→12.91ms) +84.9% (172.41→26.05ms)

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 4, 2026

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify Bot added multi-modality Related to multi-modality (#4194) qwen Related to Qwen models nvidia v1 labels Mar 4, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces full CUDA graph support for the Vision Transformer (ViT) encoder, aiming to reduce kernel launch overhead and improve performance, particularly in multi-GPU scenarios. The implementation is well-structured, featuring a budget-based graph capture system with greedy bin-packing for efficient batching of images. Key additions include new configuration flags for enabling and tuning the feature, data-parallel sharding utilities for multi-GPU vision processing, and a new EncoderCudaGraphManager that encapsulates the CUDA graph lifecycle management. The integration into the existing model runner is clean, and the necessary modifications to the Qwen3-VL model are minimal and well-justified. Furthermore, a comprehensive new test suite has been added to validate the functionality of the encoder CUDA graph manager, covering various scenarios including capture, replay, fallbacks, and data parallelism. Overall, this is a high-quality contribution that brings a significant performance enhancement.

Note: Security Review did not run due to the size of the PR.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 4, 2026

Hi @b-mu, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

1 similar comment
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 4, 2026

Hi @b-mu, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Copy Markdown
Member

@Isotr0py Isotr0py left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a first glance, my concern is the generality of current encoder CG manager. I feel current implementation is too qwen3vl-specific (MRoPE + ViT RoPE), and it's difficult to boardcast this CG support to other ViTs.

Comment thread vllm/v1/worker/gpu_model_runner.py Outdated
Comment on lines +2504 to +2506
and modality == "image"
and "pixel_values" in mm_kwargs_group
and "image_grid_thw" in mm_kwargs_group
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I feel this is too model-specific, and it 's difficult to use for other models with different mm_kwargs naming.

Can we execute encoder_cudagraph_manage with mm_kwargs_group directly?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To address the general concern about encoder cudagraph manager being too model specific, I'm thinking about having a class SupportsEncoderCudaGraph(Protocol) in vllm/model_executor/models/interfaces.py. To use encoder cudagraph support for a model, the model needs to implement a list of methods to tell the encoder cudagraph manager how to extract inputs, sequence metadata, etc. The encoder cudagraph manager would be model agnostic. The specific concern here in gpu_model_runner.py can use mm_kwargs_group directly. What do you think?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great.

Copy link
Copy Markdown
Contributor Author

@b-mu b-mu Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We introduced SupportsEncoderCudaGraph protocol in interfaces.py with 9 protocol methods. The manager is now model-agnostic — all Qwen3-VL-specific logic lives in qwen3_vl.py implementing the protocol.

The specific concern here in gpu_model_runner.py now uses mm_kwargs_batch directly and uses self.encoder_cudagraph_manager.supports_modality(modality) instead of checking for specific keys.

Comment on lines +177 to +188
# Generate dummy grid config for capture only
# (not used for runtime batching). This is just one arbitrary
# example configuration that produces token_budget tokens.
# At runtime, actual images will be packed in any
# combination that fits the budget.
dummy_grid_config = self._generate_grid_config_for_budget(
token_budget, self.max_batch_size
)

dummy_pixel_values, dummy_grid_thw = self._prepare_dummy_inputs(
dummy_grid_config
)
Copy link
Copy Markdown
Member

@Isotr0py Isotr0py Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is something belongs to DummyInputsBuilder, otherwise the CG manager's dummy data creation could be quite complicated if we want to support other models.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dummy input generation has been moved out of the manager and into the protocol method prepare_encoder_cudagraph_capture_inputs(). Each model implements its own dummy input logic (e.g., in qwen3_vl.py).

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 5, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @b-mu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 5, 2026
@b-mu b-mu force-pushed the bmu/vit-full-cudagraph-with-dp-fi branch from 52c4ab7 to 9521d31 Compare March 10, 2026 00:26
@mergify mergify Bot removed the needs-rebase label Mar 10, 2026
Comment thread vllm/model_executor/models/qwen3_vl.py
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 12, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @b-mu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Mar 24, 2026

@b-mu @Isotr0py can we have follow up benchmarking

  1. no DP VIT + cudagraph vs no DP VIT + no cudagraph
  2. DP VIT + cudagraph vs DP VIT + no cudagraph
  3. With different workloads to know if it always provides performance boost.

Past attempts suggest that CUDAGraph might not be beneficial for all workloads
#33232 (comment)

However, this PR specifically targets scenarios with fewer images or when images are distributed via ViT DP. In these cases, the computational load per rank is smaller, and the execution time is dominated by "bubbles" caused by kernel launch overhead rather than the operator execution itself. The performance gain from using CUDA Graphs to eliminate these bubbles outweighs the slight regression introduced by torch.compile.

Regarding the root cause of torch.compile regression: We do not fully understand why torch.compile negatively impacts ViT operator performance at this stage, and investigating this is outside the scope of this PR. We would greatly appreciate any insights or contributions from the community to help identify and resolve this underlying torch.compile issue.

b-mu added a commit to CentML/vllm that referenced this pull request Mar 25, 2026
Add documentation for the encoder CUDA graph feature (PR vllm-project#35963),
covering budget-based capture/replay, greedy bin-packing, data-parallel
support, SupportsEncoderCudaGraph protocol, configuration, and usage.

Signed-off-by: Baorun Mu <bmu@nvidia.com>
@wangshangsam wangshangsam added the performance Performance-related issues label Mar 25, 2026
@wangshangsam
Copy link
Copy Markdown
Collaborator

wangshangsam commented Mar 25, 2026

@tjtanaa

  1. no DP VIT + cudagraph vs no DP VIT + no cudagraph
  2. DP VIT + cudagraph vs DP VIT + no cudagraph

I believe those are in the "Test Result" section of this PR description (i.e., single-GPU is no DP and the 4-GPU is DP+TP)? If what you are looking for is 4-GPU pure TP, I don't see how the (impact of) allreduces are relevant to this feature.

  1. With different workloads to know if it always provides performance boost.
    Past attempts suggest that CUDAGraph might not be beneficial for all workloads

I think most optimization techniques (despite what marketing or academic papers could claim) can only provide perf boost in some senarios, and you always have to make trade-offs. Sweeping this feature across other workloads is beyond the scope of this PR and we have other priorities at the moment, but you are very welcome to give it a try at the workloads you care about, and let us know if something breaks and we will fix it.

@github-project-automation github-project-automation Bot moved this to Done in Qwen3.5 Mar 25, 2026
RhizoNymph pushed a commit to RhizoNymph/vllm that referenced this pull request Mar 26, 2026
Signed-off-by: Baorun Mu <bmu@nvidia.com>
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Baorun Mu <bmu@nvidia.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Baorun Mu <bmu@nvidia.com>
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Mar 27, 2026

@tjtanaa

  1. no DP VIT + cudagraph vs no DP VIT + no cudagraph
  2. DP VIT + cudagraph vs DP VIT + no cudagraph

I believe those are in the "Test Result" section of this PR description (i.e., single-GPU is no DP and the 4-GPU is DP+TP)? If what you are looking for is 4-GPU pure TP, I don't see how the (impact of) allreduces are relevant to this feature.

  1. With different workloads to know if it always provides performance boost.
    Past attempts suggest that CUDAGraph might not be beneficial for all workloads

I think most optimization techniques (despite what marketing or academic papers could claim) can only provide perf boost in some senarios, and you always have to make trade-offs. Sweeping this feature across other workloads is beyond the scope of this PR and we have other priorities at the moment, but you are very welcome to give it a try at the workloads you care about, and let us know if something breaks and we will fix it.

Thanks for the feedback @wangshangsam . My intention is to understand if we should always turn this feature on by default. Because in the past PRs that attempt to enable Cuda Graph to ViT will share that they have specific usecases that speed up and which cases will not.

What's your thoughts on this feature that you and your team have integrated? Did you guys manage to test it on variation of workloads?

It is fine even if there is no conclusion since this feature is still under experimental feature. :)
Thanks for this amazing work ~

I will also try to benchmark this feature when I have time.

nithinvc pushed a commit to nithinvc/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Baorun Mu <bmu@nvidia.com>

Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
Signed-off-by: Baorun Mu <bmu@nvidia.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Signed-off-by: Baorun Mu <bmu@nvidia.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
Signed-off-by: Baorun Mu <bmu@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

multi-modality Related to multi-modality (#4194) nvidia performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants