diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 7026777129ac..593bfceca76d 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -28,6 +28,10 @@ ) from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor @@ -47,6 +51,8 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen3 import Qwen3Model from sglang.srt.models.utils import compute_cu_seqlens_from_grid_numpy +from sglang.srt.multimodal.mm_utils import run_dp_sharded_mrope_vision_model +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix from sglang.srt.utils.hf_transformers_utils import get_processor @@ -66,14 +72,21 @@ def __init__( hidden_act="silu", quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() self.linear_fc1 = ColumnParallelLinear( in_features, hidden_features, bias=bias, quant_config=quant_config, prefix=add_prefix("linear_fc1", prefix), + tp_size=self.tp_size, + tp_rank=self.tp_rank, ) self.linear_fc2 = RowParallelLinear( hidden_features, @@ -81,6 +94,8 @@ def __init__( bias=bias, quant_config=quant_config, prefix=add_prefix("linear_fc2", prefix), + tp_size=self.tp_size, + tp_rank=self.tp_rank, ) self.act = ACT2FN[hidden_act] @@ -133,6 +148,7 @@ def __init__( norm_layer: Optional[Callable[[int], nn.Module]] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() if norm_layer is None: @@ -149,6 +165,7 @@ def __init__( flatten_batch=True, quant_config=quant_config, prefix=add_prefix("attn", prefix), + use_data_parallel=use_data_parallel, ) self.mlp = Qwen3_VisionMLP( dim, @@ -157,6 +174,7 @@ def __init__( bias=True, quant_config=quant_config, prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, ) def forward( @@ -191,6 +209,7 @@ def __init__( use_postshuffle_norm: bool = False, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) @@ -202,12 +221,18 @@ def __init__( self.norm = norm_layer( self.hidden_size if use_postshuffle_norm else context_dim ) + self.tp_size = ( + 1 if use_data_parallel else get_tensor_model_parallel_world_size() + ) + self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() self.linear_fc1 = ColumnParallelLinear( self.hidden_size, self.hidden_size, bias=True, quant_config=quant_config, prefix=add_prefix("linear_fc1", prefix), + tp_size=self.tp_size, + tp_rank=self.tp_rank, ) self.act_fn = nn.GELU() self.linear_fc2 = RowParallelLinear( @@ -216,6 +241,8 @@ def __init__( bias=True, quant_config=quant_config, prefix=add_prefix("linear_fc2", prefix), + tp_size=self.tp_size, + tp_rank=self.tp_rank, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -238,6 +265,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -247,8 +275,12 @@ def __init__( self.spatial_merge_size = vision_config.spatial_merge_size self.spatial_merge_unit = self.spatial_merge_size**2 self.temporal_patch_size = vision_config.temporal_patch_size + self.use_data_parallel = use_data_parallel # layer indexes of which layer's output should be deep-stacked self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes + self.out_hidden_size = vision_config.out_hidden_size * ( + 1 + len(self.deepstack_visual_indexes) + ) self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config) self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size) norm_layer = partial(nn.LayerNorm, eps=norm_eps) @@ -265,6 +297,7 @@ def __init__( norm_layer=norm_layer, quant_config=quant_config, prefix=add_prefix(f"blocks.{layer_idx}", prefix), + use_data_parallel=use_data_parallel, ) for layer_idx in range(vision_config.depth) ] @@ -276,6 +309,7 @@ def __init__( spatial_merge_size=self.spatial_merge_size, quant_config=quant_config, prefix=add_prefix("merger", prefix), + use_data_parallel=use_data_parallel, ) self.deepstack_merger_list = nn.ModuleList( @@ -288,6 +322,7 @@ def __init__( norm_layer=norm_layer, quant_config=quant_config, prefix=add_prefix(f"deepstack_merger_list.{layer_idx}", prefix), + use_data_parallel=use_data_parallel, ) for layer_idx in range(len(self.deepstack_visual_indexes)) ] @@ -582,6 +617,7 @@ def __init__( ) -> None: super().__init__() + self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder self.visual = Qwen3VLMoeVisionModel( config.vision_config, # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization. @@ -589,6 +625,7 @@ def __init__( quant_config=quant_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), prefix=add_prefix("visual", prefix), + use_data_parallel=self.use_data_parallel, ) # TODO: make it more elegant @@ -646,7 +683,12 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0) assert pixel_values.dim() == 2, pixel_values.dim() assert image_grid_thw.dim() == 2, image_grid_thw.dim() - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, image_grid_thw.tolist(), rope_type="rope_3d" + ) + else: + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) return image_embeds def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: @@ -657,7 +699,12 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor: video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0) assert pixel_values.dim() == 2, pixel_values.dim() assert video_grid_thw.dim() == 2, video_grid_thw.dim() - video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw) + if self.use_data_parallel: + return run_dp_sharded_mrope_vision_model( + self.visual, pixel_values, video_grid_thw.tolist(), rope_type="rope_3d" + ) + else: + video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw) return video_embeds def get_input_embeddings(self): diff --git a/test/nightly/test_encoder_dp.py b/test/nightly/test_encoder_dp.py index ed82320745e5..c3c8b4513652 100644 --- a/test/nightly/test_encoder_dp.py +++ b/test/nightly/test_encoder_dp.py @@ -19,6 +19,7 @@ MODELS = [ SimpleNamespace(model="Qwen/Qwen2.5-VL-72B-Instruct", mmmu_accuracy=0.55), + SimpleNamespace(model="Qwen/Qwen3-VL-32B-Instruct", mmmu_accuracy=0.55), SimpleNamespace(model="OpenGVLab/InternVL2_5-8B", mmmu_accuracy=0.52), ]