Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 73 additions & 10 deletions vllm/model_executor/layers/fused_moe/router/gate_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,54 @@
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import direct_register_custom_op

if has_flashinfer():

def flashinfer_tinygemm_router_gemm_impl(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
) -> torch.Tensor:
from flashinfer.gemm.routergemm import tinygemm_bf16

output = torch.empty(
x.shape[0],
weight.shape[0],
dtype=torch.bfloat16,
device=x.device,
)
tinygemm_bf16(x, weight, output, bias)
return output

def flashinfer_tinygemm_router_gemm_fake(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None,
) -> torch.Tensor:
return torch.empty(
x.shape[0],
weight.shape[0],
dtype=torch.bfloat16,
device=x.device,
)

direct_register_custom_op(
op_name="flashinfer_tinygemm_router_gemm",
op_func=flashinfer_tinygemm_router_gemm_impl,
fake_impl=flashinfer_tinygemm_router_gemm_fake,
)


@PluggableLayer.register("gate_linear")
class GateLinear(ReplicatedLinear):
"""MoE gate linear layer with three-tier GEMM dispatch:
"""MoE gate linear layer with four-tier GEMM dispatch:

1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
3. F.linear via ReplicatedLinear (ultimate fallback)
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims, no bias)
2. cuBLAS bf16xbf16→fp32 (SM90+ + bf16 + fp32 out_dtype, no bias)
3. Flashinfer tinygemm_bf16 kernel (SM90+, aligned dims, supports bias)
4. F.linear via ReplicatedLinear (ultimate fallback)

The ``out_dtype`` attribute is mutable and can be set after init
(e.g. when the required dtype depends on the expert quantization
Expand All @@ -35,12 +74,11 @@ def __init__(
force_fp32_compute: bool = False,
prefix: str = "",
):
is_hopper_or_blackwell = current_platform.is_device_capability(
(9, 0)
) or current_platform.is_device_capability_family(100)
can_use_specialized_kernels = (
current_platform.is_cuda() and is_hopper_or_blackwell and not bias
is_hopper_or_blackwell = current_platform.is_cuda() and (
current_platform.is_device_capability((9, 0))
or current_platform.is_device_capability_family(100)
)
can_use_specialized_kernels = is_hopper_or_blackwell and not bias

# If fp32 compute is required and no specialized kernel is available,
# store weights in fp32 so Tier 3 computes in fp32 natively.
Expand Down Expand Up @@ -72,6 +110,16 @@ def __init__(
and self.out_dtype == torch.float32
)

# Flashinfer tinygemm_bf16 (SM90+, aligned dims, supports bias)
self.allow_flashinfer_tinygemm_router_gemm = (
is_hopper_or_blackwell
and has_flashinfer()
and self.weight.dtype == torch.bfloat16
and self.out_dtype in [None, torch.bfloat16]
and input_size % 64 == 0
and output_size % 16 == 0
)

def set_out_dtype(self, out_dtype: torch.dtype) -> None:
"""Set output dtype for the router logits after init.

Expand All @@ -89,6 +137,10 @@ def set_out_dtype(self, out_dtype: torch.dtype) -> None:
):
self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16

# tinygemm outputs bf16 — disable if fp32 output is now required
if out_dtype == torch.float32:
self.allow_flashinfer_tinygemm_router_gemm = False

def forward(
self, x: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
Expand All @@ -108,7 +160,18 @@ def forward(
output = ops.router_gemm_bf16_fp32(x, self.weight)
return output, None

# Tier 3: F.linear (ReplicatedLinear)
# Tier 3: Flashinfer tinygemm_bf16
if (
self.allow_flashinfer_tinygemm_router_gemm
and x.dtype == torch.bfloat16
and x.shape[0] <= 128
Copy link
Copy Markdown
Contributor

@xyang16 xyang16 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x.shape[0] <= 128 check needs to be put inside the custom op. Otherwise tinygemm will never be launched. Because torch.compile integration does not support runtime dispatching on num_tokens.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was called correctly from my last test, and got improved perf.
The existed Tier 1 branch also uses this way.

# Tier 1: DSV3 specialized kernel
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
output = ops.dsv3_router_gemm(
hidden_states=x,
router_weight=self.weight,
output_dtype=self.out_dtype,
)
return output, None

Copy link
Copy Markdown
Contributor

@xyang16 xyang16 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I profiled your PR with gpt-oss-20b on H200. I don't see tinygemm kernel launched.

If I put the check inside the custom op, I can see tinygemm kernel launched:

void tinygemm_kernel<16, 16, 8, 64, 16, 4, false>(__...         0.00%       0.000us         0.00%       0.000us       0.000us     393.088us         1.51%     393.088us       3.276us           120  

Could you please double check? Thanks!

):
output = torch.ops.vllm.flashinfer_tinygemm_router_gemm(
x, self.weight, self.bias
)
return output, None

# Tier 4: F.linear (ReplicatedLinear)
if self.out_dtype is not None and x.dtype != self.weight.dtype:
x = x.to(self.weight.dtype)
output, output_bias = super().forward(x)
Expand Down
10 changes: 4 additions & 6 deletions vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand Down Expand Up @@ -175,14 +174,13 @@ def __init__(
self.hidden_size = config.hidden_size
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.router = ReplicatedLinear(
self.router = GateLinear(
config.hidden_size,
config.num_local_experts,
bias=True,
quant_config=None,
prefix=f"{prefix}.router",
return_bias=False,
)

assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(
num_experts=config.num_local_experts,
Expand All @@ -209,7 +207,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self, x[:, : self.hidden_size], self.router.weight, self.router.bias
)
else:
g = self.router(x)
g, _ = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]

if self.is_sequence_parallel:
Expand Down
Loading