diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index 77d8e756026d..70b3ee52f886 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -6,15 +6,54 @@ from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer +from vllm.utils.torch_utils import direct_register_custom_op + +if has_flashinfer(): + + def flashinfer_tinygemm_router_gemm_impl( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + ) -> torch.Tensor: + from flashinfer.gemm.routergemm import tinygemm_bf16 + + output = torch.empty( + x.shape[0], + weight.shape[0], + dtype=torch.bfloat16, + device=x.device, + ) + tinygemm_bf16(x, weight, output, bias) + return output + + def flashinfer_tinygemm_router_gemm_fake( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + ) -> torch.Tensor: + return torch.empty( + x.shape[0], + weight.shape[0], + dtype=torch.bfloat16, + device=x.device, + ) + + direct_register_custom_op( + op_name="flashinfer_tinygemm_router_gemm", + op_func=flashinfer_tinygemm_router_gemm_impl, + fake_impl=flashinfer_tinygemm_router_gemm_fake, + ) @PluggableLayer.register("gate_linear") class GateLinear(ReplicatedLinear): - """MoE gate linear layer with three-tier GEMM dispatch: + """MoE gate linear layer with four-tier GEMM dispatch: - 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims) - 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) - 3. F.linear via ReplicatedLinear (ultimate fallback) + 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims, no bias) + 2. cuBLAS bf16xbf16→fp32 (SM90+ + bf16 + fp32 out_dtype, no bias) + 3. Flashinfer tinygemm_bf16 kernel (SM90+, aligned dims, supports bias) + 4. F.linear via ReplicatedLinear (ultimate fallback) The ``out_dtype`` attribute is mutable and can be set after init (e.g. when the required dtype depends on the expert quantization @@ -35,12 +74,11 @@ def __init__( force_fp32_compute: bool = False, prefix: str = "", ): - is_hopper_or_blackwell = current_platform.is_device_capability( - (9, 0) - ) or current_platform.is_device_capability_family(100) - can_use_specialized_kernels = ( - current_platform.is_cuda() and is_hopper_or_blackwell and not bias + is_hopper_or_blackwell = current_platform.is_cuda() and ( + current_platform.is_device_capability((9, 0)) + or current_platform.is_device_capability_family(100) ) + can_use_specialized_kernels = is_hopper_or_blackwell and not bias # If fp32 compute is required and no specialized kernel is available, # store weights in fp32 so Tier 3 computes in fp32 natively. @@ -72,6 +110,16 @@ def __init__( and self.out_dtype == torch.float32 ) + # Flashinfer tinygemm_bf16 (SM90+, aligned dims, supports bias) + self.allow_flashinfer_tinygemm_router_gemm = ( + is_hopper_or_blackwell + and has_flashinfer() + and self.weight.dtype == torch.bfloat16 + and self.out_dtype in [None, torch.bfloat16] + and input_size % 64 == 0 + and output_size % 16 == 0 + ) + def set_out_dtype(self, out_dtype: torch.dtype) -> None: """Set output dtype for the router logits after init. @@ -89,6 +137,10 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None: ): self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16 + # tinygemm outputs bf16 — disable if fp32 output is now required + if out_dtype == torch.float32: + self.allow_flashinfer_tinygemm_router_gemm = False + def forward( self, x: torch.Tensor ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: @@ -108,7 +160,18 @@ def forward( output = ops.router_gemm_bf16_fp32(x, self.weight) return output, None - # Tier 3: F.linear (ReplicatedLinear) + # Tier 3: Flashinfer tinygemm_bf16 + if ( + self.allow_flashinfer_tinygemm_router_gemm + and x.dtype == torch.bfloat16 + and x.shape[0] <= 128 + ): + output = torch.ops.vllm.flashinfer_tinygemm_router_gemm( + x, self.weight, self.bias + ) + return output, None + + # Tier 4: F.linear (ReplicatedLinear) if self.out_dtype is not None and x.dtype != self.weight.dtype: x = x.to(self.weight.dtype) output, output_bias = super().forward(x) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index c3111489c0ca..69c9f7e3bb89 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -20,12 +20,11 @@ tensor_model_parallel_all_gather, ) from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -175,14 +174,13 @@ def __init__( self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.router = ReplicatedLinear( + self.router = GateLinear( config.hidden_size, config.num_local_experts, bias=True, - quant_config=None, prefix=f"{prefix}.router", - return_bias=False, ) + assert config.intermediate_size % self.world_size == 0 self.experts = FusedMoE( num_experts=config.num_local_experts, @@ -209,7 +207,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self, x[:, : self.hidden_size], self.router.weight, self.router.bias ) else: - g = self.router(x) + g, _ = self.router(x) x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size] if self.is_sequence_parallel: