From 27fe52217a449e1c08b010bdd94c9d985a248b74 Mon Sep 17 00:00:00 2001 From: Fr4nk1in Date: Thu, 8 May 2025 10:12:04 +0000 Subject: [PATCH 1/3] feat: add dp attention support for Qwen 2/3 MoE models, fixes #6088 This is the prerequisites of EP --- python/sglang/srt/models/qwen2_moe.py | 241 ++++++++++++++++++++++---- python/sglang/srt/models/qwen3_moe.py | 237 ++++++++++++++++++++++--- 2 files changed, 423 insertions(+), 55 deletions(-) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 7a0cece1735..72c60357d7f 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -16,6 +16,8 @@ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" +from dataclasses import dataclass +from enum import Enum, auto from typing import Any, Dict, Iterable, Optional, Tuple import torch @@ -28,6 +30,15 @@ tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.dp_attention import ( + dp_gather_partial, + dp_scatter, + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, + tp_all_gather, + tp_reduce_scatter, +) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -35,7 +46,7 @@ ReplicatedLinear, RowParallelLinear, ) -from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -82,8 +93,7 @@ def __init__( ) if hidden_act != "silu": raise ValueError( - f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now." + f"Unsupported activation: {hidden_act}. Only silu is supported for now." ) self.act_fn = SiluAndMul() @@ -160,7 +170,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output - if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(num_tokens, hidden_dim) @@ -182,20 +191,23 @@ def __init__( ) -> None: super().__init__() self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() + + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + assert self.total_num_heads % attn_tp_size == 0 + self.num_heads = self.total_num_heads // attn_tp_size self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: + if self.total_num_kv_heads >= attn_tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_kv_heads % attn_tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + assert attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -210,6 +222,8 @@ def __init__( self.total_num_kv_heads, bias=qkv_bias, quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, prefix=add_prefix("qkv_proj", prefix), ) @@ -218,6 +232,8 @@ def __init__( hidden_size, bias=False, quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, prefix=add_prefix("o_proj", prefix), ) @@ -252,6 +268,19 @@ def forward( return output +class _FFNInputMode(Enum): + # The MLP sublayer requires 1/tp_size tokens as input + SCATTERED = auto() + # The MLP sublayer requires all tokens as input + FULL = auto() + + +@dataclass +class _DecoderLayerInfo: + is_sparse: bool + ffn_input_mode: _FFNInputMode + + class Qwen2MoeDecoderLayer(nn.Module): def __init__( self, @@ -279,14 +308,20 @@ def __init__( prefix=add_prefix("self_attn", prefix), ) - # Note: Qwen/Qwen2-57B-A14B-Instruct does not have - # `mlp_only_layers` in the config. - mlp_only_layers = ( - [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + self.layer_id = layer_id + + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + self.dp_size = get_attention_dp_size() + + self.info = self._compute_info(config, layer_id=layer_id) + previous_layer_info = self._compute_info(config, layer_id=layer_id - 1) + self.input_is_scattered = ( + previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED ) - if (layer_id not in mlp_only_layers) and ( - config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 - ): + self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 + + if self.info.is_sparse: self.mlp = Qwen2MoeSparseMoeBlock( config=config, quant_config=quant_config, @@ -305,28 +340,175 @@ def __init__( config.hidden_size, eps=config.rms_norm_eps ) + @staticmethod + def _enable_moe_dense_fully_dp(): + return global_server_args_dict["moe_dense_tp_size"] == 1 + + @staticmethod + def _compute_info(config: PretrainedConfig, layer_id: int): + # Note: Qwen/Qwen2-57B-A14B-Instruct does not have + # `mlp_only_layers` in the config. + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) + is_sparse = (layer_id not in mlp_only_layers) and ( + config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 + ) + # WARN: DeepEP MoE is not supported for Qwen MoE models for now. + ffn_input_mode = ( + _FFNInputMode.SCATTERED + if (global_server_args_dict["enable_deepep_moe"] and is_sparse) + or (Qwen2MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) + else _FFNInputMode.FULL + ) + return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode) + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor], - ) -> torch.Tensor: - # Self Attention - if residual is None: + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.info.ffn_input_mode == _FFNInputMode.SCATTERED: + return self.forward_ffn_with_scattered_input( + positions, hidden_states, forward_batch, residual + ) + elif self.info.ffn_input_mode == _FFNInputMode.FULL: + return self.forward_ffn_with_full_input( + positions, hidden_states, forward_batch, residual + ) + else: + raise NotImplementedError + + def forward_ffn_with_full_input( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if hidden_states.shape[0] == 0: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + # Gather + if get_tensor_model_parallel_world_size() > 1 and self.dp_size != 1: + if self.attn_tp_rank == 0: + hidden_states += residual + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer, + hidden_states, + ) + dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + dp_scatter(residual, hidden_states, forward_batch) + hidden_states = self.post_attention_layernorm(hidden_states) + elif hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + # Fully Connected + hidden_states = self.mlp(hidden_states) + + # TODO: use reduce-scatter in MLP to avoid this scatter + # Scatter + if self.dp_size != 1: + # important: forward batch.gathered_buffer is used both after scatter and after gather. + # be careful about this! + hidden_states, global_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, + ) + dp_scatter(hidden_states, global_hidden_states, forward_batch) + + return hidden_states, residual + + def forward_ffn_with_scattered_input( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if hidden_states.shape[0] == 0: + residual = hidden_states + else: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + if self.attn_tp_size != 1 and self.input_is_scattered: + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, + ) + tp_all_gather( + list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states + ) + + # Self Attention hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, ) - # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - hidden_states = self.mlp(hidden_states) + if self.attn_tp_size != 1: + if self.input_is_scattered: + tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) + hidden_states = tensor_list[self.attn_tp_rank] + tp_reduce_scatter(hidden_states, tensor_list) + if hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + else: + if self.attn_tp_rank == 0: + hidden_states += residual + tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) + hidden_states = tensor_list[self.attn_tp_rank] + tp_reduce_scatter(hidden_states, tensor_list) + residual = hidden_states + if hidden_states.shape[0] != 0: + hidden_states = self.post_attention_layernorm(hidden_states) + else: + if hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + if not ( + self._enable_moe_dense_fully_dp() + and (not self.info.is_sparse) + and hidden_states.shape[0] == 0 + ): + hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) + + if self.is_last_layer and self.attn_tp_size != 1: + hidden_states += residual + residual = None + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, + ) + tp_all_gather( + list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states + ) + return hidden_states, residual @@ -345,6 +527,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + enable_tp=not global_server_args_dict["enable_dp_attention"], prefix=add_prefix("embed_tokens", prefix), ) # Use the provided decoder layer type or default to Qwen2MoeDecoderLayer @@ -379,12 +562,12 @@ def forward( hidden_states, residual = layer( positions, hidden_states, forward_batch, residual ) - hidden_states, _ = self.norm(hidden_states, residual) + if hidden_states.shape[0] != 0: + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class Qwen2MoeForCausalLM(nn.Module): - fall_back_to_pt_during_load = False def __init__( @@ -414,7 +597,7 @@ def forward( positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> torch.Tensor: + ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 16e8f377c10..689e6ba9b77 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -17,12 +17,15 @@ """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" +from dataclasses import dataclass +from enum import Enum, auto from functools import partial from typing import Any, Dict, Iterable, Optional, Tuple import torch import torch.nn.functional as F from torch import nn +from transformers.configuration_utils import PretrainedConfig from sglang.srt.distributed import ( get_tensor_model_parallel_rank, @@ -32,6 +35,15 @@ tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.dp_attention import ( + dp_gather_partial, + dp_scatter, + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, + tp_all_gather, + tp_reduce_scatter, +) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -39,7 +51,7 @@ ReplicatedLinear, RowParallelLinear, ) -from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -128,20 +140,23 @@ def __init__( ) -> None: super().__init__() self.hidden_size = hidden_size - self.tp_size = get_tensor_model_parallel_world_size() + + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() + self.total_num_heads = num_heads - assert self.total_num_heads % self.tp_size == 0 - self.num_heads = self.total_num_heads // self.tp_size + assert self.total_num_heads % attn_tp_size == 0 + self.num_heads = self.total_num_heads // attn_tp_size self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= self.tp_size: + if self.total_num_kv_heads >= attn_tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % self.tp_size == 0 + assert self.total_num_kv_heads % attn_tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert self.tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + assert attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size) self.head_dim = head_dim or hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -157,6 +172,8 @@ def __init__( self.total_num_kv_heads, bias=attention_bias, quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, prefix=add_prefix("qkv_proj", prefix), ) @@ -165,6 +182,8 @@ def __init__( hidden_size, bias=attention_bias, quant_config=quant_config, + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, prefix=add_prefix("o_proj", prefix), ) @@ -206,13 +225,27 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self._apply_qk_norm(q, k) + if q.shape[0] != 0: + q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output +class _FFNInputMode(Enum): + # The MLP sublayer requires 1/tp_size tokens as input + SCATTERED = auto() + # The MLP sublayer requires all tokens as input + FULL = auto() + + +@dataclass +class _DecoderLayerInfo: + is_sparse: bool + ffn_input_mode: _FFNInputMode + + class Qwen3MoeDecoderLayer(nn.Module): def __init__( self, @@ -246,14 +279,20 @@ def __init__( prefix=add_prefix("self_attn", prefix), ) - # Note: Qwen/Qwen2-57B-A14B-Instruct does not have - # `mlp_only_layers` in the config. - mlp_only_layers = ( - [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + self.layer_id = layer_id + + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + self.dp_size = get_attention_dp_size() + + self.info = self._compute_info(config, layer_id=layer_id) + previous_layer_info = self._compute_info(config, layer_id=layer_id - 1) + self.input_is_scattered = ( + previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED ) - if (layer_id not in mlp_only_layers) and ( - config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 - ): + self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 + + if self.info.is_sparse: self.mlp = Qwen3MoeSparseMoeBlock( config=config, quant_config=quant_config, @@ -272,28 +311,175 @@ def __init__( config.hidden_size, eps=config.rms_norm_eps ) + @staticmethod + def _enable_moe_dense_fully_dp(): + return global_server_args_dict["moe_dense_tp_size"] == 1 + + @staticmethod + def _compute_info(config: PretrainedConfig, layer_id: int): + # Note: Qwen/Qwen2-57B-A14B-Instruct does not have + # `mlp_only_layers` in the config. + mlp_only_layers = ( + [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers + ) + is_sparse = (layer_id not in mlp_only_layers) and ( + config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 + ) + # WARN: DeepEP MoE is not supported for Qwen MoE models for now. + ffn_input_mode = ( + _FFNInputMode.SCATTERED + if (global_server_args_dict["enable_deepep_moe"] and is_sparse) + or (Qwen3MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse) + else _FFNInputMode.FULL + ) + return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode) + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, residual: Optional[torch.Tensor], - ) -> torch.Tensor: - # Self Attention - if residual is None: + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.info.ffn_input_mode == _FFNInputMode.SCATTERED: + return self.forward_ffn_with_scattered_input( + positions, hidden_states, forward_batch, residual + ) + elif self.info.ffn_input_mode == _FFNInputMode.FULL: + return self.forward_ffn_with_full_input( + positions, hidden_states, forward_batch, residual + ) + else: + raise NotImplementedError + + def forward_ffn_with_full_input( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if hidden_states.shape[0] == 0: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + # Gather + if get_tensor_model_parallel_world_size() > 1 and self.dp_size != 1: + if self.attn_tp_rank == 0: + hidden_states += residual + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer, + hidden_states, + ) + dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + dp_scatter(residual, hidden_states, forward_batch) + hidden_states = self.post_attention_layernorm(hidden_states) + elif hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + # Fully Connected + hidden_states = self.mlp(hidden_states) + + # TODO: use reduce-scatter in MLP to avoid this scatter + # Scatter + if self.dp_size != 1: + # important: forward batch.gathered_buffer is used both after scatter and after gather. + # be careful about this! + hidden_states, global_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, + ) + dp_scatter(hidden_states, global_hidden_states, forward_batch) + + return hidden_states, residual + + def forward_ffn_with_scattered_input( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if hidden_states.shape[0] == 0: + residual = hidden_states + else: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + if self.attn_tp_size != 1 and self.input_is_scattered: + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, + ) + tp_all_gather( + list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states + ) + + # Self Attention hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, ) - # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) - hidden_states = self.mlp(hidden_states) + if self.attn_tp_size != 1: + if self.input_is_scattered: + tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) + hidden_states = tensor_list[self.attn_tp_rank] + tp_reduce_scatter(hidden_states, tensor_list) + if hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + else: + if self.attn_tp_rank == 0: + hidden_states += residual + tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) + hidden_states = tensor_list[self.attn_tp_rank] + tp_reduce_scatter(hidden_states, tensor_list) + residual = hidden_states + if hidden_states.shape[0] != 0: + hidden_states = self.post_attention_layernorm(hidden_states) + else: + if hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + + if not ( + self._enable_moe_dense_fully_dp() + and (not self.info.is_sparse) + and hidden_states.shape[0] == 0 + ): + hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) + + if self.is_last_layer and self.attn_tp_size != 1: + hidden_states += residual + residual = None + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], + hidden_states, + ) + tp_all_gather( + list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states + ) + return hidden_states, residual @@ -313,7 +499,6 @@ def __init__( class Qwen3MoeForCausalLM(nn.Module): - fall_back_to_pt_during_load = False def __init__( @@ -343,7 +528,7 @@ def forward( positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> torch.Tensor: + ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch From bbc0bf2170a30d7b2bda2d881809a6c00be2922a Mon Sep 17 00:00:00 2001 From: "King.Zevin" Date: Mon, 12 May 2025 12:12:38 +0000 Subject: [PATCH 2/3] fix: server hangs when attn_tp_size != 1 --- python/sglang/srt/models/qwen2_moe.py | 1 + python/sglang/srt/models/qwen3_moe.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 72c60357d7f..f7e8de70662 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -234,6 +234,7 @@ def __init__( quant_config=quant_config, tp_rank=attn_tp_rank, tp_size=attn_tp_size, + reduce_results=not global_server_args_dict["enable_dp_attention"], prefix=add_prefix("o_proj", prefix), ) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 689e6ba9b77..ffc90284eb6 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -184,6 +184,7 @@ def __init__( quant_config=quant_config, tp_rank=attn_tp_rank, tp_size=attn_tp_size, + reduce_results=not global_server_args_dict["enable_dp_attention"], prefix=add_prefix("o_proj", prefix), ) From 1674d854b2f7b79ca21c7548195b763d1ef90cb5 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Wed, 14 May 2025 18:03:21 +0000 Subject: [PATCH 3/3] fix some bugs, keep the same format as deepseek for refractor --- python/sglang/bench_one_batch.py | 1 + python/sglang/srt/layers/dp_attention.py | 10 ---- python/sglang/srt/models/qwen2_moe.py | 69 ++++++++++++++---------- python/sglang/srt/models/qwen3_moe.py | 69 +++++++++++++----------- 4 files changed, 79 insertions(+), 70 deletions(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index f8c67c8f4e4..1ddb36c4846 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -269,6 +269,7 @@ def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner): batch, dp_size=model_runner.server_args.dp_size, attn_tp_size=1, + moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size, tp_cpu_group=model_runner.tp_group.cpu_group, get_idle_batch=None, disable_cuda_graph=model_runner.server_args.disable_cuda_graph, diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 6de1797ae1b..5fa7ce09234 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -142,16 +142,6 @@ def get_local_attention_dp_size(): return _LOCAL_ATTN_DP_SIZE -def get_local_attention_dp_rank(): - assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!" - return _LOCAL_ATTN_DP_RANK - - -def get_local_attention_dp_size(): - assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!" - return _LOCAL_ATTN_DP_SIZE - - @contextmanager def disable_dp_size(): """Patch the tp group temporarily until this function ends. diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index f7e8de70662..0855dd8ae83 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -31,13 +31,13 @@ ) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather, + attn_tp_reduce_scatter, dp_gather_partial, dp_scatter, - get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, - tp_all_gather, - tp_reduce_scatter, + get_local_attention_dp_size, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -234,7 +234,7 @@ def __init__( quant_config=quant_config, tp_rank=attn_tp_rank, tp_size=attn_tp_size, - reduce_results=not global_server_args_dict["enable_dp_attention"], + reduce_results=False, prefix=add_prefix("o_proj", prefix), ) @@ -313,12 +313,13 @@ def __init__( self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() - self.dp_size = get_attention_dp_size() + self.local_dp_size = get_local_attention_dp_size() self.info = self._compute_info(config, layer_id=layer_id) previous_layer_info = self._compute_info(config, layer_id=layer_id - 1) self.input_is_scattered = ( - previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED + layer_id > 0 + and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 @@ -347,15 +348,13 @@ def _enable_moe_dense_fully_dp(): @staticmethod def _compute_info(config: PretrainedConfig, layer_id: int): - # Note: Qwen/Qwen2-57B-A14B-Instruct does not have - # `mlp_only_layers` in the config. + # WARN: Qwen2MOE has no dense_layer, it is only for compatibility. mlp_only_layers = ( [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers ) is_sparse = (layer_id not in mlp_only_layers) and ( config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 ) - # WARN: DeepEP MoE is not supported for Qwen MoE models for now. ffn_input_mode = ( _FFNInputMode.SCATTERED if (global_server_args_dict["enable_deepep_moe"] and is_sparse) @@ -405,16 +404,27 @@ def forward_ffn_with_full_input( forward_batch=forward_batch, ) # Gather - if get_tensor_model_parallel_world_size() > 1 and self.dp_size != 1: - if self.attn_tp_rank == 0: - hidden_states += residual - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer, - hidden_states, - ) - dp_gather_partial(hidden_states, local_hidden_states, forward_batch) - dp_scatter(residual, hidden_states, forward_batch) - hidden_states = self.post_attention_layernorm(hidden_states) + if get_tensor_model_parallel_world_size() > 1: + # all gather and all reduce + if self.local_dp_size != 1: + if self.attn_tp_rank == 0: + hidden_states += residual + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer, + hidden_states, + ) + dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + dp_scatter(residual, hidden_states, forward_batch) + # TODO extract this bugfix + if hidden_states.shape[0] != 0: + hidden_states = self.post_attention_layernorm(hidden_states) + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + # TODO extract this bugfix + if hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) elif hidden_states.shape[0] != 0: hidden_states, residual = self.post_attention_layernorm( hidden_states, residual @@ -425,7 +435,7 @@ def forward_ffn_with_full_input( # TODO: use reduce-scatter in MLP to avoid this scatter # Scatter - if self.dp_size != 1: + if self.local_dp_size != 1: # important: forward batch.gathered_buffer is used both after scatter and after gather. # be careful about this! hidden_states, global_hidden_states = ( @@ -457,22 +467,23 @@ def forward_ffn_with_scattered_input( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - tp_all_gather( + attn_tp_all_gather( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states ) # Self Attention - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) + if hidden_states.shape[0] != 0: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) if self.attn_tp_size != 1: if self.input_is_scattered: tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) hidden_states = tensor_list[self.attn_tp_rank] - tp_reduce_scatter(hidden_states, tensor_list) + attn_tp_reduce_scatter(hidden_states, tensor_list) if hidden_states.shape[0] != 0: hidden_states, residual = self.post_attention_layernorm( hidden_states, residual @@ -482,7 +493,7 @@ def forward_ffn_with_scattered_input( hidden_states += residual tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) hidden_states = tensor_list[self.attn_tp_rank] - tp_reduce_scatter(hidden_states, tensor_list) + attn_tp_reduce_scatter(hidden_states, tensor_list) residual = hidden_states if hidden_states.shape[0] != 0: hidden_states = self.post_attention_layernorm(hidden_states) @@ -506,7 +517,7 @@ def forward_ffn_with_scattered_input( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - tp_all_gather( + attn_tp_all_gather( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states ) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index ffc90284eb6..7f841bf37ff 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -36,13 +36,13 @@ ) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather, + attn_tp_reduce_scatter, dp_gather_partial, dp_scatter, - get_attention_dp_size, get_attention_tp_rank, get_attention_tp_size, - tp_all_gather, - tp_reduce_scatter, + get_local_attention_dp_size, ) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -184,7 +184,7 @@ def __init__( quant_config=quant_config, tp_rank=attn_tp_rank, tp_size=attn_tp_size, - reduce_results=not global_server_args_dict["enable_dp_attention"], + reduce_results=False, prefix=add_prefix("o_proj", prefix), ) @@ -226,8 +226,7 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - if q.shape[0] != 0: - q, k = self._apply_qk_norm(q, k) + q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) @@ -284,12 +283,13 @@ def __init__( self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() - self.dp_size = get_attention_dp_size() + self.local_dp_size = get_local_attention_dp_size() self.info = self._compute_info(config, layer_id=layer_id) previous_layer_info = self._compute_info(config, layer_id=layer_id - 1) self.input_is_scattered = ( - previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED + layer_id > 0 + and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED ) self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 @@ -318,15 +318,13 @@ def _enable_moe_dense_fully_dp(): @staticmethod def _compute_info(config: PretrainedConfig, layer_id: int): - # Note: Qwen/Qwen2-57B-A14B-Instruct does not have - # `mlp_only_layers` in the config. + # WARN: Qwen3MOE has no dense_layer, it is only for compatibility. mlp_only_layers = ( [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers ) is_sparse = (layer_id not in mlp_only_layers) and ( config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 ) - # WARN: DeepEP MoE is not supported for Qwen MoE models for now. ffn_input_mode = ( _FFNInputMode.SCATTERED if (global_server_args_dict["enable_deepep_moe"] and is_sparse) @@ -376,16 +374,24 @@ def forward_ffn_with_full_input( forward_batch=forward_batch, ) # Gather - if get_tensor_model_parallel_world_size() > 1 and self.dp_size != 1: - if self.attn_tp_rank == 0: - hidden_states += residual - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer, - hidden_states, - ) - dp_gather_partial(hidden_states, local_hidden_states, forward_batch) - dp_scatter(residual, hidden_states, forward_batch) - hidden_states = self.post_attention_layernorm(hidden_states) + if get_tensor_model_parallel_world_size() > 1: + if self.local_dp_size != 1: + if self.attn_tp_rank == 0: + hidden_states += residual + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer, + hidden_states, + ) + dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + dp_scatter(residual, hidden_states, forward_batch) + hidden_states = self.post_attention_layernorm(hidden_states) + else: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + # TODO extract this bugfix + if hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) elif hidden_states.shape[0] != 0: hidden_states, residual = self.post_attention_layernorm( hidden_states, residual @@ -396,7 +402,7 @@ def forward_ffn_with_full_input( # TODO: use reduce-scatter in MLP to avoid this scatter # Scatter - if self.dp_size != 1: + if self.local_dp_size != 1: # important: forward batch.gathered_buffer is used both after scatter and after gather. # be careful about this! hidden_states, global_hidden_states = ( @@ -428,22 +434,23 @@ def forward_ffn_with_scattered_input( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - tp_all_gather( + attn_tp_all_gather( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states ) # Self Attention - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) + if hidden_states.shape[0] != 0: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) if self.attn_tp_size != 1: if self.input_is_scattered: tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) hidden_states = tensor_list[self.attn_tp_rank] - tp_reduce_scatter(hidden_states, tensor_list) + attn_tp_reduce_scatter(hidden_states, tensor_list) if hidden_states.shape[0] != 0: hidden_states, residual = self.post_attention_layernorm( hidden_states, residual @@ -453,7 +460,7 @@ def forward_ffn_with_scattered_input( hidden_states += residual tensor_list = list(hidden_states.tensor_split(self.attn_tp_size)) hidden_states = tensor_list[self.attn_tp_rank] - tp_reduce_scatter(hidden_states, tensor_list) + attn_tp_reduce_scatter(hidden_states, tensor_list) residual = hidden_states if hidden_states.shape[0] != 0: hidden_states = self.post_attention_layernorm(hidden_states) @@ -477,7 +484,7 @@ def forward_ffn_with_scattered_input( forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - tp_all_gather( + attn_tp_all_gather( list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states )