Skip to content

[Feature][JIT Kernel] Fused TP QK norm For Minimax#20673

Merged
BBuf merged 15 commits into
sgl-project:mainfrom
DarkSharpness:misc_qknorm_ar
Apr 13, 2026
Merged

[Feature][JIT Kernel] Fused TP QK norm For Minimax#20673
BBuf merged 15 commits into
sgl-project:mainfrom
DarkSharpness:misc_qknorm_ar

Conversation

@DarkSharpness
Copy link
Copy Markdown
Collaborator

@DarkSharpness DarkSharpness commented Mar 16, 2026

Motivation

NVIDIA/TensorRT-LLM#12163

Adapted from trt-llm kernels. Special thanks to @jmydurant. We mainly optimize the memory access and reuse the custom all reduce v2 in SGLang.

Should be merged after #19880

Modifications

Accuracy Tests

Benchmarking and Profiling

python -m sglang.launch_server \                                                                          
    --model-path MiniMaxAI/MiniMax-M2.5 \
    --tp-size 4 \
    --tool-call-parser minimax-m2 \
    --reasoning-parser minimax-append-think \
    --host 0.0.0.0 \
    --trust-remote-code

python -m sglang.test.send_one

Decode performance:

Before: 150 tps; After: 157 tps (both already using JIT custom all reduce)

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance optimization for tensor-parallel operations by implementing a JIT-compiled fused QK norm kernel. This new kernel, adapted from TensorRT-LLM, aims to enhance memory access patterns and leverage the custom all-reduce v2 for efficient distributed computation. The changes include the addition of new CUDA kernels, Python bindings, and dedicated benchmarks and tests, culminating in its integration into the MiniMaxM2 model architecture.

Highlights

  • Fused QK Norm Kernel: Introduced a new JIT-compiled kernel for fused tensor-parallel Query-Key (QK) normalization, adapted from TensorRT-LLM, to optimize memory access and computation.
  • JIT Custom All-Reduce Integration: Implemented a new Python module (all_reduce.py) to expose the JIT custom all-reduce and fused QK norm functionalities, along with their C++ CUDA kernel definitions.
  • Benchmarking and Testing: Added comprehensive benchmarking and correctness test suites for both the JIT custom all-reduce and the new fused QK norm to validate performance and accuracy.
  • Refactored All-Reduce Logic: Refactored the existing custom all-reduce (v1) to utilize a new utility function for NVLink and P2P capability checks, and introduced an opt-in mechanism for the JIT-compiled v2 implementation.
  • Model Integration: Integrated the fused QK norm into the MiniMaxM2 model, replacing the naive implementation for improved performance in tensor-parallel setups.
  • Infrastructure Updates: Updated .clang-format rules and introduced new C++ CUDA headers for distributed communication primitives and FFI tensor utilities to support the new kernels.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/jit_kernel/.clang-format
    • Updated include category regex to be more general.
  • python/sglang/jit_kernel/all_reduce.py
    • Added Python module for JIT custom all-reduce and fused QK norm functionalities.
  • python/sglang/jit_kernel/benchmark/bench_custom_all_reduce.py
    • Added benchmark script for JIT custom all-reduce (v2) against NCCL and AOT custom all-reduce (v1).
  • python/sglang/jit_kernel/benchmark/bench_tp_qknorm.py
    • Added benchmark script for the new fused parallel QK norm.
  • python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_base.cuh
    • Added C++ header for the custom all-reduce base class and FFI registration.
  • python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_pull.cuh
    • Added C++ CUDA kernel for pull-based custom all-reduce.
  • python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_push.cuh
    • Added C++ CUDA kernel for push-based custom all-reduce.
  • python/sglang/jit_kernel/csrc/distributed/tp_qknorm.cuh
    • Added C++ CUDA kernel for the fused parallel QK norm across heads.
  • python/sglang/jit_kernel/csrc/gemm/marlin/marlin.cuh
    • Removed an unused host::div_ceil import.
  • python/sglang/jit_kernel/include/sgl_kernel/distributed/common.cuh
    • Added C++ CUDA header defining distributed communication primitives like Semaphores and controllers.
  • python/sglang/jit_kernel/include/sgl_kernel/distributed/custom_all_reduce.cuh
    • Added C++ CUDA header for the custom all-reduce base class, IPC handle utilities, and reduction implementation.
  • python/sglang/jit_kernel/include/sgl_kernel/ffi.h
    • Added C++ header providing FFI tensor creation and manipulation utilities.
  • python/sglang/jit_kernel/include/sgl_kernel/utils.cuh
    • Added a div_ceil utility function for integer division.
  • python/sglang/jit_kernel/include/sgl_kernel/vec.cuh
    • Updated AlignedVector load and store methods to accept void* pointers for greater flexibility.
  • python/sglang/jit_kernel/include/sgl_kernel/warp.cuh
    • Updated the reduce_sum template to allow specifying the number of threads for warp-level reduction.
  • python/sglang/jit_kernel/tests/test_custom_all_reduce.py
    • Added correctness tests for the JIT custom all-reduce (v2) kernel.
  • python/sglang/jit_kernel/tests/test_tp_qknorm.py
    • Added correctness tests for the fused parallel QK norm.
  • python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
    • Refactored custom all-reduce (v1) initialization and added an opt-in flag for the JIT-compiled v2.
  • python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
    • Moved and consolidated NVLink and P2P capability check logic into a new utility function.
  • python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py
    • Added new Python module implementing the JIT-compiled custom all-reduce v2.
  • python/sglang/srt/models/minimax_m2.py
    • Integrated the new fused QK norm into the MiniMaxM2 model's attention mechanism, replacing the previous naive implementation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@DarkSharpness
Copy link
Copy Markdown
Collaborator Author

DarkSharpness commented Mar 16, 2026

Performance result (q_dim = 6144, k_dim = 1024, TP=4):

H200
q_dim k_dim batch fused_us baseline_us
6144 1024 1 2.6 5.1
6144 1024 2 2.7 5.1
6144 1024 4 2.7 5.1
6144 1024 8 2.8 5.2
6144 1024 16 2.9 5.2
6144 1024 32 2.9 5.3
6144 1024 64 2.9 5.4
6144 1024 128 2.9 5.6
6144 1024 256 3.1 5.9
6144 1024 512 3.5 6.6
6144 1024 1024 4.3 7.6
6144 1024 2048 8.0 11.2
6144 1024 4096 15.0 16.0
6144 1024 8192 28.4 28.5
6144 1024 16384 52.0 52.7
B200
q_dim k_dim batch fused_us baseline_us
6144 1024 1 4.1 6.2
6144 1024 2 4.3 6.4
6144 1024 4 4.4 6.2
6144 1024 8 4.4 6.5
6144 1024 16 4.4 6.5
6144 1024 32 4.5 6.6
6144 1024 64 4.5 6.9
6144 1024 128 4.6 6.8
6144 1024 256 4.7 6.8
6144 1024 512 5.0 7.5
6144 1024 1024 6.4 8.5
6144 1024 2048 8.5 10.1
6144 1024 4096 14.5 14.6
6144 1024 8192 23.8 20.6
6144 1024 16384 42.0 35.1

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization by adding a JIT-compiled fused kernel for tensor-parallel QK normalization, adapted from TensorRT-LLM. It also brings in a new, more flexible JIT-based custom all-reduce framework (v2) that the fused kernel leverages. The changes are extensive, including new C++ CUDA kernels, Python wrappers, comprehensive benchmarks, and correctness tests. The refactoring of the existing custom all-reduce infrastructure to support this new implementation is also well-done. I've found one critical issue regarding the mathematical correctness of the RMSNorm calculation within the new fused kernel, which I've detailed in a specific comment. Once that is addressed, this will be an excellent contribution.

Comment thread python/sglang/jit_kernel/csrc/distributed/tp_qknorm.cuh
Comment thread python/sglang/jit_kernel/csrc/distributed/tp_qknorm.cuh
@DarkSharpness DarkSharpness changed the title [Feature][JIT Kernel] Fused TP QK norm [Feature][JIT Kernel] Fused TP QK norm For Minimax Mar 18, 2026
@DarkSharpness DarkSharpness marked this pull request as ready for review March 20, 2026 10:27
@DarkSharpness DarkSharpness reopened this Mar 20, 2026
@DarkSharpness
Copy link
Copy Markdown
Collaborator Author

/tag-and-rerun-ci

@DarkSharpness DarkSharpness force-pushed the misc_qknorm_ar branch 2 times, most recently from 7262bd2 to 7aca6f0 Compare March 26, 2026 04:59
Comment thread python/sglang/jit_kernel/benchmark/bench_tp_qknorm.py Outdated
self._world_size = get_tensor_model_parallel_world_size()
self._eps = q_norm.variance_epsilon
self._cpu_group = get_tp_group().cpu_group
use_fused_norm = get_bool_env_var("SGLANG_USE_FUSED_PARALLEL_QKNORM")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add this environment variable to doc?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a server arg instead of an env var?

Comment thread python/sglang/jit_kernel/csrc/distributed/tp_qknorm.cuh
Comment thread python/sglang/srt/models/minimax_m2.py Outdated
@DarkSharpness
Copy link
Copy Markdown
Collaborator Author

DarkSharpness commented Apr 3, 2026

/rerun-failed-ci try again try again

@DarkSharpness DarkSharpness force-pushed the misc_qknorm_ar branch 2 times, most recently from f1baa10 to a6d0557 Compare April 7, 2026 04:48
@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Apr 13, 2026

@BBuf BBuf merged commit 314d6ec into sgl-project:main Apr 13, 2026
284 of 318 checks passed
pyc96 pushed a commit to pyc96/sglang that referenced this pull request Apr 14, 2026
Co-authored-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
@nvpohanh
Copy link
Copy Markdown
Collaborator

cc @trevor-m

@DarkSharpness DarkSharpness deleted the misc_qknorm_ar branch April 14, 2026 11:52
yhyang201 pushed a commit to yhyang201/sglang that referenced this pull request Apr 22, 2026
Co-authored-by: Mingyang Jiang <13463932+jmydurant@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants