[Perf] Support Flashinfer trtllm tinygemm_bf16 router gemm for GPT-OSS#37244
[Perf] Support Flashinfer trtllm tinygemm_bf16 router gemm for GPT-OSS#37244elvischenv wants to merge 3 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Flashinfer tinygemm_bf16 kernel for the MoE router GEMM in GPT-OSS models. This is achieved by creating a new GateLinear layer with a four-tier dispatch mechanism, where the new Flashinfer kernel is the third tier. The changes are well-implemented and include performance benchmarks showing a ~2% gain. I've identified a minor correctness issue regarding the type hint for the optional bias parameter in the new custom op, which could lead to runtime errors if GateLinear is used without a bias. My suggestions address this.
|
cc @robertgshaw2-redhat for viz |
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
b27e13c to
c4da2c7
Compare
|
@xyang16 I appreciate your review on my PR, and have picked some of your insights, e.g. benchmarked the kernel perf(updated in the PR description) and added batch size limitation. |
| if ( | ||
| self.allow_flashinfer_tinygemm_router_gemm | ||
| and x.dtype == torch.bfloat16 | ||
| and x.shape[0] <= 128 |
There was a problem hiding this comment.
x.shape[0] <= 128 check needs to be put inside the custom op. Otherwise tinygemm will never be launched. Because torch.compile integration does not support runtime dispatching on num_tokens.
There was a problem hiding this comment.
I think it was called correctly from my last test, and got improved perf.
The existed Tier 1 branch also uses this way.
vllm/vllm/model_executor/layers/fused_moe/router/gate_linear.py
Lines 97 to 104 in 99267c2
There was a problem hiding this comment.
I profiled your PR with gpt-oss-20b on H200. I don't see tinygemm kernel launched.
If I put the check inside the custom op, I can see tinygemm kernel launched:
void tinygemm_kernel<16, 16, 8, 64, 16, 4, false>(__... 0.00% 0.000us 0.00% 0.000us 0.000us 393.088us 1.51% 393.088us 3.276us 120
Could you please double check? Thanks!
|
This pull request has merge conflicts that must be resolved before it can be |
|
@elvischenv could you rebase and fix conflicts? thanks |
|
Per offline discussion, we think this has been covered by #37205 and we can close this. |
Purpose
Support Flashinfer trtllm tinygemm_bf16 router gemm for GPT-OSS.
Test Plan && Test Result
nsys
PR:
main:
Kernel perf
GPU: NVIDIA B200
GPU: NVIDIA H100 PCIe
E2E accuracy
PR:
main:
E2E perf
PR: about 2% perf gain
main:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.