Skip to content

[VLM] Support ViT Piecewise CUDA Graph for Qwen3-VL#15320

Merged
BBuf merged 1 commit intosgl-project:mainfrom
antgroup:vit_cuda_graph_qwen3_vl
Dec 20, 2025
Merged

[VLM] Support ViT Piecewise CUDA Graph for Qwen3-VL#15320
BBuf merged 1 commit intosgl-project:mainfrom
antgroup:vit_cuda_graph_qwen3_vl

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented Dec 17, 2025

Motivation

This PR is to enable ViT Piecewise CUDA Graph for Qwen3-VL.
Building logic upon ViTCudaGraphRunner to support both Qwen2.5-VL and Qwen3-VL.
TP>1 is supported.

Benchmark show 8xH20 Qwen3-VL-8B-Instruct TP=4
TTFT 1384.53ms --> 1120.68ms

Meanwhile fixed a bug that torch-symm-mem is not enabled for outplace allreduce. It gains about 4% e2e improvement over NCCL. (As custom all reduce has not supported CUDA Graph yet, we have to disable-custom-all-reduce. Thanks to torch-symm-mem, which gives extra 4% speedup over legacy NCCL TP. But the TTFT comparation in this PR is torch-symm-mem vs torch-symm-mem, the only difference is enable/disable ViT Piecewise CUDA Graph)

The sweet spot for this feature is each rank's compute is relatively small, (i.e. TP is 4, image size is not large, no compute bound in prefill) so the kernel launch occupies large portion. ViT in PCG can save this time cost.

Detailed Design

image

Accuracy Tests

The update in test/manual/nightly/test_vlms_vit_cuda_graph.py case covers the accuracy.

Benchmarking and Profiling

8xH20 Qwen3-VL-8B-Instruct TP=4
TTFT 1384.53ms --> 1120.68ms

PR:
Server:

SGLANG_MM_FEATURE_CACHE_MB=4096 \
SGLANG_USE_CUDA_IPC_TRANSPORT=1 \
SGLANG_VLM_CACHE_SIZE_MB=0 \
SGLANG_VIT_ENABLE_CUDA_GRAPH=1 \
python3 -m sglang.launch_server --host 127.0.0.1 --mem-fraction-static 0.7 --port 30000 --max-running-requests 64 --chunked-prefill-size 8192 --attention-backend fa3 --mm-attention-backend fa3 --enable-multimodal --model /home/admin/Qwen3-VL-8B-Instruct --disable-radix-cache --piecewise-cuda-graph-max-tokens 4096 --enable-piecewise-cuda-graph --piecewise-cuda-graph-compiler eager --tp-size 4 --disable-custom-all-reduce --enable-torch-symm-mem

Client:

$python3 -m sglang.bench_serving \
  --backend sglang-oai-chat \
  --dataset-name image \
  --num-prompts 256 \
  --apply-chat-template \
  --random-input-len 128 \
  --random-output-len 1 \
  --image-resolution 560x560 \
  --image-format jpeg \
  --image-count 1 \
  --image-content random \
  --random-range-ratio 0.1 \
  --port 30000 \
  --max-concurrency 32 \
  --warmup-requests 5

Benchmark Result:

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 32        
Successful requests:                     256       
Benchmark duration (s):                  9.27      
Total input tokens:                      103986    
Total input text tokens:                 20530     
Total input vision tokens:               83456     
Total generated tokens:                  256       
Total generated tokens (retokenized):    256       
Request throughput (req/s):              27.61     
Input token throughput (tok/s):          11213.80  
Output token throughput (tok/s):         27.61     
Peak output token throughput (tok/s):    47.00     
Peak concurrent requests:                65        
Total token throughput (tok/s):          11241.41  
Concurrency:                             30.94     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   1120.68   
Median E2E Latency (ms):                 1072.01   
---------------Time to First Token----------------
Mean TTFT (ms):                          1120.67   
Median TTFT (ms):                        1072.00   
P99 TTFT (ms):                           1836.13   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          0.00      
Median TPOT (ms):                        0.00      
P99 TPOT (ms):                           0.00      
---------------Inter-Token Latency----------------
Mean ITL (ms):                           0.00      
Median ITL (ms):                         0.00      
P95 ITL (ms):                            0.00      
P99 ITL (ms):                            0.00      
Max ITL (ms):                            0.00      
==================================================

Baseline:
Server:

SGLANG_MM_FEATURE_CACHE_MB=4096 \
SGLANG_USE_CUDA_IPC_TRANSPORT=1 \
SGLANG_VLM_CACHE_SIZE_MB=0 \
python3 -m sglang.launch_server --host 127.0.0.1 --mem-fraction-static 0.7 --port 30000 --max-running-requests 64 --chunked-prefill-size 8192 --attention-backend fa3 --mm-attention-backend fa3 --enable-multimodal --model /home/admin/Qwen3-VL-8B-Instruct --disable-radix-cache --piecewise-cuda-graph-max-tokens 4096 --enable-piecewise-cuda-graph --piecewise-cuda-graph-compiler eager --tp-size 4 --disable-custom-all-reduce --enable-torch-symm-mem

Client:
Same as above.

Result:

============ Serving Benchmark Result ============
Backend:                                 sglang-oai-chat
Traffic request rate:                    inf       
Max request concurrency:                 32        
Successful requests:                     256       
Benchmark duration (s):                  11.33     
Total input tokens:                      104004    
Total input text tokens:                 20548     
Total input vision tokens:               83456     
Total generated tokens:                  256       
Total generated tokens (retokenized):    252       
Request throughput (req/s):              22.59     
Input token throughput (tok/s):          9178.06   
Output token throughput (tok/s):         22.59     
Peak output token throughput (tok/s):    32.00     
Peak concurrent requests:                64        
Total token throughput (tok/s):          9200.65   
Concurrency:                             31.28     
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   1384.53   
Median E2E Latency (ms):                 1351.33   
---------------Time to First Token----------------
Mean TTFT (ms):                          1361.75   
Median TTFT (ms):                        1350.20   
P99 TTFT (ms):                           1982.05   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          0.00      
Median TPOT (ms):                        0.00      
P99 TPOT (ms):                           0.00      
---------------Inter-Token Latency----------------
Mean ITL (ms):                           0.00      
Median ITL (ms):                         0.00      
P95 ITL (ms):                            0.00      
P99 ITL (ms):                            0.00      
Max ITL (ms):                            0.00      
==================================================

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @yuan-luo, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of Vision Transformer (ViT) operations within the SGLang framework by integrating Piecewise CUDA Graph support for the Qwen3-VL model. It refines the existing ViTCudaGraphRunner to generically handle both Qwen2.5-VL's windowed attention and Qwen3-VL's deepstack components, ensuring optimized execution of the core vision processing blocks. The changes involve structural modifications to the Qwen3-VL model to enable CUDA graph execution and updates to the graph runner to manage different model architectures effectively.

Highlights

  • Qwen3-VL CUDA Graph Support: Enabled ViT Piecewise CUDA Graph for the Qwen3-VL model, leveraging the existing ViTCudaGraphRunner for optimized performance.
  • Generic ViTCudaGraphRunner Enhancements: Extended the ViTCudaGraphRunner to generically support both Qwen2.5-VL's windowed attention and Qwen3-VL's deepstack architecture, making it more versatile for different model structures.
  • Qwen3-VL Integration Details: Implemented a new forward_with_cuda_graph method in Qwen3VisionModel to utilize the CUDA graph runner for efficient vision block processing, activated via an environment variable.
  • Test Suite Expansion and Adjustment: Added Qwen3-VL-8B-Instruct to the VLM CUDA graph test suite and adjusted the piecewise-cuda-graph-max-tokens parameter in the test script for compatibility and performance tuning.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enables ViT Piecewise CUDA Graph for Qwen3-VL by generalizing the ViTCudaGraphRunner. The changes look good overall, but I've identified a critical issue with missing optional type hints that would cause a TypeError, a high-severity issue with a missing tensor-parallelism check, and a few medium-severity suggestions to improve code clarity and maintainability. Addressing these points will make the implementation more robust and easier to maintain.

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

#15321

@yhyang201
Copy link
Copy Markdown
Collaborator

yhyang201 commented Dec 17, 2025

Good job!
Could you please share insights on the performance gains and the types of workloads under which ViT PCG delivers the most benefit? I’m quite curious about this—thank you~

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

Could you please share insights on the performance gains and the types of workloads under which ViT PCG delivers the most benefit? I’m quite curious about this—thank you~

@yhyang201 Thanks. will do it soon.

@yuan-luo yuan-luo force-pushed the vit_cuda_graph_qwen3_vl branch from 1e2b1d3 to 9bfbd0f Compare December 18, 2025 02:20
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Dec 18, 2025

After rebasing main and resolving conflict, as the Qwen3-VL Qwen3VLMoeVisionModel rot_pos_emb() function's return value has been changed in #15205, the function're result is broken. Fixing in progress. Setting WIP for the moment.

[2025-12-18 11:31:16 TP0] Prefill batch, #new-seq: 1, #new-token: 78, #cached-token: 0, token usage: 0.00, #running-req: 0, #queue-req: 0, 
[2025-12-18 11:31:17 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 2796, in run_scheduler_process
    scheduler.event_loop_overlap()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 1065, in event_loop_overlap
    batch_result = self.run_batch(batch)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 2103, in run_batch
    batch_result = self.model_worker.forward_batch_generation(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 400, in forward_batch_generation
    logits_output, can_run_cuda_graph = self.model_runner.forward(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 2736, in forward
    output = self._forward_raw(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 2806, in _forward_raw
    ret = self.forward_extend(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 2676, in forward_extend
    return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/model_executor/piecewise_cuda_graph_runner.py", line 614, in replay
    output = self.model_runner.model.forward(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_vl.py", line 930, in forward
    hidden_states = general_mm_embed_routine(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/mm_utils.py", line 1056, in general_mm_embed_routine
    input_embeds, other_info = embed_mm_inputs(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/mm_utils.py", line 931, in embed_mm_inputs
    embedding, mask = get_embedding_and_mask(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/mm_utils.py", line 836, in get_embedding_and_mask
    embedding = _get_chunked_prefill_embedding(
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/managers/mm_utils.py", line 542, in _get_chunked_prefill_embedding
    embedding_per_req = data_embedding_func(embedding_items_per_req)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_vl.py", line 788, in get_image_feature
    return self.visual(pixel_values, grid_thw=image_grid_thw)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_vl.py", line 470, in forward
    return self.forward_with_cuda_graph(x, grid_thw)
  File "/opt/conda/lib/python3.10/site-packages/sglang/srt/models/qwen3_vl.py", line 528, in forward_with_cuda_graph
    rotary_pos_emb = self.rot_pos_emb(grid_thw).to(device=x.device, dtype=x.dtype)
AttributeError: 'tuple' object has no attribute 'to'

@yuan-luo yuan-luo changed the title [VLM] Support ViT Piecewise CUDA Graph for Qwen3-VL [WIP][VLM] Support ViT Piecewise CUDA Graph for Qwen3-VL Dec 18, 2025
@yuan-luo yuan-luo force-pushed the vit_cuda_graph_qwen3_vl branch from 9bfbd0f to 9d46a2f Compare December 18, 2025 15:14
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

Refactored ViTCudaGraphRunner to support new interface. Removing [WIP].

@yuan-luo yuan-luo changed the title [WIP][VLM] Support ViT Piecewise CUDA Graph for Qwen3-VL [VLM] Support ViT Piecewise CUDA Graph for Qwen3-VL Dec 18, 2025
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented Dec 18, 2025

Good job! Could you please share insights on the performance gains and the types of workloads under which ViT PCG delivers the most benefit? I’m quite curious about this—thank you~

@yhyang201 Benchmark has been updated.

@JustinTong0323
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

3 similar comments
@JustinTong0323
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@JustinTong0323
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 20, 2025

/rerun-failed-ci

Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

4 similar comments
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

/rerun-failed-ci

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Dec 20, 2025

@BBuf BBuf merged commit 019517a into sgl-project:main Dec 20, 2025
250 of 289 checks passed
@yuan-luo yuan-luo deleted the vit_cuda_graph_qwen3_vl branch December 22, 2025 10:52
jiaming1130 pushed a commit to zhuyijie88/sglang that referenced this pull request Dec 25, 2025
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants