[Feature][JIT Kernel] Fused TP QK norm For Minimax#20673
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant performance optimization for tensor-parallel operations by implementing a JIT-compiled fused QK norm kernel. This new kernel, adapted from TensorRT-LLM, aims to enhance memory access patterns and leverage the custom all-reduce v2 for efficient distributed computation. The changes include the addition of new CUDA kernels, Python bindings, and dedicated benchmarks and tests, culminating in its integration into the MiniMaxM2 model architecture. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Performance result (q_dim = 6144, k_dim = 1024, TP=4): H200
B200
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant performance optimization by adding a JIT-compiled fused kernel for tensor-parallel QK normalization, adapted from TensorRT-LLM. It also brings in a new, more flexible JIT-based custom all-reduce framework (v2) that the fused kernel leverages. The changes are extensive, including new C++ CUDA kernels, Python wrappers, comprehensive benchmarks, and correctness tests. The refactoring of the existing custom all-reduce infrastructure to support this new implementation is also well-done. I've found one critical issue regarding the mathematical correctness of the RMSNorm calculation within the new fused kernel, which I've detailed in a specific comment. Once that is addressed, this will be an excellent contribution.
5b1da52 to
f418327
Compare
|
/tag-and-rerun-ci |
7262bd2 to
7aca6f0
Compare
| self._world_size = get_tensor_model_parallel_world_size() | ||
| self._eps = q_norm.variance_epsilon | ||
| self._cpu_group = get_tp_group().cpu_group | ||
| use_fused_norm = get_bool_env_var("SGLANG_USE_FUSED_PARALLEL_QKNORM") |
There was a problem hiding this comment.
Should we add this environment variable to doc?
There was a problem hiding this comment.
Can this be a server arg instead of an env var?
74ca882 to
7a21611
Compare
|
/rerun-failed-ci try again try again |
f1baa10 to
a6d0557
Compare
Co-authored-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
3c08b19 to
d1efab1
Compare
Co-authored-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
|
cc @trevor-m |
Co-authored-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Motivation
NVIDIA/TensorRT-LLM#12163
Adapted from trt-llm kernels. Special thanks to @jmydurant. We mainly optimize the memory access and reuse the custom all reduce v2 in SGLang.
Should be merged after #19880
Modifications
Accuracy Tests
Benchmarking and Profiling
Decode performance:
Before: 150 tps; After: 157 tps (both already using JIT custom all reduce)
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci