Skip to content

[jit_kernel] Add fused_qknorm_rope JIT kernel#19059

Merged
BBuf merged 17 commits intosgl-project:mainfrom
Johnsonms:fused-qknorm-rope-jit
Mar 27, 2026
Merged

[jit_kernel] Add fused_qknorm_rope JIT kernel#19059
BBuf merged 17 commits intosgl-project:mainfrom
Johnsonms:fused-qknorm-rope-jit

Conversation

@Johnsonms
Copy link
Copy Markdown
Contributor

@Johnsonms Johnsonms commented Feb 20, 2026

Motivation

Part of tracking issue #17865 — migrate sgl-kernel AOT kernels to the lightweight python/sglang/jit_kernel/ system.

This PR ports sgl-kernel/csrc/moe/fused_qknorm_rope_kernel.cu to a JIT kernel, matching the existing
sgl_kernel.fused_qk_norm_rope call signature so it can serve as a drop-in replacement.

Key Design Decisions

The original file has no CUTLASS or FlashInfer dependencies — only standard CUDA intrinsics (cuda_bf16.h,
__shfl_xor_sync, __sincosf, etc.). All utilities (packed_as_uint, warp_reduce_sum, div_up, compute_freq_yarn) are inlined
directly in the .cuh file without any heavyweight third-party headers.

One correctness fix is included: the original NeoX active_mask uses (1u << rotary_lanes) - 1, which is undefined behavior
in C++ when rotary_lanes == 32 (the common full-rotary case). The JIT kernel replaces this with 0xffffffffu >> (32 -
rotary_lanes).

Modification

python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh

  • fusedQKNormRopeKernel<head_dim, interleave> — one warp per (token, head): packed bf16 loads → RMSNorm (warp-reduce) →
    RoPE (interleave or NeoX, YaRN-aware) → store
  • fused_qk_norm_rope host launcher: runtime validation via RuntimeCheck, dispatches to 6 instantiations: head_dim ∈ {64,
    128, 256} × interleave ∈ {true, false}

python/sglang/jit_kernel/fused_qknorm_rope.py

  • cache_once/load_jit wrapper; single JIT module (runtime dispatch, no per-head-dim caching)
  • fused_qk_norm_rope_out registered as a custom op (mutates qkv in-place)
  • fused_qk_norm_rope(..., rotary_dim=None) public API; rotary_dim=None defaults to head_dim

python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py

  • 34 correctness tests: vs PyTorch reference across head_dim × num_tokens × is_neox, partial rotary, YaRN scaling,
    default rotary_dim, optional AOT cross-validation

python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py

  • JIT vs AOT throughput comparison using triton.testing.perf_report, CI-friendly reduced ranges

Accuracy Tests

34 passed in 2.67s
Correctness: 34/34 passed ✓
Correctness: JIT and AOT are bit-identical across all tested configs
(max_err=0).

python -m pytest python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py -v
image

Benchmarking and Profiling

python python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py

image

Performance: JIT matches AOT within noise (<0.5%) across all 15 configurations (1–4096 tokens × head_dim 64/128/256). No meaningful difference.

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Johnsonms, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates a highly optimized, self-contained JIT CUDA kernel for performing fused RMSNorm and Rotary Positional Embeddings on query and key tensors. This enhancement aims to improve the performance and efficiency of attention mechanisms within the framework by consolidating multiple operations into a single, specialized kernel, while also reducing reliance on external, heavier dependencies.

Highlights

  • New JIT Kernel for Fused QK-Norm RoPE: A new JIT-compiled CUDA kernel, fused_qknorm_rope, has been introduced, porting the functionality previously found in sgl-kernel/csrc/moe/fused_qknorm_rope_kernel.cu.
  • Optimized Fused Operation: The kernel efficiently fuses per-head RMSNorm and Rotary Positional Embeddings (RoPE) for Query (Q) and Key (K) tensors in a single warp-level CUDA pass, operating in-place on a packed QKV tensor.
  • Reduced External Dependencies: The implementation is self-contained within a .cuh file, eliminating dependencies on external libraries like CUTLASS or FlashInfer.
  • Flexible RoPE Support: The kernel supports various configurations including head_dim values of 64, 128, and 256, both interleave (GPT-J) and NeoX RoPE styles, YaRN frequency scaling, and partial rotary embeddings (rotary_dim <= head_dim).
  • Bug Fix for NeoX RoPE: A potential undefined behavior in the NeoX active_mask computation when rotary_lanes == 32 has been addressed and fixed.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py
    • Added a new benchmark script to evaluate the performance of the JIT fused_qk_norm_rope against its AOT sgl_kernel counterpart.
    • Included configurations for various num_tokens and head_dim values to cover typical LLM scenarios.
    • Implemented a quick correctness diff check between the JIT and AOT implementations.
  • python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh
    • Introduced a new CUDA header file containing the fusedQKNormRopeKernel for fused QK-Norm and RoPE.
    • Implemented warp-level sum reduction and YaRN-aware frequency computation.
    • Provided support for both interleave (GPT-J) and NeoX style RoPE, including a fix for NeoX active_mask UB.
    • Included a host-side TVM-FFI entry point for the kernel, with comprehensive runtime checks for input tensors and parameters.
  • python/sglang/jit_kernel/fused_qknorm_rope.py
    • Created a new Python module to wrap the fused_qk_norm_rope JIT CUDA kernel.
    • Defined fused_qk_norm_rope_out as a custom operation that directly calls the CUDA kernel.
    • Provided a user-friendly fused_qk_norm_rope function that handles default rotary_dim behavior and matches the sgl_kernel signature.
  • python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py
    • Added a new test suite for the fused_qknorm_rope JIT kernel.
    • Included correctness tests against a pure-PyTorch reference for various head_dim, num_tokens, and is_neox configurations.
    • Verified partial rotary functionality and YaRN scaling behavior.
    • Confirmed that rotary_dim=None correctly defaults to head_dim.
    • Implemented cross-validation tests against the AOT sgl_kernel implementation when available.
Activity
  • No specific activity (comments, reviews, or progress updates) was provided in the context for this pull request.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces a new JIT kernel for fused QK RMSNorm + RoPE operations, along with benchmarks and correctness tests. The implementation looks solid, and the tests cover various configurations, including partial rotary and YaRN scaling. The code is well-structured and follows good practices for CUDA kernel development and Python integration. The use of static_assert for compile-time checks and RuntimeCheck for runtime validations is good for robustness. The benchmark script provides a clear comparison between the JIT and AOT implementations.

@yuan-luo
Copy link
Copy Markdown
Collaborator

I think the performance comparison doesn't make sense. Because the JIT is faster from the second run onward. We may need to compare that round's statistics.

@yuan-luo
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@DarkSharpness
Copy link
Copy Markdown
Collaborator

@Johnsonms do you know what's the root cause of the regression?

@Johnsonms
Copy link
Copy Markdown
Contributor Author

Johnsonms commented Feb 20, 2026

@Johnsonms do you know what's the root cause of the regression?
Fixed two issues, the regression is gone,
Updated in the PR description, Thanks @yuan-luo @DarkSharpness :

  1. Issue: ~2x perf gap                                                        
  Root Cause: AOT compiled with --use_fast_math; JIT wasn't. powf() in compute_freq_yarn is the bottleneck.                                     
  Fix: Added extra_cuda_cflags=["--use_fast_math"] to load_jit
 
  2. Issue: False MISMATCH (head_dim=256, is_neox=False)
  Root Cause: atol=1e-3 in calculate_diff() was tighter than bfloat16 precision (~7.8e-3/ULP). JIT was actually more correct than AOT vs float32 reference.
  Fix: Raised atol=1e-3 → 1e-2 in calculate_diff(). With --use_fast_math both
    now produce bit-identical outputs (max_err=0).
image

@Johnsonms Johnsonms force-pushed the fused-qknorm-rope-jit branch from 3f14cd1 to f68da77 Compare February 28, 2026 06:06
@yuan-luo yuan-luo self-requested a review February 28, 2026 06:37
@yuan-luo
Copy link
Copy Markdown
Collaborator

We also need a python wrapper for this kernel, then we can utilize this kernel and replace the original one.

@Johnsonms Johnsonms force-pushed the fused-qknorm-rope-jit branch 6 times, most recently from 0920dc7 to f274334 Compare February 28, 2026 08:01
@Johnsonms Johnsonms force-pushed the fused-qknorm-rope-jit branch from f274334 to c3b4efe Compare February 28, 2026 08:06
@Johnsonms
Copy link
Copy Markdown
Contributor Author

We also need a python wrapper for this kernel, then we can utilize this kernel and replace the original one.

Done, checked and confirmed

export SGLANG_TORCH_PROFILER_DIR=/scratch/johnson/profile/
python -m sglang.launch_server \
    --model-path Qwen/Qwen3-30B-A3B \
    --enable-fused-qk-norm-rope \
    --tp-size 4
image image

@Johnsonms Johnsonms force-pushed the fused-qknorm-rope-jit branch from 5795543 to 2f79445 Compare March 1, 2026 00:15
@Johnsonms Johnsonms requested a review from DarkSharpness March 1, 2026 00:21
@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 2, 2026

Please fix lint.

@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 2, 2026

Overall LGTM.

…en3_moe, and optimise kernel

- Add Python wrapper and integrate fused_qk_norm_rope into qwen3_moe
- Pass head_dim and is_neox as JIT_HEAD_DIM/JIT_INTERLEAVE defines, compiling
  1 kernel instantiation instead of 6 to reduce JIT compile time
- Replace packed_as_uint with device::AlignedVector for cleaner aligned ld/st
- Replace raw CUDA intrinsics with device::cast from sgl_kernel/type.cuh
- Replace local warp_reduce_sum with device::warp::reduce_sum from sgl_kernel/warp.cuh
- Add bench_compiletime_qknorm_rope.py to measure JIT compile time before/after
- Use can_use_fused_qk_norm_rope() at model init in qwen3_moe to gate
  the fused kernel, so unsupported configs fall back gracefully in the
  Python layer instead of hitting the CUDA static_assert at compile time
- Reformat CASES list in bench_compiletime_qknorm_rope for readability
…_ceil

- Remove -DJIT_HEAD_DIM and -DJIT_INTERLEAVE compiler macros; dispatch to
  the correct template instantiation at runtime via if/else in the host function
- Replace custom div_up with host::div_ceil from sgl_kernel/utils.h
- Compile a single module for all head_dim/interleave combinations
@Johnsonms Johnsonms force-pushed the fused-qknorm-rope-jit branch from 5bd4ee8 to b683e5b Compare March 7, 2026 01:04
@Johnsonms Johnsonms force-pushed the fused-qknorm-rope-jit branch from b683e5b to 4c8d356 Compare March 7, 2026 01:19
@yuan-luo
Copy link
Copy Markdown
Collaborator

yuan-luo commented Mar 7, 2026

Let's wait for the CI.

fused_qk_norm_rope is now a template function parameterised on HEAD_DIM
and INTERLEAVE. Python passes the concrete types via cuda_wrappers
(e.g. fused_qk_norm_rope<128, false>), so TVM-FFI exports exactly one
instantiation per (head_dim, is_neox) config at compile time.

This removes the #ifdef JIT_HEAD_DIM / JIT_INTERLEAVE macros, the
runtime if-else dispatch chain, and the head_dim / is_neox parameters
from the TVM-FFI call site. Also removes the outdated compile-time
benchmark script and simplifies the test to use the full parameter
range in CI.
@Johnsonms Johnsonms force-pushed the fused-qknorm-rope-jit branch from cd3628c to 87fd214 Compare March 10, 2026 03:56
@Johnsonms Johnsonms requested a review from DarkSharpness March 10, 2026 04:03
@BBuf BBuf merged commit c531be4 into sgl-project:main Mar 27, 2026
74 of 106 checks passed
satyamk7054 pushed a commit to satyamk7054/sglang that referenced this pull request Apr 3, 2026
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
JustinTong0323 pushed a commit to JustinTong0323/sglang that referenced this pull request Apr 7, 2026
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants