[jit_kernel] Add fused_qknorm_rope JIT kernel#19059
Conversation
Summary of ChangesHello @Johnsonms, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a highly optimized, self-contained JIT CUDA kernel for performing fused RMSNorm and Rotary Positional Embeddings on query and key tensors. This enhancement aims to improve the performance and efficiency of attention mechanisms within the framework by consolidating multiple operations into a single, specialized kernel, while also reducing reliance on external, heavier dependencies. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
The pull request introduces a new JIT kernel for fused QK RMSNorm + RoPE operations, along with benchmarks and correctness tests. The implementation looks solid, and the tests cover various configurations, including partial rotary and YaRN scaling. The code is well-structured and follows good practices for CUDA kernel development and Python integration. The use of static_assert for compile-time checks and RuntimeCheck for runtime validations is good for robustness. The benchmark script provides a clear comparison between the JIT and AOT implementations.
|
I think the performance comparison doesn't make sense. Because the JIT is faster from the second run onward. We may need to compare that round's statistics. |
|
/tag-and-rerun-ci |
|
@Johnsonms do you know what's the root cause of the regression? |
|
3f14cd1 to
f68da77
Compare
|
We also need a python wrapper for this kernel, then we can utilize this kernel and replace the original one. |
0920dc7 to
f274334
Compare
python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh
Outdated
Show resolved
Hide resolved
python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh
Outdated
Show resolved
Hide resolved
python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh
Outdated
Show resolved
Hide resolved
f274334 to
c3b4efe
Compare
5795543 to
2f79445
Compare
|
Please fix lint. |
|
Overall LGTM. |
python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh
Outdated
Show resolved
Hide resolved
…en3_moe, and optimise kernel - Add Python wrapper and integrate fused_qk_norm_rope into qwen3_moe - Pass head_dim and is_neox as JIT_HEAD_DIM/JIT_INTERLEAVE defines, compiling 1 kernel instantiation instead of 6 to reduce JIT compile time - Replace packed_as_uint with device::AlignedVector for cleaner aligned ld/st - Replace raw CUDA intrinsics with device::cast from sgl_kernel/type.cuh - Replace local warp_reduce_sum with device::warp::reduce_sum from sgl_kernel/warp.cuh - Add bench_compiletime_qknorm_rope.py to measure JIT compile time before/after
- Use can_use_fused_qk_norm_rope() at model init in qwen3_moe to gate the fused kernel, so unsupported configs fall back gracefully in the Python layer instead of hitting the CUDA static_assert at compile time - Reformat CASES list in bench_compiletime_qknorm_rope for readability
…_ceil - Remove -DJIT_HEAD_DIM and -DJIT_INTERLEAVE compiler macros; dispatch to the correct template instantiation at runtime via if/else in the host function - Replace custom div_up with host::div_ceil from sgl_kernel/utils.h - Compile a single module for all head_dim/interleave combinations
5bd4ee8 to
b683e5b
Compare
b683e5b to
4c8d356
Compare
|
Let's wait for the CI. |
python/sglang/jit_kernel/benchmark/bench_compiletime_qknorm_rope.py
Outdated
Show resolved
Hide resolved
python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh
Outdated
Show resolved
Hide resolved
fused_qk_norm_rope is now a template function parameterised on HEAD_DIM and INTERLEAVE. Python passes the concrete types via cuda_wrappers (e.g. fused_qk_norm_rope<128, false>), so TVM-FFI exports exactly one instantiation per (head_dim, is_neox) config at compile time. This removes the #ifdef JIT_HEAD_DIM / JIT_INTERLEAVE macros, the runtime if-else dispatch chain, and the head_dim / is_neox parameters from the TVM-FFI call site. Also removes the outdated compile-time benchmark script and simplifies the test to use the full parameter range in CI.
cd3628c to
87fd214
Compare
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>



Motivation
Part of tracking issue #17865 — migrate sgl-kernel AOT kernels to the lightweight python/sglang/jit_kernel/ system.
This PR ports sgl-kernel/csrc/moe/fused_qknorm_rope_kernel.cu to a JIT kernel, matching the existing
sgl_kernel.fused_qk_norm_rope call signature so it can serve as a drop-in replacement.
Key Design Decisions
The original file has no CUTLASS or FlashInfer dependencies — only standard CUDA intrinsics (cuda_bf16.h,
__shfl_xor_sync, __sincosf, etc.). All utilities (packed_as_uint, warp_reduce_sum, div_up, compute_freq_yarn) are inlined
directly in the .cuh file without any heavyweight third-party headers.
One correctness fix is included: the original NeoX active_mask uses (1u << rotary_lanes) - 1, which is undefined behavior
in C++ when rotary_lanes == 32 (the common full-rotary case). The JIT kernel replaces this with 0xffffffffu >> (32 -
rotary_lanes).
Modification
python/sglang/jit_kernel/csrc/elementwise/fused_qknorm_rope.cuh
RoPE (interleave or NeoX, YaRN-aware) → store
128, 256} × interleave ∈ {true, false}
python/sglang/jit_kernel/fused_qknorm_rope.py
python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py
default rotary_dim, optional AOT cross-validation
python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.py
Accuracy Tests
34 passed in 2.67s
Correctness: 34/34 passed ✓
Correctness: JIT and AOT are bit-identical across all tested configs
(max_err=0).
python -m pytest python/sglang/jit_kernel/tests/test_fused_qknorm_rope.py -vBenchmarking and Profiling
python python/sglang/jit_kernel/benchmark/bench_fused_qknorm_rope.pyPerformance: JIT matches AOT within noise (<0.5%) across all 15 configurations (1–4096 tokens × head_dim 64/128/256). No meaningful difference.
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci