-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Kernel] Optimization of the mm_k operator. #28280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Optimization of the mm_k operator. #28280
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant optimization to the mm_k Triton kernel by intelligently handling masking for the not EVEN_K case. By categorizing memory access patterns into fully in-bounds, fully out-of-bounds, and partially out-of-bounds, it effectively reduces redundant masking and tl.dot() operations, which is a great improvement. The logic appears sound and the performance gains are evident from the benchmarks. I've added one suggestion for a minor further optimization in the EVEN_K path to remove a redundant bounds check.
| # K is divisible by BLOCK_K, no masking ever needed | ||
| # But skip if entire block is out of range | ||
| if iter_k < K: | ||
| tiled_a = tl.load(a_ptr) | ||
| tiled_b = tl.load(b_ptr) | ||
| if CAST_TYPE: | ||
| tiled_a = tiled_a.to(b_dtype) | ||
| accumulator += tl.dot(tiled_a, tiled_b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check if iter_k < K: appears to be redundant when EVEN_K is true. The EVEN_K constant expression implies that K is a multiple of STEP_K (where STEP_K is BLOCK_K * SPLIT_K). Given the loop for k in range(tl.cdiv(K, STEP_K)) and the calculation of iter_k, all memory accesses are guaranteed to be within the bounds of K. Removing this unnecessary branch could yield a small performance improvement in this hot loop.
# K is divisible by BLOCK_K, no masking ever needed.
# When EVEN_K is true, all loads are guaranteed to be in-bounds.
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
use 4090Add an additional set of experiments on the RTX 4090 GPU. Benchmark duration ↓5.2%,Output token throughput ↑5.5% main this PR |
use 3090Experiments on the RTX 3090 GPU are complete. Output token throughput increases by approximately 8%, consistent with the intuition that this change is more friendly to older GPU architectures. main this PR |
21cd263 to
efb5d82
Compare
use H800 (rebenchmarked after rebase)After resolving rebase conflicts, reran the benchmark on the H800: |
jeejeelee
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM,thank you for contribution
Co-authored-by: Jee Jee Li <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
Co-authored-by: Jee Jee Li <[email protected]>
Purpose
The most significant change lies in the handling of the not
EVEN_Kscenario. In the main branch implementation, each loop iteration in this case incurs additional masking operations andtl.dot()calls. However, my analysis reveals that masking is only necessary for partially out-of-bound accesses. The other two scenarios are:tl.dot()operation.tl.dot()operations.By reducing redundant masking and
tl.dot()operations, this modification improves computational speed. The optimization is particularly effective and friendly for GPUs with less advanced architectures. Performance results on the Metax C500 and H800 GPUs are as follows:use C500
main
this PR
use H800
main
this PR
Test plan
Test result
All test cases passed.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.