Skip to content

[Perf] Support Flashinfer trtllm tinygemm_bf16 router gemm for GPT-OSS#37244

Open
elvischenv wants to merge 3 commits intovllm-project:mainfrom
elvischenv:elvischenv/support-flashinfer-tinygemm
Open

[Perf] Support Flashinfer trtllm tinygemm_bf16 router gemm for GPT-OSS#37244
elvischenv wants to merge 3 commits intovllm-project:mainfrom
elvischenv:elvischenv/support-flashinfer-tinygemm

Conversation

@elvischenv
Copy link
Copy Markdown
Contributor

@elvischenv elvischenv commented Mar 17, 2026

Purpose

Support Flashinfer trtllm tinygemm_bf16 router gemm for GPT-OSS.

Test Plan && Test Result

nsys

PR:

void flashinfer::trtllm_allreduce_fusion::allreduce_fusion_kernel_oneshot_lamport   5.088 μs
void tinygemm_kernel                                                                3.136 μs
void tensorrt_llm::kernels::quantize_with_block_size                                2.848 μs

main:

void flashinfer::trtllm_allreduce_fusion::allreduce_fusion_kernel_oneshot_lamport   5.344 μs
nvjet_sm100_tst_32x64_64x16_4x1_v_bz_splitK_bias_TNN                                3.904 μs
void cublasLt::splitKreduce_kernel                                                  2.816 μs
void tensorrt_llm::kernels::quantize_with_block_size                                3.360 μs

Kernel perf

GPU: NVIDIA B200

gpt-oss-120b: hidden_size=2880, num_experts=128, bias=True
 batch  F.linear(us)  tinygemm(us)   speedup
------------------------------------------
     1      0.0057        0.0034     1.66x
     2      0.0059        0.0034     1.72x
     4      0.0059        0.0034     1.72x
     8      0.0059        0.0034     1.72x
    16      0.0061        0.0034     1.78x
    32      0.0061        0.0034     1.78x
    64      0.0059        0.0036     1.62x
   128      0.0061        0.0036     1.68x
   256      0.0061        0.0059     1.03x
   512      0.0061        0.0104     0.59x

GPU: NVIDIA H100 PCIe

gpt-oss-120b: hidden_size=2880, num_experts=128, bias=True
 batch  F.linear(us)  tinygemm(us)   speedup
------------------------------------------
     1      0.0066        0.0037     1.79x
     2      0.0066        0.0037     1.80x
     4      0.0066        0.0037     1.79x
     8      0.0066        0.0037     1.80x
    16      0.0069        0.0038     1.84x
    32      0.0070        0.0038     1.85x
    64      0.0071        0.0041     1.74x
   128      0.0073        0.0067     1.08x
   256      0.0074        0.0104     0.71x
   512      0.0073        0.0165     0.44x

E2E accuracy

PR:

[{'eval_name': 'gpqa', 'model_name': 'gpt-oss-120b-high_temp1.0_20260316_202021', 'metric': 0.7954545454545454}]

main:

[{'eval_name': 'gpqa', 'model_name': 'gpt-oss-120b-high_temp1.0_20260315_210654', 'metric': 0.7891414141414141}]

E2E perf

PR: about 2% perf gain

============ Serving Benchmark Result ============
Successful requests:                     80
Failed requests:                         0
Maximum request concurrency:             8
Benchmark duration (s):                  29.79
Total input tokens:                      81920
Total generated tokens:                  81920
Request throughput (req/s):              2.69
Output token throughput (tok/s):         2749.91
Peak output token throughput (tok/s):    153.00
Peak concurrent requests:                16.00
Total token throughput (tok/s):          5499.81
---------------Time to First Token----------------
Mean TTFT (ms):                          57.26
Median TTFT (ms):                        65.95
P99 TTFT (ms):                           70.88
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          2.85
Median TPOT (ms):                        2.85
P99 TPOT (ms):                           2.90
---------------Inter-token Latency----------------
Mean ITL (ms):                           56.15
Median ITL (ms):                         56.96
P99 ITL (ms):                            58.44
==================================================

main:

============ Serving Benchmark Result ============
Successful requests:                     80
Failed requests:                         0
Maximum request concurrency:             8
Benchmark duration (s):                  30.68
Total input tokens:                      81920
Total generated tokens:                  81920
Request throughput (req/s):              2.61
Output token throughput (tok/s):         2670.22
Peak output token throughput (tok/s):    144.00
Peak concurrent requests:                16.00
Total token throughput (tok/s):          5340.45
---------------Time to First Token----------------
Mean TTFT (ms):                          74.49
Median TTFT (ms):                        81.17
P99 TTFT (ms):                           102.37
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          2.93
Median TPOT (ms):                        2.92
P99 TPOT (ms):                           2.98
---------------Inter-token Latency----------------
Mean ITL (ms):                           57.55
Median ITL (ms):                         58.38
P99 ITL (ms):                            59.59
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Flashinfer tinygemm_bf16 kernel for the MoE router GEMM in GPT-OSS models. This is achieved by creating a new GateLinear layer with a four-tier dispatch mechanism, where the new Flashinfer kernel is the third tier. The changes are well-implemented and include performance benchmarks showing a ~2% gain. I've identified a minor correctness issue regarding the type hint for the optional bias parameter in the new custom op, which could lead to runtime errors if GateLinear is used without a bias. My suggestions address this.

@elvischenv
Copy link
Copy Markdown
Contributor Author

cc @robertgshaw2-redhat for viz

elvischenv and others added 3 commits March 18, 2026 01:59
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
@elvischenv elvischenv force-pushed the elvischenv/support-flashinfer-tinygemm branch from b27e13c to c4da2c7 Compare March 18, 2026 09:52
@elvischenv
Copy link
Copy Markdown
Contributor Author

@xyang16 I appreciate your review on my PR, and have picked some of your insights, e.g. benchmarked the kernel perf(updated in the PR description) and added batch size limitation.

if (
self.allow_flashinfer_tinygemm_router_gemm
and x.dtype == torch.bfloat16
and x.shape[0] <= 128
Copy link
Copy Markdown
Contributor

@xyang16 xyang16 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x.shape[0] <= 128 check needs to be put inside the custom op. Otherwise tinygemm will never be launched. Because torch.compile integration does not support runtime dispatching on num_tokens.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was called correctly from my last test, and got improved perf.
The existed Tier 1 branch also uses this way.

# Tier 1: DSV3 specialized kernel
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
output = ops.dsv3_router_gemm(
hidden_states=x,
router_weight=self.weight,
output_dtype=self.out_dtype,
)
return output, None

Copy link
Copy Markdown
Contributor

@xyang16 xyang16 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I profiled your PR with gpt-oss-20b on H200. I don't see tinygemm kernel launched.

If I put the check inside the custom op, I can see tinygemm kernel launched:

void tinygemm_kernel<16, 16, 8, 64, 16, 4, false>(__...         0.00%       0.000us         0.00%       0.000us       0.000us     393.088us         1.51%     393.088us       3.276us           120  

Could you please double check? Thanks!

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 18, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @elvischenv.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 18, 2026
@nvpohanh
Copy link
Copy Markdown
Contributor

@elvischenv could you rebase and fix conflicts? thanks

@nvpohanh
Copy link
Copy Markdown
Contributor

Per offline discussion, we think this has been covered by #37205 and we can close this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models needs-rebase

Projects

Status: To Triage

Development

Successfully merging this pull request may close these issues.

3 participants