Skip to content

[Kernel] Use pre-allocated output buffer for triton kernel fused_experts#29219

Merged
jeejeelee merged 4 commits intovllm-project:mainfrom
xyang16:triton
Nov 26, 2025
Merged

[Kernel] Use pre-allocated output buffer for triton kernel fused_experts#29219
jeejeelee merged 4 commits intovllm-project:mainfrom
xyang16:triton

Conversation

@xyang16
Copy link
Contributor

@xyang16 xyang16 commented Nov 22, 2025

Purpose

This PR is to use pre-allocated output buffer for triton kernel matmal_ogs

  • Fix N by overriding moe_problem_size() function in OAITritonExperts, because the super class moe_problem_size expects N to be the second dimension of w1, see here. But triton kernels expect N to be the third dimension of w1. This will cause N assigned the value of K incorrectly for triton.
  • Allocate intermediate_cache13 (shape [M * topk, N // 2]) to be the output buffer of first matmal_ogs
  • Allocate output (shape [M, K]) to be the output buffer of second matmal_ogs
  • Add batch_dim to output buffer because matmul_ogs expects 3D output, see here.

Test Plan

pytest -s -v tests/kernels/moe/test_gpt_oss_triton_kernels.py 

Test Result

Unit test passed

Accuracy Testing

  • gpt-oss-20b
vllm serve openai/gpt-oss-20b --tensor-parallel-size 1 --max-num-seqs=16 
Writing report to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_182105.html
{'chars': np.float64(66.06565656565657), 'chars:std': np.float64(235.89106420986758), 'score': np.float64(0.5681818181818182), 'score:std': np.float64(0.4953294254023493)}
Writing results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_182105.json
Writing all results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_182105_allresults.json
[{'eval_name': 'gpqa', 'model_name': 'openai__gpt-oss-20b-low_temp1.0_20251121_182105', 'metric': 0.5681818181818182}]
  • gpt-oss-20b deepep
VLLM_ALL2ALL_BACKEND="deepep_high_throughput" vllm serve openai/gpt-oss-20b --tensor-parallel-size 1 --data-parallel-size 2 --enable-expert-parallel --no-enable-prefix-caching
Writing report to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_184437.html
{'chars': np.float64(71.08207070707071), 'chars:std': np.float64(268.30453841690064), 'score': np.float64(0.5707070707070707), 'score:std': np.float64(0.4949752621616814)}
Writing results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_184437.json
Writing all results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_184437_allresults.json
[{'eval_name': 'gpqa', 'model_name': 'openai__gpt-oss-20b-low_temp1.0_20251121_184437', 'metric': 0.5707070707070707}]

Benchmark

vllm serve openai/gpt-oss-20b \
  --tensor-parallel-size 1 \
  --max-num-seqs 16
vllm bench serve \
  --model openai/gpt-oss-20b \
  --dataset-name sharegpt \
  --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
  --max-concurrency 16 \
  --num-prompts 1000 \
  --num-warmups 60 \
  --ignore-eos

Baseline:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  103.15    
Total input tokens:                      215312    
Total generated tokens:                  199033    
Request throughput (req/s):              9.69      
Output token throughput (tok/s):         1929.46   
Peak output token throughput (tok/s):    2104.00   
Peak concurrent requests:                33.00     
Total Token throughput (tok/s):          4016.74   
---------------Time to First Token----------------
Mean TTFT (ms):                          30.38     
Median TTFT (ms):                        20.28     
P99 TTFT (ms):                           579.15    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.03      
Median TPOT (ms):                        8.01      
P99 TPOT (ms):                           9.32      
---------------Inter-token Latency----------------
Mean ITL (ms):                           8.04      
Median ITL (ms):                         7.66      
P99 ITL (ms):                            17.10     
==================================================

PR:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  96.91     
Total input tokens:                      215312    
Total generated tokens:                  199033    
Request throughput (req/s):              10.32     
Output token throughput (tok/s):         2053.88   
Peak output token throughput (tok/s):    2230.00   
Peak concurrent requests:                35.00     
Total Token throughput (tok/s):          4275.74   
---------------Time to First Token----------------
Mean TTFT (ms):                          19.89     
Median TTFT (ms):                        17.34     
P99 TTFT (ms):                           86.70     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.59      
Median TPOT (ms):                        7.55      
P99 TPOT (ms):                           8.64      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.59      
Median ITL (ms):                         7.37      
P99 ITL (ms):                            16.13     
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

cc @varun-sundar-rabindranath

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimization for the Triton-based fused MoE kernel by using pre-allocated output and intermediate buffers, which should reduce memory allocation overhead and improve performance. The changes involve modifying triton_kernel_fused_experts to accept these buffers and updating the call sites. Additionally, moe_problem_size is correctly overridden in OAITritonExperts to match the Triton kernel's weight layout expectations. My review identifies a critical bug in shape unpacking that could lead to a crash when handling 3D input tensors. The rest of the changes appear correct and consistent with the goal of the pull request.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@ApostaC
Copy link
Collaborator

ApostaC commented Nov 24, 2025

cc @mgoin @pavanimajety

Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
y=output_tensor,
)
return intermediate_cache3
return output_tensor.view(M, K)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't output tensor already [M, K] ?

Copy link
Contributor Author

@xyang16 xyang16 Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matmal_ogs add batch_dim to output, the shape is [1, M, K]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm - I see it is resized to [1, M , K] . 👍

global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
intermediate_cache13: torch.Tensor | None = None,
intermediate_cache2: torch.Tensor | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to remove the intermediate_cache2 arg if it is not used. and rename intermedidate_cache13 -> intermediate_cache

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed. Thanks!

intermediate_cache3 = matmul_ogs(
intermediate_cache1,
matmul_ogs(
intermediate_cache13.view(M * topk, N // 2),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this not require batch_dim in the view ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because triton matmul_ogs doesn't support batch_dim with scatter

Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Nice optimization! Thanks @xyang16

Signed-off-by: Xin Yang <xyangx@amazon.com>
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 25, 2025
@github-project-automation github-project-automation bot moved this from To Triage to Ready in gpt-oss Issues & Enhancements Nov 26, 2025
@jeejeelee jeejeelee merged commit 53d7f1f into vllm-project:main Nov 26, 2025
50 checks passed
@xyang16 xyang16 deleted the triton branch November 28, 2025 03:11
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…rts (vllm-project#29219)

Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants