-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Fuse MLA RoPE + FP8 Quantization + KV Cache Write into Single CUDA Kernel #12503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @YAMY1234, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant performance enhancement by fusing three critical operations—MLA RoPE application, FP8 quantization, and KV-cache writing—into a single CUDA kernel. This fusion aims to minimize kernel launch overheads and global memory access, which are bottlenecks in the FP8 decode path. The changes include the implementation of the new fused kernel, its integration into the existing attention backend with conditional activation for modern GPU architectures, and comprehensive benchmarking and testing to validate both performance gains and accuracy preservation. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fused CUDA kernel for MLA RoPE, FP8 quantization, and KV cache writing, which is a significant performance optimization. The changes are well-structured, including a new standalone build for the kernel, benchmarks, and correctness tests. The Python-side logic correctly integrates the new kernel, with a fallback to the old path, ensuring safety and compatibility. The CUDA kernel itself is well-written with vectorization for performance and a scalar fallback. I have a couple of suggestions for improving code maintainability in the setup script and the CUDA kernel. Overall, this is a high-quality contribution.
| "../sgl-kernel/csrc/elementwise/mla_rope_fp8_kv_fused.cu", | ||
| ], | ||
| include_dirs=[ | ||
| "../sgl-kernel/include", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sources and include_dirs paths for the CUDAExtension use relative paths with ... This makes the build script dependent on the current working directory from which setup.py is invoked. To make it more robust, you could construct absolute paths based on the location of the setup.py file itself.
| "../sgl-kernel/csrc/elementwise/mla_rope_fp8_kv_fused.cu", | |
| ], | |
| include_dirs=[ | |
| "../sgl-kernel/include", | |
| os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "sgl-kernel", "csrc", "elementwise", "mla_rope_fp8_kv_fused.cu"), | |
| ], | |
| include_dirs=[ | |
| os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "sgl-kernel", "include"), |
| return (uint32_t)a0 | ((uint32_t)a1 << 8) | ((uint32_t)a2 << 16) | ((uint32_t)a3 << 24); | ||
| } | ||
|
|
||
| __device__ inline void rope_rotate(float& xr, float& xi, float c, float s, bool) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bool parameter in the rope_rotate function signature is unnamed and unused within the function. It seems to be a remnant from a previous implementation. For better code clarity and maintainability, it's best to remove it from the function definition and all its call sites within this file.
__device__ inline void rope_rotate(float& xr, float& xi, float c, float s) {
Motivation
Fuse MLA RoPE + FP8 quantization + KV-cache write into a single CUDA kernel to reduce kernel launches and global-memory traffic during decode. This primarily targets MLA’s FP8 path where we previously (1) applied RoPE, (2) quantized Q/K to FP8, and (3) wrote K into KV cache in separate steps.
Modifications
Add fused CUDA kernel
mla_rope_fp8_kv_fused.cuthat:Applies RoPE to rope parts,
Quantizes {Q_nope,Q_rope,K_nope,K_rope} to FP8,
Emits Q (packed) and writes K directly into KV cache when locations are provided.
Wire a fused path in
python/sglang/srt/layers/attention/trtllm_mla_backend.pythat calls the fused op when available.Add tests & microbench:
test/srt/test_mla_fp8_fused.pyfor correctness,benchmark/kernels/bench_flashmla_fused_kv.pyfor microbenchmarking.Accuracy Tests
Benchmarking and Profiling
fused(3us):

main( 4+us & inter ops):
Send_one:
main:
this pr:
Checklist