Skip to content

support qwen3_vl vision model dp#13724

Merged
mickqian merged 7 commits intosgl-project:mainfrom
Lzhang-hub:qwen3_vl_vision_dp
Nov 28, 2025
Merged

support qwen3_vl vision model dp#13724
mickqian merged 7 commits intosgl-project:mainfrom
Lzhang-hub:qwen3_vl_vision_dp

Conversation

@Lzhang-hub
Copy link
Contributor

@Lzhang-hub Lzhang-hub commented Nov 21, 2025

Motivation

Based on PR 13126, add support for the Qwen3_VL vision model DP.

Modifications

Accuracy Tests

Qwen/Qwen3-VL-32B-Instruct

- without dp
| Tasks  |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|--------|------:|------|-----:|--------|---|-----:|---|------|
|mmmu_val|      0|none  |     0|mmmu_acc|↑  |0.6067|±  |   N/A|

- with dp
| Tasks  |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|--------|------:|------|-----:|--------|---|-----:|---|------|
|mmmu_val|      0|none  |     0|mmmu_acc|↑  |0.6078|±  |   N/A|

### Qwen/Qwen3-VL-235B-A22B-Instruct

 - without dp:
| Tasks  |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|--------|------:|------|-----:|--------|---|-----:|---|------|
|mmmu_val|      0|none  |     0|mmmu_acc|↑  |0.6456|±  |   N/A|

- with dp:
| Tasks  |Version|Filter|n-shot| Metric |   |Value|   |Stderr|
|--------|------:|------|-----:|--------|---|----:|---|------|
|mmmu_val|      0|none  |     0|mmmu_acc|↑  | 0.64|±  |   N/A|

Benchmarking and Profiling

server cmd:

python -m sglang.launch_server --host 0.0.0.0 --port 8080 --model-path Qwen/Qwen3-VL-32B-Instruct --served-model-name test --trust-remote-code --disable-radix-cache --tp 4 --mem-fraction-static 0.85  --mm-attention-backend fa3 --attention-backend flashinfer --mm-enable-dp-encoder 

bench cmd:


python -m sglang.bench_serving \
    --dataset-name image \
    --host 0.0.0.0 \
    --port 8080 \
    --random-input-len 200 \
    --image-count  4 \
    --image-resolution 1280x720 \
    --random-output-len 1 \
    --backend vllm-chat \
    --num-prompts 50 \
    --request-rate 0.4 \
    --warmup-requests=0  \
    --model Qwen/Qwen3-VL-32B-Instruct

| model | tp | Img count | Img resolution | base(ms) | dp(ms) | improment |
| --- | --- | --- | --- | --- | --- | --- |
| Qwen3-VL-32B-Instruct | 4 | 4 | 720p | 1205.37 | 1131.91 | -6.2% |
| Qwen3-VL-235B-A22B-Instruct | 4 | 4 | 720p | 1112.01 | 1075.99 | -3.4% |
| Qwen3-VL-235B-A22B-Instruct | 4 | 4 | 1280×1280 | 4803.98 | 4522.80 | -6.2% |

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Lzhang-hub, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Qwen3_VL vision model by implementing data parallelism capabilities. This integration of distributed computing mechanisms allows the vision encoder to process multimodal data more efficiently, leading to improved inference speeds as validated by the provided benchmark results.

Highlights

  • Data Parallelism for Qwen3_VL: This pull request introduces explicit support for data parallelism (DP) within the Qwen3_VL vision model, building upon previous work.
  • Performance Improvement: Benchmarking results indicate a notable performance enhancement, with processing time decreasing from 1205.37ms to 1131.91ms when data parallelism is enabled for image processing.
  • Distributed Utilities Integration: Distributed processing utilities, including functions for tensor model parallel rank and world size, and a sharded vision model runner, have been integrated into the Qwen3_VL model's architecture.
  • Modular DP Configuration: Key vision model components such as Qwen3_VisionMLP, Qwen3VLVisionBlock, and Qwen3VLVisionMerger are now configured to accept a use_data_parallel flag, enabling conditional application of tensor parallelism parameters.
  • Dynamic DP Activation: The Qwen3VLMoeVisionModel dynamically activates data parallelism based on the mm_enable_dp_encoder setting retrieved from global server arguments.
  • Attribute Renaming for Consistency: The output dimension attribute has been standardized from out_hidden_size to out_dim across relevant classes in qwen2_5_vl.py, qwen3_vl.py, and mm_utils.py for improved consistency.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds data parallelism support for the Qwen3-VL vision model. The changes correctly propagate a use_data_parallel flag to disable tensor parallelism in the vision encoder components when DP is enabled. The core logic for DP is added to get_image_feature and get_video_feature methods, which now use run_dp_sharded_mrope_vision_model for sharded execution. The changes look correct and follow a standard pattern for adding DP support. I have a couple of suggestions to improve maintainability by reducing code duplication.

Comment on lines +77 to +78
self.tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for determining tp_size and tp_rank is also used in Qwen3VLMoeVisionPatchMerger. To avoid code duplication and improve maintainability, consider extracting this into a helper function.

Comment on lines +703 to 709
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, video_grid_thw.tolist(), rope_type="rope_3d"
)
else:
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
return video_embeds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in this if/else block is almost identical to the one in get_image_feature. The surrounding functions get_image_feature and get_video_feature are also very similar. To improve maintainability and reduce redundancy, you could extract this logic into a private helper method that accepts the grid_thw_attr as an argument.

For example:

    def _get_media_feature(self, items: List[MultimodalDataItem], grid_thw_attr: str) -> torch.Tensor:
        # in qwen-vl, last dim is the same
        pixel_values = torch.cat([item.feature for item in items], dim=0).type(
            self.visual.dtype
        )
        grid_thw = torch.concat([getattr(item, grid_thw_attr) for item in items], dim=0)
        assert pixel_values.dim() == 2, pixel_values.dim()
        assert grid_thw.dim() == 2, grid_thw.dim()
        if self.use_data_parallel:
            return run_dp_sharded_mrope_vision_model(
                self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
            )
        
        return self.visual(pixel_values, grid_thw=grid_thw)

    def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
        return self._get_media_feature(items, "image_grid_thw")

    def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
        return self._get_media_feature(items, "video_grid_thw")

@vincentzed
Copy link
Contributor

Is it tested for MoE?

@Lzhang-hub
Copy link
Contributor Author

Is it tested for MoE?

Add acc and perf bench for Qwen3-VL-235B-A22B-Instruct

@Lzhang-hub
Copy link
Contributor Author

@Lzhang-hub could you tag the label precisely?

The label added by github-actions bot. It seems I don’t have permission to remove it. If you could help remove it, I’d really appreciate it.

@ShangmingCai ShangmingCai removed documentation Improvements or additions to documentation quant LLM Quantization amd dependencies Pull requests that update a dependency file lora deepseek hicache Hierarchical Caching for SGLang sgl-kernel blackwell SM100/SM120 npu piecewise-cuda-graph diffusion SGLang Diffusion model-gateway labels Nov 26, 2025
@Lzhang-hub
Copy link
Contributor Author

Thanks for @ShangmingCai removed label, I rebase main.
@yuan-luo Do you think it’s ready to be merged?

else:
# Handle empty case
image_embeds_local = torch.empty(
(0, vision_model.out_hidden_size),
Copy link
Collaborator

@yuan-luo yuan-luo Nov 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug. Not only qwen2_5_vl and qwen3_vl will use this run_dp_sharded_mrope_vision_model, more and more vlm models with mrope might adopt this function to support DP, but they are not necessarily owning vision_model.out_dim. All the models are to be impacted.
Change back to vision_model.out_hidden_size here. Meanwhile, change back to out_hidden_size in qwen3_vl.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

Copy link
Collaborator

@yuan-luo yuan-luo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@yuan-luo
Copy link
Collaborator

/tag-and-rerun-ci

@yhyang201
Copy link
Collaborator

/tag-and-rerun-ci

@mickqian mickqian merged commit ea1e9f6 into sgl-project:main Nov 28, 2025
282 of 312 checks passed
harvenstar pushed a commit to harvenstar/sglang that referenced this pull request Dec 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants