From a930163a6dc5ed956ccf5104c545c3142cb7e15d Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Mon, 16 Mar 2026 17:36:39 -0700 Subject: [PATCH 1/3] support Flashinfer trtllm tinygemm_bf16 router gemm for GPT-OSS Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- .../layers/fused_moe/router/gate_linear.py | 83 ++++++++++++++++--- vllm/model_executor/models/gpt_oss.py | 10 +-- 2 files changed, 76 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index 77d8e756026d..7674bcb9d4d2 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -6,15 +6,45 @@ from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer +from vllm.utils.torch_utils import direct_register_custom_op + +if has_flashinfer(): + + def flashinfer_tinygemm_router_gemm_impl( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + out: torch.Tensor, + ) -> None: + from flashinfer.gemm.routergemm import tinygemm_bf16 + + tinygemm_bf16(x, weight, out, bias) + + def flashinfer_tinygemm_router_gemm_fake( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + out: torch.Tensor, + ) -> None: + return + + direct_register_custom_op( + op_name="flashinfer_tinygemm_router_gemm", + op_func=flashinfer_tinygemm_router_gemm_impl, + fake_impl=flashinfer_tinygemm_router_gemm_fake, + mutates_args=["out"], + ) @PluggableLayer.register("gate_linear") class GateLinear(ReplicatedLinear): - """MoE gate linear layer with three-tier GEMM dispatch: + """MoE gate linear layer with four-tier GEMM dispatch: - 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims) - 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) - 3. F.linear via ReplicatedLinear (ultimate fallback) + 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims, no bias) + 2. cuBLAS bf16xbf16→fp32 (SM90+ + bf16 + fp32 out_dtype, no bias) + 3. Flashinfer tinygemm_bf16 kernel (SM90+, aligned dims, supports bias) + 4. F.linear via ReplicatedLinear (ultimate fallback) The ``out_dtype`` attribute is mutable and can be set after init (e.g. when the required dtype depends on the expert quantization @@ -39,7 +69,7 @@ def __init__( (9, 0) ) or current_platform.is_device_capability_family(100) can_use_specialized_kernels = ( - current_platform.is_cuda() and is_hopper_or_blackwell and not bias + current_platform.is_cuda() and is_hopper_or_blackwell ) # If fp32 compute is required and no specialized kernel is available, @@ -57,21 +87,35 @@ def __init__( ) self.out_dtype = out_dtype - # DSV3 specialized kernel eligibility (SM90+, exact dims) self.allow_specialized_router_gemm = can_use_specialized_kernels + self.allow_specialized_router_gemm_no_bias = ( + can_use_specialized_kernels and not bias + ) + + # DSV3 specialized kernel eligibility (SM90+, exact dims, no bias) self.allow_dsv3_router_gemm = ( - self.allow_specialized_router_gemm + self.allow_specialized_router_gemm_no_bias and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES ) - # cuBLAS bf16→fp32 eligibility + # cuBLAS bf16→fp32 eligibility (no bias) self.allow_cublas_router_gemm = ( - self.allow_specialized_router_gemm + self.allow_specialized_router_gemm_no_bias and self.weight.dtype == torch.bfloat16 and self.out_dtype == torch.float32 ) + # Flashinfer tinygemm_bf16 (SM90+, aligned dims, supports bias) + self.allow_tinygemm_router_gemm = ( + self.allow_specialized_router_gemm + and self.weight.dtype == torch.bfloat16 + and self.out_dtype != torch.float32 + and input_size % 64 == 0 + and output_size % 16 == 0 + and has_flashinfer() + ) + def set_out_dtype(self, out_dtype: torch.dtype) -> None: """Set output dtype for the router logits after init. @@ -84,11 +128,15 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None: if ( not self.allow_cublas_router_gemm - and self.allow_specialized_router_gemm + and self.allow_specialized_router_gemm_no_bias and out_dtype == torch.float32 ): self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16 + # tinygemm outputs bf16 — disable if fp32 output is now required + if out_dtype == torch.float32: + self.allow_tinygemm_router_gemm = False + def forward( self, x: torch.Tensor ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: @@ -108,7 +156,20 @@ def forward( output = ops.router_gemm_bf16_fp32(x, self.weight) return output, None - # Tier 3: F.linear (ReplicatedLinear) + # Tier 3: Flashinfer tinygemm_bf16 + if self.allow_tinygemm_router_gemm and x.dtype == torch.bfloat16: + output = torch.empty( + x.shape[0], + self.weight.shape[0], + dtype=torch.bfloat16, + device=x.device, + ) + torch.ops.vllm.flashinfer_tinygemm_router_gemm( + x, self.weight, self.bias, output + ) + return output, None + + # Tier 4: F.linear (ReplicatedLinear) if self.out_dtype is not None and x.dtype != self.weight.dtype: x = x.to(self.weight.dtype) output, output_bias = super().forward(x) diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index c3111489c0ca..69c9f7e3bb89 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -20,12 +20,11 @@ tensor_model_parallel_all_gather, ) from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( QKVParallelLinear, - ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -175,14 +174,13 @@ def __init__( self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.router = ReplicatedLinear( + self.router = GateLinear( config.hidden_size, config.num_local_experts, bias=True, - quant_config=None, prefix=f"{prefix}.router", - return_bias=False, ) + assert config.intermediate_size % self.world_size == 0 self.experts = FusedMoE( num_experts=config.num_local_experts, @@ -209,7 +207,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self, x[:, : self.hidden_size], self.router.weight, self.router.bias ) else: - g = self.router(x) + g, _ = self.router(x) x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size] if self.is_sequence_parallel: From c174844d658600dc749d89cf06d456f229293ccb Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Tue, 17 Mar 2026 09:33:40 +0800 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/router/gate_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index 7674bcb9d4d2..751d9839b45b 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -14,7 +14,7 @@ def flashinfer_tinygemm_router_gemm_impl( x: torch.Tensor, weight: torch.Tensor, - bias: torch.Tensor, + bias: torch.Tensor | None, out: torch.Tensor, ) -> None: from flashinfer.gemm.routergemm import tinygemm_bf16 @@ -24,7 +24,7 @@ def flashinfer_tinygemm_router_gemm_impl( def flashinfer_tinygemm_router_gemm_fake( x: torch.Tensor, weight: torch.Tensor, - bias: torch.Tensor, + bias: torch.Tensor | None, out: torch.Tensor, ) -> None: return From c4da2c7e5393dd8e3834bdfe75937b868e002a03 Mon Sep 17 00:00:00 2001 From: elvischenv <219235043+elvischenv@users.noreply.github.com> Date: Wed, 18 Mar 2026 02:26:00 -0700 Subject: [PATCH 3/3] add bs limitation Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> --- .../layers/fused_moe/router/gate_linear.py | 72 ++++++++++--------- 1 file changed, 37 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py index 751d9839b45b..70b3ee52f886 100644 --- a/vllm/model_executor/layers/fused_moe/router/gate_linear.py +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -15,25 +15,34 @@ def flashinfer_tinygemm_router_gemm_impl( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None, - out: torch.Tensor, - ) -> None: + ) -> torch.Tensor: from flashinfer.gemm.routergemm import tinygemm_bf16 - tinygemm_bf16(x, weight, out, bias) + output = torch.empty( + x.shape[0], + weight.shape[0], + dtype=torch.bfloat16, + device=x.device, + ) + tinygemm_bf16(x, weight, output, bias) + return output def flashinfer_tinygemm_router_gemm_fake( x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None, - out: torch.Tensor, - ) -> None: - return + ) -> torch.Tensor: + return torch.empty( + x.shape[0], + weight.shape[0], + dtype=torch.bfloat16, + device=x.device, + ) direct_register_custom_op( op_name="flashinfer_tinygemm_router_gemm", op_func=flashinfer_tinygemm_router_gemm_impl, fake_impl=flashinfer_tinygemm_router_gemm_fake, - mutates_args=["out"], ) @@ -65,12 +74,11 @@ def __init__( force_fp32_compute: bool = False, prefix: str = "", ): - is_hopper_or_blackwell = current_platform.is_device_capability( - (9, 0) - ) or current_platform.is_device_capability_family(100) - can_use_specialized_kernels = ( - current_platform.is_cuda() and is_hopper_or_blackwell + is_hopper_or_blackwell = current_platform.is_cuda() and ( + current_platform.is_device_capability((9, 0)) + or current_platform.is_device_capability_family(100) ) + can_use_specialized_kernels = is_hopper_or_blackwell and not bias # If fp32 compute is required and no specialized kernel is available, # store weights in fp32 so Tier 3 computes in fp32 natively. @@ -87,33 +95,29 @@ def __init__( ) self.out_dtype = out_dtype + # DSV3 specialized kernel eligibility (SM90+, exact dims) self.allow_specialized_router_gemm = can_use_specialized_kernels - self.allow_specialized_router_gemm_no_bias = ( - can_use_specialized_kernels and not bias - ) - - # DSV3 specialized kernel eligibility (SM90+, exact dims, no bias) self.allow_dsv3_router_gemm = ( - self.allow_specialized_router_gemm_no_bias + self.allow_specialized_router_gemm and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES ) - # cuBLAS bf16→fp32 eligibility (no bias) + # cuBLAS bf16→fp32 eligibility self.allow_cublas_router_gemm = ( - self.allow_specialized_router_gemm_no_bias + self.allow_specialized_router_gemm and self.weight.dtype == torch.bfloat16 and self.out_dtype == torch.float32 ) # Flashinfer tinygemm_bf16 (SM90+, aligned dims, supports bias) - self.allow_tinygemm_router_gemm = ( - self.allow_specialized_router_gemm + self.allow_flashinfer_tinygemm_router_gemm = ( + is_hopper_or_blackwell + and has_flashinfer() and self.weight.dtype == torch.bfloat16 - and self.out_dtype != torch.float32 + and self.out_dtype in [None, torch.bfloat16] and input_size % 64 == 0 and output_size % 16 == 0 - and has_flashinfer() ) def set_out_dtype(self, out_dtype: torch.dtype) -> None: @@ -128,14 +132,14 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None: if ( not self.allow_cublas_router_gemm - and self.allow_specialized_router_gemm_no_bias + and self.allow_specialized_router_gemm and out_dtype == torch.float32 ): self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16 # tinygemm outputs bf16 — disable if fp32 output is now required if out_dtype == torch.float32: - self.allow_tinygemm_router_gemm = False + self.allow_flashinfer_tinygemm_router_gemm = False def forward( self, x: torch.Tensor @@ -157,15 +161,13 @@ def forward( return output, None # Tier 3: Flashinfer tinygemm_bf16 - if self.allow_tinygemm_router_gemm and x.dtype == torch.bfloat16: - output = torch.empty( - x.shape[0], - self.weight.shape[0], - dtype=torch.bfloat16, - device=x.device, - ) - torch.ops.vllm.flashinfer_tinygemm_router_gemm( - x, self.weight, self.bias, output + if ( + self.allow_flashinfer_tinygemm_router_gemm + and x.dtype == torch.bfloat16 + and x.shape[0] <= 128 + ): + output = torch.ops.vllm.flashinfer_tinygemm_router_gemm( + x, self.weight, self.bias ) return output, None