diff --git a/python/sglang/srt/models/qwen3_omni_moe.py b/python/sglang/srt/models/qwen3_omni_moe.py index ae5b8332d222..5d964ebdc3dd 100644 --- a/python/sglang/srt/models/qwen3_omni_moe.py +++ b/python/sglang/srt/models/qwen3_omni_moe.py @@ -30,6 +30,10 @@ Qwen3OmniMoeVisionEncoderConfig, ) from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE @@ -42,6 +46,7 @@ Qwen3VLMoeForConditionalGeneration, load_fused_expert_weights, ) +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, logger @@ -51,6 +56,7 @@ def __init__( config: Qwen3OmniMoeAudioEncoderConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ): super().__init__() embed_dim = config.d_model @@ -64,6 +70,7 @@ def __init__( flatten_batch=True, quant_config=quant_config, prefix=add_prefix("attn", prefix), + use_data_parallel=use_data_parallel, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -151,7 +158,9 @@ def _get_feat_extract_output_lengths(input_lengths): class Qwen3OmniMoeAudioEncoder(PreTrainedModel): config: Qwen3OmniMoeAudioEncoderConfig - def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig): + def __init__( + self, config: Qwen3OmniMoeAudioEncoderConfig, use_data_parallel: bool = False + ): super().__init__(config) self.dropout = config.dropout @@ -165,7 +174,9 @@ def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig): ) self.layers = nn.ModuleList( [ - Qwen3OmniMoeAudioEncoderLayer(config) + Qwen3OmniMoeAudioEncoderLayer( + config=config, use_data_parallel=use_data_parallel + ) for _ in range(config.encoder_layers) ] ) @@ -313,6 +324,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", use_postshuffle_norm=False, + use_data_parallel: bool = False, ) -> None: super().__init__() self.hidden_size = context_dim * (spatial_merge_size**2) @@ -320,6 +332,10 @@ def __init__( self.ln_q = nn.LayerNorm( self.hidden_size if use_postshuffle_norm else context_dim, eps=1e-6 ) + + tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size() + tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank() + self.mlp = nn.ModuleList( [ ColumnParallelLinear( @@ -328,6 +344,8 @@ def __init__( bias=True, quant_config=quant_config, prefix=add_prefix("mlp.0", prefix), + tp_size=tp_size, + tp_rank=tp_rank, ), nn.GELU(), RowParallelLinear( @@ -336,6 +354,8 @@ def __init__( bias=True, quant_config=quant_config, prefix=add_prefix("mlp.2", prefix), + tp_size=tp_size, + tp_rank=tp_rank, ), ] ) @@ -366,12 +386,14 @@ def __init__( config: Qwen3OmniMoeVisionEncoderConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = None, + use_data_parallel: bool = False, **kwargs, ): super().__init__( vision_config=config, quant_config=quant_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), + use_data_parallel=use_data_parallel, ) self.merger = Qwen3OmniMoeVisionPatchMerger( @@ -381,6 +403,7 @@ def __init__( quant_config=quant_config, use_postshuffle_norm=False, prefix=add_prefix("merger", prefix), + use_data_parallel=use_data_parallel, ) self.merger_list = nn.ModuleList( [ @@ -391,6 +414,7 @@ def __init__( use_postshuffle_norm=True, quant_config=quant_config, prefix=add_prefix("merger_list", prefix), + use_data_parallel=use_data_parallel, ) for _ in range(len(config.deepstack_visual_indexes)) ] @@ -422,12 +446,16 @@ def __init__( super().__init__( config, quant_config, prefix, language_model_cls=Qwen3MoeLLMModel ) - self.audio_tower = Qwen3OmniMoeAudioEncoder(config.audio_config) + self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder + self.audio_tower = Qwen3OmniMoeAudioEncoder( + config.audio_config, self.use_data_parallel + ) self.visual = Qwen3OmniMoeVisionEncoder( config.vision_config, quant_config=quant_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), prefix=add_prefix("visual", prefix), + use_data_parallel=self.use_data_parallel, ) self.pad_token_id = ( self.config.pad_token_id if self.config.pad_token_id is not None else -1 diff --git a/test/nightly/test_encoder_dp.py b/test/nightly/test_encoder_dp.py index a18075f71e7c..ef81ea24e1af 100644 --- a/test/nightly/test_encoder_dp.py +++ b/test/nightly/test_encoder_dp.py @@ -21,6 +21,7 @@ register_cuda_ci(est_time=500, suite="nightly-4-gpu", nightly=True) MODELS = [ + SimpleNamespace(model="Qwen/Qwen3-Omni-30B-A3B-Instruct", mmmu_accuracy=0.55), SimpleNamespace(model="Qwen/Qwen2.5-VL-72B-Instruct", mmmu_accuracy=0.55), SimpleNamespace(model="Qwen/Qwen3-VL-32B-Instruct", mmmu_accuracy=0.55), SimpleNamespace(model="OpenGVLab/InternVL2_5-8B", mmmu_accuracy=0.52),