[Kernel] Use pre-allocated output buffer for triton kernel fused_experts#29219
[Kernel] Use pre-allocated output buffer for triton kernel fused_experts#29219jeejeelee merged 4 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an optimization for the Triton-based fused MoE kernel by using pre-allocated output and intermediate buffers, which should reduce memory allocation overhead and improve performance. The changes involve modifying triton_kernel_fused_experts to accept these buffers and updating the call sites. Additionally, moe_problem_size is correctly overridden in OAITritonExperts to match the Triton kernel's weight layout expectations. My review identifies a critical bug in shape unpacking that could lead to a crash when handling 3D input tensors. The rest of the changes appear correct and consistent with the goal of the pull request.
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
Outdated
Show resolved
Hide resolved
c520f20 to
faea043
Compare
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
| y=output_tensor, | ||
| ) | ||
| return intermediate_cache3 | ||
| return output_tensor.view(M, K) |
There was a problem hiding this comment.
isn't output tensor already [M, K] ?
There was a problem hiding this comment.
matmal_ogs add batch_dim to output, the shape is [1, M, K]
There was a problem hiding this comment.
nvm - I see it is resized to [1, M , K] . 👍
| global_num_experts: int = -1, | ||
| expert_map: torch.Tensor | None = None, | ||
| intermediate_cache13: torch.Tensor | None = None, | ||
| intermediate_cache2: torch.Tensor | None = None, |
There was a problem hiding this comment.
better to remove the intermediate_cache2 arg if it is not used. and rename intermedidate_cache13 -> intermediate_cache
| intermediate_cache3 = matmul_ogs( | ||
| intermediate_cache1, | ||
| matmul_ogs( | ||
| intermediate_cache13.view(M * topk, N // 2), |
There was a problem hiding this comment.
does this not require batch_dim in the view ?
There was a problem hiding this comment.
No, because triton matmul_ogs doesn't support batch_dim with scatter
varun-sundar-rabindranath
left a comment
There was a problem hiding this comment.
LGTM. Nice optimization! Thanks @xyang16
Signed-off-by: Xin Yang <xyangx@amazon.com>
…rts (vllm-project#29219) Signed-off-by: Xin Yang <xyangx@amazon.com>
…rts (vllm-project#29219) Signed-off-by: Xin Yang <xyangx@amazon.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Purpose
This PR is to use pre-allocated output buffer for triton kernel matmal_ogs
moe_problem_size()function inOAITritonExperts, because the super classmoe_problem_sizeexpects N to be the second dimension of w1, see here. But triton kernels expect N to be the third dimension of w1. This will cause N assigned the value of K incorrectly for triton.Test Plan
Test Result
Unit test passed
Accuracy Testing
Benchmark
Baseline:
PR:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.cc @varun-sundar-rabindranath