Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions python/sglang/srt/models/qwen3_omni_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
Qwen3OmniMoeVisionEncoderConfig,
)
from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
Expand All @@ -42,6 +46,7 @@
Qwen3VLMoeForConditionalGeneration,
load_fused_expert_weights,
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, logger


Expand All @@ -51,6 +56,7 @@ def __init__(
config: Qwen3OmniMoeAudioEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
embed_dim = config.d_model
Expand All @@ -64,6 +70,7 @@ def __init__(
flatten_batch=True,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
use_data_parallel=use_data_parallel,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
Expand Down Expand Up @@ -151,7 +158,9 @@ def _get_feat_extract_output_lengths(input_lengths):
class Qwen3OmniMoeAudioEncoder(PreTrainedModel):
config: Qwen3OmniMoeAudioEncoderConfig

def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig):
def __init__(
self, config: Qwen3OmniMoeAudioEncoderConfig, use_data_parallel: bool = False
):
super().__init__(config)
self.dropout = config.dropout

Expand All @@ -165,7 +174,9 @@ def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig):
)
self.layers = nn.ModuleList(
[
Qwen3OmniMoeAudioEncoderLayer(config)
Qwen3OmniMoeAudioEncoderLayer(
config=config, use_data_parallel=use_data_parallel
)
for _ in range(config.encoder_layers)
]
)
Expand Down Expand Up @@ -313,13 +324,18 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_postshuffle_norm=False,
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2)
self.use_postshuffle_norm = use_postshuffle_norm
self.ln_q = nn.LayerNorm(
self.hidden_size if use_postshuffle_norm else context_dim, eps=1e-6
)

tp_size = 1 if use_data_parallel else get_tensor_model_parallel_world_size()
tp_rank = 0 if use_data_parallel else get_tensor_model_parallel_rank()

self.mlp = nn.ModuleList(
[
ColumnParallelLinear(
Expand All @@ -328,6 +344,8 @@ def __init__(
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.0", prefix),
tp_size=tp_size,
tp_rank=tp_rank,
),
nn.GELU(),
RowParallelLinear(
Expand All @@ -336,6 +354,8 @@ def __init__(
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.2", prefix),
tp_size=tp_size,
tp_rank=tp_rank,
),
]
)
Expand Down Expand Up @@ -366,12 +386,14 @@ def __init__(
config: Qwen3OmniMoeVisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = None,
use_data_parallel: bool = False,
**kwargs,
):
super().__init__(
vision_config=config,
quant_config=quant_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
use_data_parallel=use_data_parallel,
)

self.merger = Qwen3OmniMoeVisionPatchMerger(
Expand All @@ -381,6 +403,7 @@ def __init__(
quant_config=quant_config,
use_postshuffle_norm=False,
prefix=add_prefix("merger", prefix),
use_data_parallel=use_data_parallel,
)
self.merger_list = nn.ModuleList(
[
Expand All @@ -391,6 +414,7 @@ def __init__(
use_postshuffle_norm=True,
quant_config=quant_config,
prefix=add_prefix("merger_list", prefix),
use_data_parallel=use_data_parallel,
)
for _ in range(len(config.deepstack_visual_indexes))
]
Expand Down Expand Up @@ -422,12 +446,16 @@ def __init__(
super().__init__(
config, quant_config, prefix, language_model_cls=Qwen3MoeLLMModel
)
self.audio_tower = Qwen3OmniMoeAudioEncoder(config.audio_config)
self.use_data_parallel = get_global_server_args().mm_enable_dp_encoder
self.audio_tower = Qwen3OmniMoeAudioEncoder(
config.audio_config, self.use_data_parallel
)
self.visual = Qwen3OmniMoeVisionEncoder(
config.vision_config,
quant_config=quant_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
prefix=add_prefix("visual", prefix),
use_data_parallel=self.use_data_parallel,
)
self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
Expand Down
1 change: 1 addition & 0 deletions test/nightly/test_encoder_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
register_cuda_ci(est_time=500, suite="nightly-4-gpu", nightly=True)

MODELS = [
SimpleNamespace(model="Qwen/Qwen3-Omni-30B-A3B-Instruct", mmmu_accuracy=0.55),
SimpleNamespace(model="Qwen/Qwen2.5-VL-72B-Instruct", mmmu_accuracy=0.55),
SimpleNamespace(model="Qwen/Qwen3-VL-32B-Instruct", mmmu_accuracy=0.55),
SimpleNamespace(model="OpenGVLab/InternVL2_5-8B", mmmu_accuracy=0.52),
Expand Down
Loading