From 2fed0ffb8e016a8cd86ead99fbd61950ff757d4d Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 23 Apr 2026 09:01:06 +0800 Subject: [PATCH] fix Qwen3 MoE call gate twice Signed-off-by: Kunshang Ji --- vllm/model_executor/models/qwen3_moe.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 6f080d07795e..520126718fdc 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -231,11 +231,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + # Actually this will be dead code, since we always pass gate into + # FusedMoE in the current implementation. But we keep this code + # here for clarity and future flexibility. + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather(