Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions python/sglang/srt/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
)

from sglang.srt.configs.qwen3_vl import Qwen3VLConfig, Qwen3VLVisionConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
Expand All @@ -47,6 +51,8 @@
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen3 import Qwen3Model
from sglang.srt.models.utils import compute_cu_seqlens_from_grid_numpy
from sglang.srt.multimodal.mm_utils import run_dp_sharded_mrope_vision_model
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix
from sglang.srt.utils.hf_transformers_utils import get_processor

Expand All @@ -66,21 +72,30 @@ def __init__(
hidden_act="silu",
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
self.linear_fc1 = ColumnParallelLinear(
in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("linear_fc1", prefix),
tp_size=self.tp_size,
tp_rank=self.tp_rank,
)
self.linear_fc2 = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("linear_fc2", prefix),
tp_size=self.tp_size,
tp_rank=self.tp_rank,
)
self.act = ACT2FN[hidden_act]

Expand Down Expand Up @@ -133,6 +148,7 @@ def __init__(
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
if norm_layer is None:
Expand All @@ -149,6 +165,7 @@ def __init__(
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
use_data_parallel=use_data_parallel,
)
self.mlp = Qwen3_VisionMLP(
dim,
Expand All @@ -157,6 +174,7 @@ def __init__(
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)

def forward(
Expand Down Expand Up @@ -191,6 +209,7 @@ def __init__(
use_postshuffle_norm: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
Expand All @@ -202,12 +221,18 @@ def __init__(
self.norm = norm_layer(
self.hidden_size if use_postshuffle_norm else context_dim
)
self.tp_size = (
1 if use_data_parallel else get_tensor_model_parallel_world_size()
)
self.tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()
self.linear_fc1 = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=add_prefix("linear_fc1", prefix),
tp_size=self.tp_size,
tp_rank=self.tp_rank,
)
self.act_fn = nn.GELU()
self.linear_fc2 = RowParallelLinear(
Expand All @@ -216,6 +241,8 @@ def __init__(
bias=True,
quant_config=quant_config,
prefix=add_prefix("linear_fc2", prefix),
tp_size=self.tp_size,
tp_rank=self.tp_rank,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -238,6 +265,7 @@ def __init__(
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = vision_config.hidden_size
Expand All @@ -247,8 +275,12 @@ def __init__(
self.spatial_merge_size = vision_config.spatial_merge_size
self.spatial_merge_unit = self.spatial_merge_size**2
self.temporal_patch_size = vision_config.temporal_patch_size
self.use_data_parallel = use_data_parallel
# layer indexes of which layer's output should be deep-stacked
self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes
self.out_hidden_size = vision_config.out_hidden_size * (
1 + len(self.deepstack_visual_indexes)
)
self.patch_embed = Qwen3VLVisionPatchEmbed(config=vision_config)
self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
Expand All @@ -265,6 +297,7 @@ def __init__(
norm_layer=norm_layer,
quant_config=quant_config,
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
use_data_parallel=use_data_parallel,
)
for layer_idx in range(vision_config.depth)
]
Expand All @@ -276,6 +309,7 @@ def __init__(
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=add_prefix("merger", prefix),
use_data_parallel=use_data_parallel,
)

self.deepstack_merger_list = nn.ModuleList(
Expand All @@ -288,6 +322,7 @@ def __init__(
norm_layer=norm_layer,
quant_config=quant_config,
prefix=add_prefix(f"deepstack_merger_list.{layer_idx}", prefix),
use_data_parallel=use_data_parallel,
)
for layer_idx in range(len(self.deepstack_visual_indexes))
]
Expand Down Expand Up @@ -582,13 +617,15 @@ def __init__(
) -> None:
super().__init__()

self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder
self.visual = Qwen3VLMoeVisionModel(
config.vision_config,
# NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
quant_config=quant_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
prefix=add_prefix("visual", prefix),
use_data_parallel=self.use_data_parallel,
)

# TODO: make it more elegant
Expand Down Expand Up @@ -646,7 +683,12 @@ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert image_grid_thw.dim() == 2, image_grid_thw.dim()
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, image_grid_thw.tolist(), rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
return image_embeds

def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
Expand All @@ -657,7 +699,12 @@ def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
video_grid_thw = torch.concat([item.video_grid_thw for item in items], dim=0)
assert pixel_values.dim() == 2, pixel_values.dim()
assert video_grid_thw.dim() == 2, video_grid_thw.dim()
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, video_grid_thw.tolist(), rope_type="rope_3d"
)
else:
video_embeds = self.visual(pixel_values, grid_thw=video_grid_thw)
return video_embeds
Comment on lines +702 to 708
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in this if/else block is almost identical to the one in get_image_feature. The surrounding functions get_image_feature and get_video_feature are also very similar. To improve maintainability and reduce redundancy, you could extract this logic into a private helper method that accepts the grid_thw_attr as an argument.

For example:

    def _get_media_feature(self, items: List[MultimodalDataItem], grid_thw_attr: str) -> torch.Tensor:
        # in qwen-vl, last dim is the same
        pixel_values = torch.cat([item.feature for item in items], dim=0).type(
            self.visual.dtype
        )
        grid_thw = torch.concat([getattr(item, grid_thw_attr) for item in items], dim=0)
        assert pixel_values.dim() == 2, pixel_values.dim()
        assert grid_thw.dim() == 2, grid_thw.dim()
        if self.use_data_parallel:
            return run_dp_sharded_mrope_vision_model(
                self.visual, pixel_values, grid_thw.tolist(), rope_type="rope_3d"
            )
        
        return self.visual(pixel_values, grid_thw=grid_thw)

    def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
        return self._get_media_feature(items, "image_grid_thw")

    def get_video_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
        return self._get_media_feature(items, "video_grid_thw")


def get_input_embeddings(self):
Expand Down
1 change: 1 addition & 0 deletions test/nightly/test_encoder_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

MODELS = [
SimpleNamespace(model="Qwen/Qwen2.5-VL-72B-Instruct", mmmu_accuracy=0.55),
SimpleNamespace(model="Qwen/Qwen3-VL-32B-Instruct", mmmu_accuracy=0.55),
SimpleNamespace(model="OpenGVLab/InternVL2_5-8B", mmmu_accuracy=0.52),
]

Expand Down
Loading