[Qwen3-Next] Add cutedsl decode/mtp kernel with transposed ssm_state and prefill gluon kernel for blackwell.#17981
[Qwen3-Next] Add cutedsl decode/mtp kernel with transposed ssm_state and prefill gluon kernel for blackwell.#17981Jon-WZQ wants to merge 14 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @Jon-WZQ, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant performance optimizations for the Qwen3-Next model, particularly for its decode and multi-token prediction phases on Blackwell hardware. By leveraging custom CuTeDSL kernels and an optimized memory layout for the SSM state, the changes aim to maximize hardware utilization and reduce inference latency. The modifications are designed to be seamlessly integrated and activated via configuration, ensuring both efficiency and compatibility. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces new CuTeDSL kernels for Qwen3-Next models on Blackwell architecture, aiming to improve performance by using a transposed state layout for better memory access patterns. The changes are comprehensive, including the new kernel implementation, extensive tests for precision and performance, and integration into the model execution backend.
My review identified a couple of issues: a typo in the test reporting script that would cause a NameError, and an incorrect double transposition of the SSM state in the attention backend which would result in an incorrect memory layout for the new kernel. Both are high-severity issues that should be addressed. Otherwise, the new kernel implementation and its integration appear well-structured and follow good practices for performance-critical code.
| if self.use_cutedsl_transpose: | ||
| recurrent_state = ssm_states.transpose(-2, -1) | ||
| else: | ||
| recurrent_state = ssm_states |
There was a problem hiding this comment.
The ssm_states tensor is already prepared with the correct transposed memory layout in memory_pool.py when use_cutedsl_transpose is true. Transposing it again here with ssm_states.transpose(-2, -1) will revert it to a row-major layout, which is incorrect for the transposed kernel. This if/else block should be removed and recurrent_state should be directly assigned ssm_states.
recurrent_state = ssm_states| else fused_sigmoid_gating_delta_rule_update | ||
| ) | ||
| self.use_cutedsl_transpose = Envs.SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE.get() | ||
| rank0_log(f"CuTe DSL GDN decode enabled: use_cutedsl: {use_cutedsl}, use_cutedsl_transpose: {self.use_cutedsl_transpose}") |
There was a problem hiding this comment.
Should these two environment variables be checked for mutual exclusivity?
There was a problem hiding this comment.
Yes, we have verified end-to-end precision across three configurations: the default kernel, SGLANG_USE_CUTEDSL_GDN_DECODE=1, and SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE=1.
| if use_cutedsl | ||
| else fused_sigmoid_gating_delta_rule_update | ||
| ) | ||
| self.use_cutedsl_transpose = Envs.SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE.get() |
There was a problem hiding this comment.
| self.use_cutedsl_transpose = Envs.SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE.get() | |
| use_cutedsl_transpose = Envs.SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE.get() |
There was a problem hiding this comment.
We will use the env SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE in forward_extend and forward_decode.
| if a.dim() == 2: | ||
| a = a.unsqueeze(0) | ||
|
|
||
| q_ = from_dlpack(q.detach(), assumed_align=16) |
There was a problem hiding this comment.
Should we add an address align check?
|
Any Qwen3-Next end2end acc can be reported? |
|
|
/tag-and-rerun-ci |
…formance. Co-authored-by: dusan <dusan.du1006@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
26281f3 to
f3534f7
Compare
|
/rerun-failed-ci |
|
@BBuf Can you help take a look at the ci tests? We have checked all the ci outputs. And we think there are some errors unrelated to our patch. |
| HV = mV.shape[2] | ||
| blocks_per_head = mK.shape[-1] // BV | ||
|
|
||
| assert T // CACHE_STEPS == B, "batch * CACHE_STEPS must be equal to T" |
There was a problem hiding this comment.
If cache_steps not input, the cache_steps = T, this asserts fallback to T//T==B, namely B must be 1. That means by default bs > 1 will assert, since this kernel is for decode only, it might need to be addressed in function name, or add some comment.
There was a problem hiding this comment.
OK, we will add some comments.
| ga = cute.zipped_divide(ma, (1, 1, 1)) | ||
| gdt_bias = mdt_bias | ||
|
|
||
| B = mQ.shape[0] if cu_seqlens is None else cu_seqlens.shape[0] - 1 |
There was a problem hiding this comment.
[nit] In some places it uses if cu_seqlens is None, while in the same file L947 it uses if cu_seqlens != None:. Keeping the style unified is decent.
- benchmark_gdn_transpose_vs_flashinfer.py: compares SGLang PR sgl-project#17981 (cutedsl transpose kernel) vs FlashInfer PR sgl-project#2498 (gdn_kernels) - T=1: sigmoid decode kernel vs gated_delta_rule (both bf16 state) - T>1: MTP kernel vs gated_delta_rule_mtp (both fp32 state) - Correctness verified for T=1 and T>1 (g pre-computed from A_log/a/dt_bias) - README_GDN_FLASHINFER_VS_SGLANG.md: benchmark results on B200 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
In the latest commit, we tune y_threads to 4. Here is the latest results. MTP Performance Test Results (BF16, Draft Tokens 3) script: test_cutedsl_gdn.py
Besides, we also compare with the #19150 (Impressed Work!) script: benchmark_gdn_transpose_vs_flashinfer.py
|
… --linear-attn-backend. 2. Use ssm_state in shape [V, K]. 3. Rename kernel.
Motivation
For Qwen3-Next model, we observed current preill/decode kernels cannot fully utilize the hardware efficiency of Blackwell. Thus, we implement gluon based prefill kernels and cutedsl based decode kernels for better performance. Here is a blog about this PR. https://zhuanlan.zhihu.com/p/2003887397411258684
Modifications
Accuracy Tests
We use
python python/sglang/jit_kernel/tests/test_cutedsl_gdn.pyfor decode kernel accuracy and performance test.Decode Precision Test Results
MTP Decode Precision Test Results
Benchmarking and Profiling
Performance Test Results (decode mode)
head_q=head_k=16, head_v=32, dim=128
BF16 Results
FP32 Results
MTP Performance Test Results (Draft Tokens 3)
BF16
F32
Qwen3-Next-80B-A3B-Instruct-FP8 E2E test
SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE=1 python -m sglang.launch_server --model-path Qwen3-Next-80B-A3B-Instruct-FP8 --tp-size 2 --speculative-num-steps=2 --speculative-eagle-topk=1 --speculative-num-draft-tokens=3 --speculative-draft-model-path Qwen3-Next-80B-A3B-Instruct-FP8 --speculative-algorithm NEXTNaccuracy check
python benchmark/gsm8k/bench_sglang.py --num-questions 128
Accuracy: 0.961
Invalid: 0.000
Latency: 10.909 s
Output throughput: 1834.008 token/s
baseline vs cutedsl_gdn_transpose, In:Out=1:1500
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci