diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5fc9fa073180..7c228cc9083d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -28,7 +28,6 @@ ) from vllm.model_executor.layers.utils import ( dispatch_unquantized_gemm, - is_layer_moe_router_gate, ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -257,11 +256,7 @@ def apply( x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - if ( - vllm_is_batch_invariant() - and current_platform.is_cuda_alike() - and is_layer_moe_router_gate(getattr(layer, "prefix", "")) - ): + if vllm_is_batch_invariant() and current_platform.is_cuda_alike(): return linear_batch_invariant(x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index fd7050861678..ce13048d1e8f 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -23,7 +23,11 @@ from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE @@ -165,7 +169,14 @@ def __init__( self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts) + self.router = ReplicatedLinear( + config.hidden_size, + config.num_local_experts, + bias=True, + quant_config=None, + prefix=f"{prefix}.router", + return_bias=False, + ) assert config.intermediate_size % self.world_size == 0 self.experts = FusedMoE( num_experts=config.num_local_experts,