diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 8e49ccea5fd4..567c031938f5 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -29,6 +29,7 @@ from typing import Any import torch +import torch.nn.functional as F from torch import nn from vllm.attention.layer import Attention @@ -42,7 +43,7 @@ ) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, @@ -86,6 +87,7 @@ def __init__( hidden_act: str, quant_config: QuantizationConfig | None = None, reduce_results: bool = True, + expert_gate: torch.nn.Linear | None = None, prefix: str = "", ) -> None: super().__init__() @@ -109,12 +111,17 @@ def __init__( f"Unsupported activation: {hidden_act}. Only silu is supported for now." ) self.act_fn = SiluAndMul() + self.expert_gate = expert_gate def forward(self, x): gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x + out = self.act_fn(gate_up) + out, _ = self.down_proj(out) + + if self.expert_gate is not None: + out = F.sigmoid(self.expert_gate(x)[0]) * out + + return out class Qwen3MoeSparseMoeBlock(nn.Module): @@ -159,12 +166,46 @@ def __init__( self.physical_expert_start + self.n_local_physical_experts ) - self.experts = FusedMoE( + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate", + ) + + shared_expert_intermediate_size = getattr( + config, "shared_expert_intermediate_size", 0 + ) + if shared_expert_intermediate_size > 0: + self.shared_expert_gate = ReplicatedLinear( + config.hidden_size, + 1, + bias=False, + quant_config=None, + prefix=f"{prefix}.shared_expert_gate", + ) + self.shared_expert = Qwen3MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + expert_gate=self.shared_expert_gate, + prefix=f"{prefix}.shared_expert", + ) + else: + self.shared_expert_gate = None + self.shared_expert = None + + self.experts = SharedFusedMoE( + shared_experts=self.shared_expert, + gate=self.gate, num_experts=self.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=True, + reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=f"{prefix}.experts", @@ -173,14 +214,6 @@ def __init__( is_sequence_parallel=self.is_sequence_parallel, ) - self.gate = ReplicatedLinear( - config.hidden_size, - config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate", - ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: assert hidden_states.dim() <= 2, ( "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" @@ -194,15 +227,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( + shared_out, fused_out = self.experts( hidden_states=hidden_states, router_logits=router_logits ) + final_hidden_states = ( + shared_out + fused_out if shared_out is not None else fused_out + ) if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states, 0 ) final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 + final_hidden_states + ) # return to 1d if input is 1d return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states @@ -467,7 +507,7 @@ def forward( def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - return FusedMoE.make_expert_params_mapping( + return SharedFusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj",