Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions vllm/model_executor/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# Actually this will be dead code, since we always pass gate into
# FusedMoE in the current implementation. But we keep this code
# here for clarity and future flexibility.
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)

if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
Expand Down
Loading