[sgl-kernel] support custom fp8 flashmla kernel#13087
[sgl-kernel] support custom fp8 flashmla kernel#13087Fridge003 merged 6 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @FlamingoPg, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces crucial support for custom FP8 FlashMLA kernels within the 'sgl-kernel' project. The primary goal is to leverage 8-bit floating-point precision for attention mechanisms, which can lead to significant memory savings and potential performance improvements on compatible hardware. The changes involve integrating a specialized FlashMLA branch, extending the C++ and Python APIs to handle FP8 data types, and adding robust unit tests to ensure the correctness of the new FP8 kernel. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for a custom fp8 flashmla kernel by updating the FlashMLA dependency, adding new C++/CUDA source files, and creating the corresponding Python bindings. The changes look mostly correct, but I've identified a critical issue in the Python wrapper flash_mla_with_kvcache where new parameters are used without being added to the function signature, which will cause a runtime error. Additionally, there are a couple of issues in the tests: one existing test seems to be broken by the changes, and the new test for the fp8 kernel doesn't correctly exercise the new code path for metadata generation. I've also noted a minor copy-paste error in a C++ header comment that could cause confusion.
| const at::Tensor& kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or | ||
| // num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) |
There was a problem hiding this comment.
This comment appears to be a copy-paste from the fwd_kvcache_mla function declaration. Since fwd_kvcache_mla_fp8 is specific to FP8, the conditional parts of the comment are confusing and unnecessary. Please simplify the comment to describe only the FP8 layout.
| const at::Tensor& kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or | |
| // num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) | |
| const at::Tensor& kcache, // num_blocks x num_heads_k x (page_block_size*656) |
|
LGTM |
|
Hi, could I know why we're putting FlashMLA into the sglang kernel, and where I can see the related plans? |
Hi, @HanHan009527 We didn’t directly put flashmla into the sgl kernel. You can see this in my stack PR. Our integration approach is the same as vLLM’s. The key reason is that compiling this kernel ourselves helps us maintain a stable sgl-kernel wheel. flashmla still uses pybind, which brings torch/cuda/python version constraints during integration. |
|
Do you have any concerns about this PR, or are there any other integration-related questions I can answer for you? |
get it,thanks |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
sgl-kernel fixed. |
|
Custom fp8 flashmla kernel is only covered in sgl-kernel test, so as long as this passes it will be OK |
Motivation
Add custom fp8 flashmla kernel, which is already used in sglang. But we not support it.
Stack PR: sgl-project/FlashMLA#1
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist