[Kernel] Add gpt-oss Router GEMM kernel#37205
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an optimized GEMM kernel for the gpt-oss router, which demonstrates performance improvements for small batch sizes. The integration is well-supported by new unit tests and benchmarks. My review focuses on enhancing the robustness and error handling of the new CUDA kernel code. I've identified the use of exit() in a library context and assert() for error checking, which could lead to silent failures in release builds or abrupt process termination. I have suggested replacing these with PyTorch's standard error-checking mechanisms (TORCH_CHECK) and C++ exceptions to ensure proper error reporting.
|
Hi @xyang16, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
ff7ef6e to
e75fa6a
Compare
| # Tier 2: cuBLAS bf16→fp32 | ||
| # Tier 2: gpt-oss specialized kernel | ||
| if self.allow_gpt_oss_router_gemm: | ||
| output = torch.ops.vllm.gpt_oss_router_gemm(x, self.weight, self.bias) | ||
| return output, None |
There was a problem hiding this comment.
Shouldn't we skip this case if x.shape[0] > 128 so it could fall to other implementations? Such as cublas
There was a problem hiding this comment.
Yes, ideally the check should be:
if self.allow_gpt_oss_router_gemm and x.shape[0] <= 128:
output = ops.gpt_oss_router_gemm(x, self.weight, self.bias)
return output, None
But I found if I have x.shape[0] <= 128 check like above, the custom router gemm is never launched, because torch.compile integration does not support runtime dispatching on num_tokens. So I have to put the x.shape[0] <= 128 check in the custom ops, similarly like https://github.com/vllm-project/vllm/blob/v0.18.0rc0/vllm/model_executor/models/deepseek_v2.py#L735-L755
Please let me know if you have any good suggestions. Thanks!
There was a problem hiding this comment.
Oh I see, that is a fair tradeoff for now
There was a problem hiding this comment.
Thanks! And since the cublas ops.router_gemm_bf16_fp32 doesn't support bias, so it's basically the same as before.
|
@xyang16 the lora test failure looks related |
@mgoin This has been fixed in main by #37181. I have rebased this PR with main. Thanks! |
Signed-off-by: Xin Yang <xyangx@amazon.com>
This reverts commit ac52c1d. Signed-off-by: Xin Yang <xyangx@amazon.com>
|
I manually kicked off the gpqa-eval-gpt-oss tests and all green, merging |
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com> Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com> Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Xin Yang <xyangx@amazon.com> Signed-off-by: EricccYang <yangyang4991@gmail.com>
This reverts commit b1169d7.
This reverts commit b1169d7. Signed-off-by: Xin Yang <xyangx@amazon.com>
…vllm-project#38778) Signed-off-by: Xin Yang <xyangx@amazon.com> Signed-off-by: bsliu <1187291748@qq.com>
…vllm-project#38778) Signed-off-by: Xin Yang <xyangx@amazon.com>
Purpose
This PR add gpt-oss optimized Router GEMM kernel.
1% - 2% output token throughput improvement at batch size 1.
Test Plan
Added unit test.
Test Result
Unit test passed.
Micro bench
gpt_oss_router_gemmkernel has better throughput for low batch size.Benchmark
Main:
concurrency=1
concurrency=16
PR:
concurrency=1
concurrency=16
Accuracy Testing
Main:
PR:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.