Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from vllm.model_executor.layers.utils import (
dispatch_unquantized_gemm,
is_layer_moe_router_gate,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
Expand Down Expand Up @@ -257,11 +256,7 @@ def apply(
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if (
vllm_is_batch_invariant()
and current_platform.is_cuda_alike()
and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
):
if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)

Expand Down
15 changes: 13 additions & 2 deletions vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
Expand Down Expand Up @@ -165,7 +169,14 @@ def __init__(
self.hidden_size = config.hidden_size
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
self.router = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=True,
quant_config=None,
prefix=f"{prefix}.router",
return_bias=False,
)
assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(
num_experts=config.num_local_experts,
Expand Down