Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
from transformers import PretrainedConfig

from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import (
LayerCommunicator,
Expand All @@ -50,6 +52,7 @@
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
Expand Down Expand Up @@ -82,6 +85,8 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
Expand All @@ -90,6 +95,8 @@ def __init__(
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
self.down_proj = RowParallelLinear(
intermediate_size,
Expand All @@ -98,6 +105,8 @@ def __init__(
quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("down_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
if hidden_act != "silu":
raise ValueError(
Expand Down Expand Up @@ -146,7 +155,8 @@ def __init__(
self.experts = get_moe_impl_class(quant_config)(
layer_id=self.layer_id,
top_k=config.num_experts_per_tok,
num_experts=config.num_experts,
num_experts=config.num_experts
+ global_server_args_dict["ep_num_redundant_experts"],
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
Expand All @@ -168,11 +178,31 @@ def __init__(
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_expert", prefix),
**(
dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep()
else {}
),
)
else:
self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)

if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
)
self.top_k = config.num_experts_per_tok
Comment on lines +191 to +197
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code appears to be redundant. The attributes self.ep_size, self.num_experts, and self.top_k are assigned but are not used within the class. The values for num_experts and top_k were already used during the initialization of self.experts and self.topk respectively. If this code is for future use as hinted by the TODO, it should be commented out. Otherwise, it can be removed to improve code clarity.


def get_moe_weights(self):
return [
x.data
for name, x in self.experts.named_parameters()
if name not in ["correction_bias"]
]

def _forward_shared_experts(self, hidden_states: torch.Tensor):
shared_output = None
if self.shared_expert is not None:
Expand All @@ -183,6 +213,36 @@ def _forward_shared_experts(self, hidden_states: torch.Tensor):
)
return shared_output

def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
shared_output = None
if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk(
hidden_states,
router_logits,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
hidden_states.device
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
forward_batch=forward_batch,
)

if shared_output is not None:
final_hidden_states.add_(shared_output)

return final_hidden_states

def _forward_router_experts(self, hidden_states: torch.Tensor):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
Expand Down Expand Up @@ -213,6 +273,9 @@ def forward(
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

if get_moe_a2a_backend().is_deepep():
return self._forward_deepep(hidden_states, forward_batch)

DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self.alt_stream is not None
Expand Down
37 changes: 29 additions & 8 deletions python/sglang/srt/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
Expand Down Expand Up @@ -46,7 +47,14 @@
sharded_weight_loader,
)
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs
from sglang.srt.utils import (
LazyValue,
add_prefix,
is_cuda,
is_npu,
make_layers,
set_weight_attrs,
)

logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
Expand Down Expand Up @@ -849,13 +857,14 @@ def forward(
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
layer_id=i,
positions=positions,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
with get_global_expert_distribution_recorder().with_current_layer(i):
hidden_states, residual = layer(
layer_id=i,
positions=positions,
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)

if not forward_batch.forward_mode.is_idle():
if residual is None:
Expand Down Expand Up @@ -901,6 +910,18 @@ def __init__(
self.lm_head = self.lm_head.float()
self.logits_processor = LogitsProcessor(config)

self._routed_experts_weights_of_layer = LazyValue(
lambda: {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock)
}
)

@property
def routed_experts_weights_of_layer(self):
return self._routed_experts_weights_of_layer.value

@torch.no_grad()
def forward(
self,
Expand Down
Loading