Skip to content

[Kernel] Add gpt-oss Router GEMM kernel#37205

Merged
vllm-bot merged 12 commits intovllm-project:mainfrom
xyang16:router_gemm
Mar 18, 2026
Merged

[Kernel] Add gpt-oss Router GEMM kernel#37205
vllm-bot merged 12 commits intovllm-project:mainfrom
xyang16:router_gemm

Conversation

@xyang16
Copy link
Copy Markdown
Contributor

@xyang16 xyang16 commented Mar 16, 2026

Purpose

This PR add gpt-oss optimized Router GEMM kernel.

1% - 2% output token throughput improvement at batch size 1.

Test Plan

Added unit test.

pytest -s -v tests/kernels/moe/test_router_gemm.py

Test Result

Unit test passed.

Micro bench

python3 benchmarks/kernels/benchmark_router_gemm.py --model openai/gpt-oss-20b --max-batch-size 8192
openai/gpt-oss-20b router gemm throughput:
    batch_size  PyTorch (TFLOPs)  vLLM (TFLOPs)
0          1.0          0.031421       0.089570
1          2.0          0.064126       0.178294
2          4.0          0.128316       0.354806
3          8.0          0.255723       0.707328
4         16.0          0.489696       1.371641
5         32.0          0.968925       2.821921
6         64.0          1.911852       5.565304
7        128.0          3.720499      11.084153
8        256.0          7.380550      20.414985
9        512.0         14.636453      32.505290
10      1024.0         29.070115      33.019194
11      2048.0         56.395810      33.173499
12      4096.0        120.938349      34.309334
13      8192.0        152.726427      32.962142
python3 benchmarks/kernels/benchmark_router_gemm.py --model openai/gpt-oss-120b --max-batch-size 8192
openai/gpt-oss-120b router gemm throughput:
    batch_size  PyTorch (TFLOPs)  vLLM (TFLOPs)
0          1.0          0.123305       0.355313
1          2.0          0.254834       0.705630
2          4.0          0.505317       1.404891
3          8.0          1.004569       2.794942
4         16.0          1.918626       5.537719
5         32.0          3.781321      10.981122
6         64.0          7.482728      19.906920
7        128.0         14.560503      32.026677
8        256.0         28.719212      32.120986
9        512.0         56.351540      32.404140
10      1024.0        112.437718      32.985149
11      2048.0        201.149820      35.144443
12      4096.0        402.626180      37.998635
13      8192.0        452.585227      36.484008

gpt_oss_router_gemm kernel has better throughput for low batch size.

Benchmark

vllm serve openai/gpt-oss-20b \
  --tensor-parallel-size 1 \
  --max-num-seqs 16 \
  --no-enable-prefix-caching
vllm bench serve \
        --model openai/gpt-oss-20b \
        --dataset-name sharegpt \
        --dataset-path /tmp/ShareGPT_V3_unfiltered_cleaned_split.json \
        --sharegpt-output-len 300 \
        --num-prompts ${num_prompts} \
        --max-concurrency ${concurrency} \
        --num-warmups 50 \
        --ignore-eos \
        --temperature 0

Main:

concurrency=1

============ Serving Benchmark Result ============
Successful requests:                     60        
Failed requests:                         0         
Maximum request concurrency:             1         
Benchmark duration (s):                  87.73     
Total input tokens:                      15599     
Total generated tokens:                  18000     
Request throughput (req/s):              0.68      
Output token throughput (tok/s):         205.19    
Peak output token throughput (tok/s):    219.00    
Peak concurrent requests:                2.00      
Total token throughput (tok/s):          383.00    
---------------Time to First Token----------------
Mean TTFT (ms):                          32.70     
Median TTFT (ms):                        28.32     
P99 TTFT (ms):                           80.66     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          4.78      
Median TPOT (ms):                        4.75      
P99 TPOT (ms):                           5.17      
---------------Inter-token Latency----------------
Mean ITL (ms):                           4.78      
Median ITL (ms):                         4.74      
P99 ITL (ms):                            5.52      
==================================================

concurrency=16

============ Serving Benchmark Result ============
Successful requests:                     960       
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  444.34    
Total input tokens:                      217301    
Total generated tokens:                  288000    
Request throughput (req/s):              2.16      
Output token throughput (tok/s):         648.16    
Peak output token throughput (tok/s):    768.00    
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          1137.21   
---------------Time to First Token----------------
Mean TTFT (ms):                          190.85    
Median TTFT (ms):                        152.53    
P99 TTFT (ms):                           827.11    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          24.12     
Median TPOT (ms):                        23.99     
P99 TPOT (ms):                           26.65     
---------------Inter-token Latency----------------
Mean ITL (ms):                           24.12     
Median ITL (ms):                         23.72     
P99 ITL (ms):                            36.63     
==================================================

PR:

concurrency=1

============ Serving Benchmark Result ============
Successful requests:                     60        
Failed requests:                         0         
Maximum request concurrency:             1         
Benchmark duration (s):                  86.24     
Total input tokens:                      15599     
Total generated tokens:                  18000     
Request throughput (req/s):              0.70      
Output token throughput (tok/s):         208.73    
Peak output token throughput (tok/s):    222.00    
Peak concurrent requests:                2.00      
Total token throughput (tok/s):          389.62    
---------------Time to First Token----------------
Mean TTFT (ms):                          32.81     
Median TTFT (ms):                        28.40     
P99 TTFT (ms):                           80.30     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          4.70      
Median TPOT (ms):                        4.66      
P99 TPOT (ms):                           5.08      
---------------Inter-token Latency----------------
Mean ITL (ms):                           4.70      
Median ITL (ms):                         4.66      
P99 ITL (ms):                            5.44      
==================================================

concurrency=16

============ Serving Benchmark Result ============
Successful requests:                     960       
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  442.22    
Total input tokens:                      217301    
Total generated tokens:                  288000    
Request throughput (req/s):              2.17      
Output token throughput (tok/s):         651.26    
Peak output token throughput (tok/s):    784.00    
Peak concurrent requests:                32.00     
Total token throughput (tok/s):          1142.64   
---------------Time to First Token----------------
Mean TTFT (ms):                          174.88    
Median TTFT (ms):                        147.44    
P99 TTFT (ms):                           829.79    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          24.06     
Median TPOT (ms):                        24.02     
P99 TPOT (ms):                           26.06     
---------------Inter-token Latency----------------
Mean ITL (ms):                           24.06     
Median ITL (ms):                         23.67     
P99 ITL (ms):                            42.65     
==================================================

Accuracy Testing

OPENAI_API_KEY=EMPTY python3 -m gpt_oss.evals --model openai/gpt-oss-20b --eval gpqa --n-threads 200 --reasoning-effort low

Main:

Writing report to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_045459.html
{'chars': np.float64(66.44823232323232), 'chars:std': np.float64(235.44711891411228), 'score': np.float64(0.5561868686868687), 'score:std': np.float64(0.49683300593576163)}
Writing results to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_045459.json
Writing all results to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_045459_allresults.json
[{'eval_name': 'gpqa', 'model_name': '__opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_045459', 'metric': 0.5561868686868687}]

PR:

Writing report to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_040258.html
{'chars': np.float64(73.21843434343434), 'chars:std': np.float64(258.6924049276393), 'score': np.float64(0.5662878787878788), 'score:std': np.float64(0.49558643759268023)}
Writing results to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_040258.json
Writing all results to /tmp/gpqa___opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_040258_allresults.json
[{'eval_name': 'gpqa', 'model_name': '__opt__dlami__nvme__models__gpt-oss-20b-low_temp1.0_20260316_040258', 'metric': 0.5662878787878788}]

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimized GEMM kernel for the gpt-oss router, which demonstrates performance improvements for small batch sizes. The integration is well-supported by new unit tests and benchmarks. My review focuses on enhancing the robustness and error handling of the new CUDA kernel code. I've identified the use of exit() in a library context and assert() for error checking, which could lead to silent failures in release builds or abrupt process termination. I have suggested replacing these with PyTorch's standard error-checking mechanisms (TORCH_CHECK) and C++ exceptions to ensure proper error reporting.

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 16, 2026

Hi @xyang16, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@xyang16 xyang16 force-pushed the router_gemm branch 6 times, most recently from ff7ef6e to e75fa6a Compare March 17, 2026 01:13
Comment on lines -106 to +123
# Tier 2: cuBLAS bf16→fp32
# Tier 2: gpt-oss specialized kernel
if self.allow_gpt_oss_router_gemm:
output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias)
return output, None
Copy link
Copy Markdown
Member

@mgoin mgoin Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we skip this case if x.shape[0] > 128 so it could fall to other implementations? Such as cublas

Copy link
Copy Markdown
Contributor Author

@xyang16 xyang16 Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ideally the check should be:

        if self.allow_gpt_oss_router_gemm and x.shape[0] <= 128:
            output = ops.gpt_oss_router_gemm(x, self.weight, self.bias)
            return output, None

But I found if I have x.shape[0] <= 128 check like above, the custom router gemm is never launched, because torch.compile integration does not support runtime dispatching on num_tokens. So I have to put the x.shape[0] <= 128 check in the custom ops, similarly like https://github.com/vllm-project/vllm/blob/v0.18.0rc0/vllm/model_executor/models/deepseek_v2.py#L735-L755

Please let me know if you have any good suggestions. Thanks!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, that is a fair tradeoff for now

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! And since the cublas ops.router_gemm_bf16_fp32 doesn't support bias, so it's basically the same as before.

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 17, 2026
@mgoin
Copy link
Copy Markdown
Member

mgoin commented Mar 17, 2026

@xyang16 the lora test failure looks related

[2026-03-17T20:24:09Z] (EngineCore pid=9845)     super().__init__(
[2026-03-17T20:24:09Z] (EngineCore pid=9845)   File "/usr/local/lib/python3.12/dist-packages/vllm/lora/model_manager.py", line 114, in __init__
[2026-03-17T20:24:09Z] (EngineCore pid=9845)     self._create_lora_modules()
[2026-03-17T20:24:09Z] (EngineCore pid=9845)   File "/usr/local/lib/python3.12/dist-packages/vllm/lora/model_manager.py", line 410, in _create_lora_modules
[2026-03-17T20:24:09Z] (EngineCore pid=9845)     self.register_module(module_name, new_module)
[2026-03-17T20:24:09Z] (EngineCore pid=9845)   File "/usr/local/lib/python3.12/dist-packages/vllm/lora/model_manager.py", line 417, in register_module
[2026-03-17T20:24:09Z] (EngineCore pid=9845)     assert isinstance(module, BaseLayerWithLoRA), (
[2026-03-17T20:24:09Z] (EngineCore pid=9845)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[2026-03-17T20:24:09Z] (EngineCore pid=9845) AssertionError: Module model.layers.0.mlp.router must be a BaseLayerWithLoRA instance, got <class 'vllm.model_executor.layers.fused_moe.router.gate_linear.GateLinear'>

@xyang16
Copy link
Copy Markdown
Contributor Author

xyang16 commented Mar 17, 2026

@xyang16 the lora test failure looks related

[2026-03-17T20:24:09Z] (EngineCore pid=9845)     super().__init__(
[2026-03-17T20:24:09Z] (EngineCore pid=9845)   File "/usr/local/lib/python3.12/dist-packages/vllm/lora/model_manager.py", line 114, in __init__
[2026-03-17T20:24:09Z] (EngineCore pid=9845)     self._create_lora_modules()
[2026-03-17T20:24:09Z] (EngineCore pid=9845)   File "/usr/local/lib/python3.12/dist-packages/vllm/lora/model_manager.py", line 410, in _create_lora_modules
[2026-03-17T20:24:09Z] (EngineCore pid=9845)     self.register_module(module_name, new_module)
[2026-03-17T20:24:09Z] (EngineCore pid=9845)   File "/usr/local/lib/python3.12/dist-packages/vllm/lora/model_manager.py", line 417, in register_module
[2026-03-17T20:24:09Z] (EngineCore pid=9845)     assert isinstance(module, BaseLayerWithLoRA), (
[2026-03-17T20:24:09Z] (EngineCore pid=9845)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[2026-03-17T20:24:09Z] (EngineCore pid=9845) AssertionError: Module model.layers.0.mlp.router must be a BaseLayerWithLoRA instance, got <class 'vllm.model_executor.layers.fused_moe.router.gate_linear.GateLinear'>

@mgoin This has been fixed in main by #37181. I have rebased this PR with main. Thanks!

xyang16 added 2 commits March 17, 2026 17:02
Signed-off-by: Xin Yang <xyangx@amazon.com>
This reverts commit ac52c1d.

Signed-off-by: Xin Yang <xyangx@amazon.com>
@xyang16
Copy link
Copy Markdown
Contributor Author

xyang16 commented Mar 18, 2026

It seems we might be able to get this gemm directly from flashinfer? #37244

@mgoin I tried the flashinfer tinygemm, but it breaks test_batch_invariance.py. If I use the ported gpt_oss_router_gemm.cu kernel it works fine though.

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Mar 18, 2026

I manually kicked off the gpqa-eval-gpt-oss tests and all green, merging

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 18, 2026
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Mar 18, 2026
@vllm-bot vllm-bot merged commit b1169d7 into vllm-project:main Mar 18, 2026
132 of 134 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 18, 2026
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
@xyang16 xyang16 deleted the router_gemm branch March 19, 2026 22:42
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
xyang16 added a commit to xyang16/vllm that referenced this pull request Apr 2, 2026
xyang16 added a commit to xyang16/vllm that referenced this pull request Apr 2, 2026
This reverts commit b1169d7.

Signed-off-by: Xin Yang <xyangx@amazon.com>
vllm-bot pushed a commit that referenced this pull request Apr 2, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
bingshuailiu pushed a commit to bingshuailiu/vllm that referenced this pull request Apr 2, 2026
…vllm-project#38778)

Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: bsliu <1187291748@qq.com>
yzong-rh pushed a commit to yzong-rh/vllm that referenced this pull request Apr 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build gpt-oss Related to GPT-OSS models nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants