Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions vllm/model_executor/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import Any

import torch
import torch.nn.functional as F
from torch import nn

from vllm.attention.layer import Attention
Expand All @@ -42,7 +43,7 @@
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
expert_gate: torch.nn.Linear | None = None,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect type hint causes wrong tensor indexing

Low Severity

The expert_gate parameter is typed as torch.nn.Linear | None but is actually a ReplicatedLinear. The code does self.expert_gate(x)[0] to extract the output from a tuple returned by ReplicatedLinear.forward. However, torch.nn.Linear.forward returns a tensor directly, not a tuple, so [0] would incorrectly index into the first dimension of the tensor instead of extracting the output. While the current code works because only ReplicatedLinear is passed, the type annotation is misleading and using torch.nn.Linear as documented would produce silently incorrect results.

Additional Locations (1)

Fix in Cursor Fix in Web

prefix: str = "",
) -> None:
super().__init__()
Expand All @@ -109,12 +111,17 @@ def __init__(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
self.expert_gate = expert_gate

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
out = self.act_fn(gate_up)
out, _ = self.down_proj(out)

if self.expert_gate is not None:
out = F.sigmoid(self.expert_gate(x)[0]) * out

return out


class Qwen3MoeSparseMoeBlock(nn.Module):
Expand Down Expand Up @@ -159,12 +166,46 @@ def __init__(
self.physical_expert_start + self.n_local_physical_experts
)

self.experts = FusedMoE(
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate",
)

shared_expert_intermediate_size = getattr(
config, "shared_expert_intermediate_size", 0
)
if shared_expert_intermediate_size > 0:
self.shared_expert_gate = ReplicatedLinear(
config.hidden_size,
1,
bias=False,
quant_config=None,
prefix=f"{prefix}.shared_expert_gate",
)
self.shared_expert = Qwen3MoeMLP(
hidden_size=config.hidden_size,
intermediate_size=shared_expert_intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
expert_gate=self.shared_expert_gate,
prefix=f"{prefix}.shared_expert",
)
Comment thread
Isotr0py marked this conversation as resolved.
else:
self.shared_expert_gate = None
self.shared_expert = None

self.experts = SharedFusedMoE(
shared_experts=self.shared_expert,
gate=self.gate,
num_experts=self.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=True,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
Expand All @@ -173,14 +214,6 @@ def __init__(
is_sequence_parallel=self.is_sequence_parallel,
)

self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate",
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
assert hidden_states.dim() <= 2, (
"Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs"
Expand All @@ -194,15 +227,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
shared_out, fused_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
final_hidden_states = (
shared_out + fused_out if shared_out is not None else fused_out
)

if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states
)
Comment thread
Isotr0py marked this conversation as resolved.

# return to 1d if input is 1d
return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states
Expand Down Expand Up @@ -467,7 +507,7 @@ def forward(
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
return SharedFusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
Expand Down