Skip to content

[Misc][ViT Cuda Graphs] Enable Piecewise CUDA Graphs for Qwen3-VL and Qwen2.5-VL ViT to Improve Performance#33232

Open
HirokenOvo wants to merge 35 commits intovllm-project:mainfrom
rednote-ai:feat/add_vit_piecewise_cudaGraph
Open

[Misc][ViT Cuda Graphs] Enable Piecewise CUDA Graphs for Qwen3-VL and Qwen2.5-VL ViT to Improve Performance#33232
HirokenOvo wants to merge 35 commits intovllm-project:mainfrom
rednote-ai:feat/add_vit_piecewise_cudaGraph

Conversation

@HirokenOvo
Copy link
Contributor

@HirokenOvo HirokenOvo commented Jan 28, 2026

Purpose

Based on the torch.compile mechanism for generic nn.Module from PR #23207 and PR #27741, this PR implements piecewise CUDA graph support for the Vision Transformer (ViT) encoder in Qwen2.5-VL and Qwen3-VL models. The primary goal is to eliminate kernel launch bubbles for operators other than attention, thereby reducing overhead and improving inference performance.

This optimization is particularly effective for scenarios with low concurrency where the number of images is insufficient to fully saturate the ViT's computational capacity. By reducing kernel launch overhead, it helps to lower the TTFT.

Key Features

  1. Supported Models: Enabled for Qwen2.5-VL and Qwen3-VL series. Modified Qwen2_5_VisionTransformer and Qwen3_VisionTransformer to be graph-friendly, including handling of persistent buffers for hidden states and RoPE.
  2. Parallelism Modes: Supports both ViT TP mode and DP mode for ViT CUDA graphs.

Usage

Enable ViT CUDA Graph

To enable this feature, you need to set compile_mm_encoder to True via the --compilation-config argument.

--compilation-config '{"compile_mm_encoder": true}'

Configure Capture Sizes

You can specify the image patch counts for which to capture ViT CUDA graphs using mm_encoder_cudagraph_capture_sizes.

--compilation-config '{"compile_mm_encoder": true, "mm_encoder_cudagraph_capture_sizes": [512, 1024]}'

If not specified, vLLM will automatically select a set of sizes based on the model's encoder budget.

Alternatively, you can specify max_mm_encoder_cudagraph_capture_size to generate a default list of capture sizes up to the given value:

--compilation-config '{"compile_mm_encoder": true, "max_mm_encoder_cudagraph_capture_size": 2048}'

Limitations & Notes

  • Image Only: This feature currently only supports image inference. Video inference is not supported yet.
  • torch.compile Issues: torch.compile consumes a significant amount of GPU memory for compiling the ViT layer and may introduce negative optimization, leading to a slight increase in latency for individual operators. However, this issue is orthogonal to this PR, which focuses on enabling CUDA graph capture to reduce launch overhead.

Performance

The performance was benchmarked under the following configuration:

  • Hardware: H800
  • Model: Qwen3-VL 32B
  • Parallelism: TP=4, ViT DP=4
  • Input: The total number of image patches processed per rank is calculated as: (20 images / 4 DP ranks) * 128 tokens/image * 4 merge_size = 2560.

without ViT cudagraph:
image

with ViT cudagraph:
image
image

The remaining bubbles are due to the GEMM being too small to fully hide the kernel launch overhead of the attention operator. Future work could involve implementing a full CUDA graph with FlashAttention-3 for the ViT to completely eliminate this overhead.


Test Plan

Run the following test to verify the consistency between ViT CUDA graph execution and eager mode:

pytest tests/compile/piecewise/test_qwenvl_vit_cudagraph.py

Test Result

6 passed, confirming the correctness of the piecewise CUDA graph implementation.


Essential Elements of an Effective PR Description Checklist
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify
Copy link

mergify bot commented Jan 28, 2026

Documentation preview: https://vllm--33232.org.readthedocs.build/en/33232/

@mergify mergify bot added documentation Improvements or additions to documentation qwen Related to Qwen models nvidia v1 labels Jan 28, 2026
@mergify
Copy link

mergify bot commented Jan 28, 2026

Hi @HirokenOvo, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization by enabling piecewise CUDA graphs for the Vision Transformer (ViT) encoder in Qwen2.5-VL and Qwen3-VL models. The changes are comprehensive, spanning model definitions, configuration, data-parallel handling, and the CUDA graph dispatching mechanism. The implementation is well-tested and demonstrates clear performance benefits, especially for low-concurrency scenarios. My primary concern is the introduction of global state to manage graph compilation options, which could affect maintainability and introduce subtle bugs in more complex execution scenarios.

Comment on lines 51 to 89
# A global flag to indicate if the current graph being compiled
# is the last one in a sequence of graphs (e.g., a sequence of blocks).
# This is a workaround to control CUDAGraph weak_ref_output behavior
# in **vit** piecewise compilation.
_is_last_graph_in_vit_sequence: bool = True


@contextmanager
def set_is_last_graph_in_vit_sequence(is_last: bool) -> Iterator[None]:
"""Context manager to indicate if the current graph being compiled
is the last one in a sequence of graphs (e.g., a sequence of blocks).
"""
global _is_last_graph_in_vit_sequence
original_value = _is_last_graph_in_vit_sequence
_is_last_graph_in_vit_sequence = is_last
try:
yield
finally:
_is_last_graph_in_vit_sequence = original_value


# A global flag to indicate if the current graph being compiled
# is the first one in a sequence of graphs (e.g., a sequence of blocks).
_is_first_graph_in_vit_sequence: bool = True


@contextmanager
def set_is_first_graph_in_vit_sequence(is_first: bool) -> Iterator[None]:
"""Context manager to indicate if the current graph being compiled
is the first one in a sequence of graphs (e.g., a sequence of blocks).
"""
global _is_first_graph_in_vit_sequence
original_value = _is_first_graph_in_vit_sequence
_is_first_graph_in_vit_sequence = is_first
try:
yield
finally:
_is_first_graph_in_vit_sequence = original_value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The introduction of global flags _is_last_graph_in_vit_sequence and _is_first_graph_in_vit_sequence to control CUDA graph options is a significant design concern. While the use of context managers helps to scope their usage, relying on global state makes the code harder to reason about, debug, and is not thread-safe, which could lead to subtle bugs if compilations were to ever run concurrently.

Consider exploring alternatives to pass this state explicitly through the call stack, for instance, by extending the ForwardContext or another context object. This would make the data flow explicit and improve the overall maintainability and robustness of the compilation backend.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to second this comment - I have had issues reasoning about and seen bugs introduced from these global variables from context managers. If we can explore another way of passing this information, I think it would be much better for the code

Copy link
Contributor Author

@HirokenOvo HirokenOvo Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to second this comment - I have had issues reasoning about and seen bugs introduced from these global variables from context managers. If we can explore another way of passing this information, I think it would be much better for the code

@Lucaskabela The original reason for using global variables was that ViT torch.compile is applied piecewise to individual sub-modules, rather than the top-level module. This means the compilation backend only sees one sub-module's graph at a time, not the full ViT graph, rendering the existing first/last graph detection logic (used for LLMs) incorrect.

In the latest commit, I have refactored this to pass the information via the forward context, eliminating the need for these global variables. cc @wangxingran222

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 6 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

@ywang96
Copy link
Member

ywang96 commented Jan 28, 2026

Hey @HirokenOvo Thank you very much for contributing this PR! 🚀

This is definitely one of the ideas that we are aware of but didn't get time to experiment yet, so very much appreciated your effort here! I'm a bit busy recently so prob don't have time to review this in the next two days, but I'll definitely try to get to this PR this weekend!

cc @ProExpertProg @DarkLight1337 in case you want to take a first pass.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Jan 28, 2026

@HirokenOvo is the torch compile graph enabled by default? Because last time when we were testing the ViT torch compile feature before your PR on ROCm, there are performance regression in certain cases, like when the DP ViT feature is enabled.

Can you also try running your feature with DP ViT feature enabled, can you check if DP ViT + torch compile is faster than DP ViT + without torch compile?

@HirokenOvo
Copy link
Contributor Author

@HirokenOvo is the torch compile graph enabled by default? Because last time when we were testing the ViT torch compile feature before your PR on ROCm, there are performance regression in certain cases, like when the DP ViT feature is enabled.

Can you also try running your feature with DP ViT feature enabled, can you check if DP ViT + torch compile is faster than DP ViT + without torch compile?

@tjtanaa Sorry for the late reply.

First question: No, this feature is not enabled by default. The Piecewise CUDA Graph implementation relies on torch.compile to trace the computation graph and separate the attention operators. Therefore, users must explicitly enable ViT compilation via the argument --compilation-config '{"compile_mm_encoder": true}' to activate this feature.
Once enabled, if vit_cudagraph_capture_sizes is not specified, vLLM will use a default set of sizes for capture. Since compile_mm_encoder is False by default, this feature remains inactive unless configured. If you only want to enable torch.compile for ViT without using the CUDA Graph feature from this PR, you can explicitly set the capture sizes to empty: --compilation-config '{"compile_mm_encoder": true, "vit_cudagraph_capture_sizes": []}'.

Second question: As mentioned in the Limitations & Notes section, torch.compile can indeed cause performance regression on individual operators in certain scenarios. PR #27741 also noted similar regression when introducing torch.compile for Qwen3-VL, but this is orthogonal to the issue of this PR.

However, this PR specifically targets scenarios with fewer images or when images are distributed via ViT DP. In these cases, the computational load per rank is smaller, and the execution time is dominated by "bubbles" caused by kernel launch overhead rather than the operator execution itself. The performance gain from using CUDA Graphs to eliminate these bubbles outweighs the slight regression introduced by torch.compile.

Regarding the root cause of torch.compile regression: We do not fully understand why torch.compile negatively impacts ViT operator performance at this stage, and investigating this is outside the scope of this PR. We would greatly appreciate any insights or contributions from the community to help identify and resolve this underlying torch.compile issue.

In the future, we plan to develop a Full CUDA Graph based on FA3, which will not require torch.compile to separate attention operators (and thus won't need to compile the graph), thereby avoiding these negative side effects entirely.

Comment on lines 51 to 89
# A global flag to indicate if the current graph being compiled
# is the last one in a sequence of graphs (e.g., a sequence of blocks).
# This is a workaround to control CUDAGraph weak_ref_output behavior
# in **vit** piecewise compilation.
_is_last_graph_in_vit_sequence: bool = True


@contextmanager
def set_is_last_graph_in_vit_sequence(is_last: bool) -> Iterator[None]:
"""Context manager to indicate if the current graph being compiled
is the last one in a sequence of graphs (e.g., a sequence of blocks).
"""
global _is_last_graph_in_vit_sequence
original_value = _is_last_graph_in_vit_sequence
_is_last_graph_in_vit_sequence = is_last
try:
yield
finally:
_is_last_graph_in_vit_sequence = original_value


# A global flag to indicate if the current graph being compiled
# is the first one in a sequence of graphs (e.g., a sequence of blocks).
_is_first_graph_in_vit_sequence: bool = True


@contextmanager
def set_is_first_graph_in_vit_sequence(is_first: bool) -> Iterator[None]:
"""Context manager to indicate if the current graph being compiled
is the first one in a sequence of graphs (e.g., a sequence of blocks).
"""
global _is_first_graph_in_vit_sequence
original_value = _is_first_graph_in_vit_sequence
_is_first_graph_in_vit_sequence = is_first
try:
yield
finally:
_is_first_graph_in_vit_sequence = original_value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to second this comment - I have had issues reasoning about and seen bugs introduced from these global variables from context managers. If we can explore another way of passing this information, I think it would be much better for the code


3. If no sizes are provided by the user, a default list of sizes is
generated up to a maximum of 5120. The default sizes are:
[512, 1024, 1536] + list(range(2048, 2048, 128)) + list(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment explaining why these are the default ranges we are using (i.e image sizes are usually one of these or something to that effect)

Copy link
Contributor Author

@HirokenOvo HirokenOvo Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a comment explaining why these are the default ranges we are using (i.e image sizes are usually one of these or something to that effect)

@Lucaskabela Actually, in our internal usage, we specify vit_cudagraph_capture_sizes based on specific workload requirements. For the general community default, we designed it with the following considerations:

  1. Patch Variance: Unlike the LLM part (which might use a step size of 8), image patch counts vary significantly.
  2. Small Inputs: For scenarios with very few images, the GEMM computation is too small to effectively hide kernel launch overheads. Therefore, padding to a larger starting size is acceptable.
  3. Step Size: Increasing the stride doesn't significantly add to the computation time for the ViT part. However, a larger stride reduces the number of graphs to be captured, which saves VRAM and reduces startup time.

We are not certain if this default fits all general community scenarios, so we are open to discussion if there are better configurations.

@github-project-automation github-project-automation bot moved this to In review in NVIDIA Jan 28, 2026
@Lucaskabela
Copy link
Contributor

Hi @HirokenOvo from the torch.compile team - this is awesome work! We were intending to investigate this enablement in the coming few months, so it is great to see it here working today!

In regards to torch.compile negatively impacting VIT operator performance, we are planning on looking into this very soon - I would not expect this to be a fundamental limitation of torch.compile, but perhaps some oversight in how one of our passes impacts multimodal encoders

@HirokenOvo
Copy link
Contributor Author

@cursor review

HirokenOvo and others added 9 commits February 3, 2026 16:25
Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
Signed-off-by: Xingran Wang <wangxingran123456@outlook.com>
Co-authored-by: Xingran Wang <wangxingran123456@outlook.com>
Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>

Signed-off-by: Hongjian Zhang <zhanghongjian@xiaohongshu.com>
… multimodal input handling

Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
@HirokenOvo HirokenOvo force-pushed the feat/add_vit_piecewise_cudaGraph branch from b667b32 to 53814ec Compare February 3, 2026 10:27
@mergify mergify bot added multi-modality Related to multi-modality (#4194) and removed needs-rebase labels Feb 3, 2026
@HirokenOvo
Copy link
Contributor Author

Overall I'm not a fan of how much complexity this adds to both the model implementation and model runner. We should try to integrate this using similar principles as model runner v2.

@DarkLight1337 Thank you for the feedback. Here's the rationale for the current design:

1. Model Runner Refactoring

I extracted MMEncoderCudagraphManager to encapsulate CUDA graph capture, dispatch, and padding logic. However, since this PR targets V1, it must integrate with V1's current logic, which prevents full parity with V2.

2. Why Model-Layer Modifications Are Necessary

The model-layer changes (persistent buffers and mm_cudagraph_manager parameter) cannot be avoided due to two architectural constraints:

  • Persistent Buffers: torch.compile is applied per-layer (PatchEmbed, VisionBlock, PatchMerger) rather than as a monolithic graph. This means operations between layers (e.g., between PatchEmbed and VisionBlock) are outside the CUDA graph capture scope, causing intermediate tensors like hidden_states and rotary_pos_emb to have non-deterministic memory addresses across invocations. Unlike the LLM where we can simply copy input x to a fixed buffer at the model runner level, here we must add persistent buffers at the model layer and manually copy intermediate results during forward. Each model architecture requires different buffers, making extraction into a common abstraction impractical.

  • DP Mode Padding: In ViT DP mode, run_dp_sharded_mrope_vision_model() performs load balancing to distribute images across ranks before the vision model runs. If padding were applied at the model runner layer, dummy images would be included in load balancing and incorrectly distributed across GPUs. By passing mm_cudagraph_manager to the model layer, padding is applied after load balancing, ensuring each rank pads only its local batch independently.

has_lora: bool = False,
disable_full: bool = False,
num_active_loras: int = 0,
is_mm_encoder: bool = False,
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution, the performance gains are impressive!

i agree with @DarkLight1337, the complexity it too high right now, in particular inside the model runner and CudagraphDispatcher, on that note

why not just instantiate another cudagraph_dispatcher in the gpu_model_runner

        # Cudagraph dispatcher for runtime cudagraph dispatching.
        self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)

        if self.supports_mm_inputs:
              self.mm_cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
              ...

and initialize these with different keys. Then these flags are not required. We can initialize_cudagraph_keys to accept capture_sizes instead of fetching it directly from self.compilation_config.cudagraph_capture_sizes to assist with this. We do this for eagle (have a separate cudagraph_dispatcher instance)

Ideally we'd keep the CudagraphDispatcher as simple as possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see:

# Cudagraph dispatcher for PIECEWISE-only dispatching in eagle.
# Keys are initialized later via initialize_cudagraph_keys() called from
# gpu_model_runner._check_and_update_cudagraph_mode after
# adjust_cudagraph_sizes_for_spec_decode is called.
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)

def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle.
Eagle only supports PIECEWISE cudagraphs (via mixed_mode).
This should be called after adjust_cudagraph_sizes_for_spec_decode.
"""
if (
not self.speculative_config.enforce_eager
and cudagraph_mode.mixed_mode()
in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
):
eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE
else:
eagle_cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion. Following your advice has indeed significantly reduced the modifications required in CudagraphDispatcher. Regarding gpu_model_runner, I have encapsulated the relevant logic into MMEncoderCudagraphManager to minimize changes to the runner itself.

Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
"""

def __init__(self, vllm_config: VllmConfig):
def __init__(self, vllm_config: VllmConfig, is_mm_encoder: bool = False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for cleaning this up! lets just pass max_capture_size and capture_sizes so we can avoid this and make more extensible

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(will be helpful if we have separate sizes for drafters too)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should pass it into initialize_cudagraph_keys instead of __init__ to so it happens after adjust_cudagraph_sizes_for_spec_decode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions! I have refactored CudagraphDispatcher as requested.

Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com>
@mergify
Copy link

mergify bot commented Feb 5, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @HirokenOvo.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 5, 2026
all_moe_layers: list[str] | None = None
moe_layer_index: int = 0

# mm_encoder Multi-Modal Encoder flags used by backend compiler
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I missing something? Where are these flags being read?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I missing something? Where are these flags being read?

@DarkLight1337 Sorry about the confusing comment. These flags are actually consumed by the Dynamo post-processing logic in vllm/compilation/backends.py.

Since ViT uses torch.compile on submodules (unlike the top-level approach in LLMs), we end up with multiple Dynamo graphs. The backend incorrectly treats each submodule as a standalone sequence, leading to wrong gc_disable and weak_ref_output settings. This causes intermediate tensors to be prematurely garbage collected during capture.

To fix this, we introduced these flags and manually call set_is_first_graph_in_mm_encoder_sequence and set_is_last_graph_in_mm_encoder_sequence inside the ViT forward method. This explicitly hints the correct global sequence boundaries to the backend wrapper.

@tjtanaa
Copy link
Collaborator

tjtanaa commented Feb 5, 2026

@HirokenOvo @Lucaskabela @DarkLight1337 @LucasWilkinson

I agree with @DarkLight1337 that the PR introduces a lot of complexity. My concerns are as follows:

Concern 1:

Second question: As mentioned in the Limitations & Notes section, torch.compile can indeed cause performance regression on individual operators in certain scenarios. PR #27741 also noted similar regression when introducing torch.compile for Qwen3-VL, but this is orthogonal to the issue of this PR.

However, this PR specifically targets scenarios with fewer images or when images are distributed via ViT DP. In these cases, the computational load per rank is smaller, and the execution time is dominated by "bubbles" caused by kernel launch overhead rather than the operator execution itself. The performance gain from using CUDA Graphs to eliminate these bubbles outweighs the slight regression introduced by torch.compile.

Regarding the root cause of torch.compile regression: We do not fully understand why torch.compile negatively impacts ViT operator performance at this stage, and investigating this is outside the scope of this PR. We would greatly appreciate any insights or contributions from the community to help identify and resolve this underlying torch.compile issue.

In the future, we plan to develop a Full CUDA Graph based on FA3, which will not require torch.compile to separate attention operators (and thus won't need to compile the graph), thereby avoiding these negative side effects entirely.

As mentioned above, this feature is only restricted to targets scenarios with **fewer images** or when images are **distributed via ViT DP** scenarios. while adding significant complexity that directly modify the model definition files.

Concern 2: Lack of extensibility

2. Why Model-Layer Modifications Are Necessary

The model-layer changes (persistent buffers and mm_cudagraph_manager parameter) cannot be avoided due to two architectural constraints:

  • Persistent Buffers: torch.compile is applied per-layer (PatchEmbed, VisionBlock, PatchMerger) rather than as a monolithic graph. This means operations between layers (e.g., between PatchEmbed and VisionBlock) are outside the CUDA graph capture scope, causing intermediate tensors like hidden_states and rotary_pos_emb to have non-deterministic memory addresses across invocations. Unlike the LLM where we can simply copy input x to a fixed buffer at the model runner level, here we must add persistent buffers at the model layer and manually copy intermediate results during forward. Each model architecture requires different buffers, making extraction into a common abstraction impractical.

We have to reimplement the same complex logic to all other models. The management of buffers in the model definition file will bloat the code. I am not very familiar with torch compile, can we do the buffer management elsewhere like at Layer class rather than in model definition file?

@tjtanaa
Copy link
Collaborator

tjtanaa commented Feb 5, 2026

@HirokenOvo @Lucaskabela When you are planning and implementing your solutions, does you think the framework is restrictive? iirc that there have been quite a number of attempts to add torch compile feature to the ViT/Multimodal part, and many introduces large complexity and limited usage. It makes me think if we will need to redesign the framework for ViT/Multimodality so that torch compile can be supported natively without largely modifying the model definition file and manually manage the persistent buffers.

And given that it also requires changes to the model runner, maybe we can get more of your thoughts in the model runner V2, trying to make multimodality torch compile and cudagraph compatible.

What is everyone's thoughts on this?

CC @ProExpertProg to the discussion as he is working on the torch compile features e.g. fusion passes. Maybe we can have a new perspective on this.

@HirokenOvo
Copy link
Contributor Author

@tjtanaa During our development process, we indeed encountered several framework constraints, primarily in two areas:

Constraint 1: ViT DP Load Balancing at the Model Layer
Currently, the ViT DP is executed within the model layer (via run_dp_sharded_mrope_vision_model). This means the Model Runner cannot anticipate the data distribution across ranks beforehand and thus cannot handle padding in advance. Consequently, we are forced to pass the mm_cudagraph_manager (managed by the gpu model runner) down to the model layer to apply padding after load balancing but before computation begins.

Constraint 2:
Previous explorations have validated that applying torch.compile on submodules significantly outperforms top-level compilation for ViT. The primary reason is the presence of CPU/Python computations (such as positional embedding generation). Using torch.compile at the top level triggers Graph Breaks due to this logic, resulting in worse performance compared to using torch.compile on submodules (see #23207 (comment)).

To correctly chain independently compiled graph segments generated by torch.compile on submodules, we must ensure stable input addresses via buffers in forward and use context managers like set_is_first_graph_in_mm_encoder_sequence to explicitly hint the correct global sequence boundaries to the backend wrapper.

@LucasWilkinson
Copy link
Collaborator

To avoid torch.compile we could maybe try going straight to FULL-cudagraphs? or is there something preventing that?

@HirokenOvo
Copy link
Contributor Author

To avoid torch.compile we could maybe try going straight to FULL-cudagraphs? or is there something preventing that?

@LucasWilkinson There are some CPU operations within the forward pass that cannot be captured by CUDA Graphs.

pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
axis=0, dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
cu_seqlens = torch.from_numpy(cu_seqlens)

This prevents us from wrapping the entire ViT in a single full CUDA graph, unlike the LLM part. We still need to wrap it at the sub-modules. Therefore, we still need to manually persist memory addresses before calling each sub-module in the forward pass. While using FULL mode would strictly eliminate kernel launch overheads for performance, it does not avoid the code complexity associated with this piecewise capturing and manual buffer management.

@Lucaskabela
Copy link
Contributor

Lucaskabela commented Feb 10, 2026

@HirokenOvo @Lucaskabela When you are planning and implementing your solutions, does you think the framework is restrictive? iirc that there have been quite a number of attempts to add torch compile feature to the ViT/Multimodal part, and many introduces large complexity and limited usage. It makes me think if we will need to redesign the framework for ViT/Multimodality so that torch compile can be supported natively without largely modifying the model definition file and manually manage the persistent buffers.

And given that it also requires changes to the model runner, maybe we can get more of your thoughts in the model runner V2, trying to make multimodality torch compile and cudagraph compatible.

What is everyone's thoughts on this?

CC @ProExpertProg to the discussion as he is working on the torch compile features e.g. fusion passes. Maybe we can have a new perspective on this.

For torch.compile integration on it's own, I think it is quite unrestrictive which perhaps leads to the confusion - we can apply it almost anywhere with minimal complexity (there are some support structures like set_forward_context and set_model_tag that I am going to focus on polishing and try to eliminate, but these are very minimal).

That said, I do think many of the current implementations for ViT models are not written with upmost efficiency in mind as there are some number of ops that have this CPU sync and/or buffer issues. Additionally, having a clear and consistent integration point (like vision_tower would make this integration much more straightforward (similar to the support_torch_compile on LLM backbones).

@wangshangsam
Copy link
Collaborator

To avoid torch.compile we could maybe try going straight to FULL-cudagraphs? or is there something preventing that?

For the MLPerf v6.0 Qwen3-VL submission from NVIDIA, our team actually developed full cudagraph (and piece-wise graph and torch.compile too, though slightly differently from this PR). We are happy to upstream this feature.

cc @b-mu @maxyanghu who developed it, and @ywang96 who's helping to upstream some other coming-from-MLPerf features currently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation multi-modality Related to multi-modality (#4194) needs-rebase nvidia qwen Related to Qwen models v1

Projects

Status: In review

Development

Successfully merging this pull request may close these issues.

9 participants