Skip to content

[Qwen3-Next] Add cutedsl decode/mtp kernel with transposed ssm_state and prefill gluon kernel for blackwell.#17981

Open
Jon-WZQ wants to merge 14 commits intosgl-project:mainfrom
Jon-WZQ:cutedsl_decode_transpose
Open

[Qwen3-Next] Add cutedsl decode/mtp kernel with transposed ssm_state and prefill gluon kernel for blackwell.#17981
Jon-WZQ wants to merge 14 commits intosgl-project:mainfrom
Jon-WZQ:cutedsl_decode_transpose

Conversation

@Jon-WZQ
Copy link
Copy Markdown

@Jon-WZQ Jon-WZQ commented Jan 30, 2026

Motivation

For Qwen3-Next model, we observed current preill/decode kernels cannot fully utilize the hardware efficiency of Blackwell. Thus, we implement gluon based prefill kernels and cutedsl based decode kernels for better performance. Here is a blog about this PR. https://zhuanlan.zhihu.com/p/2003887397411258684

Modifications

  1. For prefill, we rewrite recompute_w_u_fwd_kernel/chunk_fwd_kernel_o kernels in gluon with triton>=3.6.0. [Qwen3-Next] Optimize Prefill Kernel, add GDN Gluon kernel and optimize cumsum kernel #17983
  2. For decode, we use cutedsl to write the fuse_recurrent kernel for both decode/mtp. Besides, we transpose the shape of mamba_ssm_state from [B,H,K,V] to [B,H,V,K], which makes the memory contiguous in K dim. We achieve this only in ssm_state initialization phase and prefill kernel. It will introduce almost no additional overhead.

Accuracy Tests

We use python python/sglang/jit_kernel/tests/test_cutedsl_gdn.py for decode kernel accuracy and performance test.

Decode Precision Test Results

B(BS) kernel dtype max_diff mean_diff fail_rate check
8 cutedsl_gdn f32 0.00e+00 1.00e-09 0.00% ok
8 cutedsl_gdn_transpose bf16 6.10e-05 1.87e-09 0.00% ok
8 cutedsl_gdn_transpose f32 4.88e-04 1.49e-08 0.00% ok
16 cutedsl_gdn f32 4.88e-04 1.26e-08 0.00% ok
16 cutedsl_gdn_transpose bf16 1.22e-04 1.93e-09 0.00% ok
16 cutedsl_gdn_transpose f32 2.44e-04 5.90e-09 0.00% ok
32 cutedsl_gdn f32 9.77e-04 9.43e-09 0.00% ok
32 cutedsl_gdn_transpose bf16 6.10e-05 1.68e-09 0.00% ok
32 cutedsl_gdn_transpose f32 9.77e-04 1.03e-08 0.00% ok
64 cutedsl_gdn f32 9.77e-04 9.51e-09 0.00% ok
64 cutedsl_gdn_transpose bf16 2.44e-04 3.82e-09 0.00% ok
64 cutedsl_gdn_transpose f32 4.88e-04 2.86e-09 0.00% ok
128 cutedsl_gdn f32 4.88e-04 5.82e-09 0.00% ok
128 cutedsl_gdn_transpose bf16 4.88e-04 4.48e-09 0.00% ok
128 cutedsl_gdn_transpose f32 2.44e-04 2.98e-09 0.00% ok
256 cutedsl_gdn f32 9.77e-04 4.75e-09 0.00% ok
256 cutedsl_gdn_transpose bf16 4.88e-04 5.01e-09 0.00% ok
256 cutedsl_gdn_transpose f32 9.77e-04 4.01e-09 0.00% ok

MTP Decode Precision Test Results

T(BS) max_diff mean_diff fail_rate check
4 3.81e-06 1.56e-10 0.00% ok
4 1.53e-05 6.14e-10 0.00% ok
8 4.88e-04 9.79e-09 0.00% ok
8 9.77e-04 1.84e-08 0.00% ok
16 4.88e-04 5.67e-09 0.00% ok
16 4.88e-04 7.04e-09 0.00% ok
32 1.95e-03 1.61e-08 0.00% ok
32 1.25e-01 3.24e-07 0.00% ok
48 9.77e-04 1.17e-08 0.00% ok
48 3.91e-03 1.34e-08 0.00% ok

Benchmarking and Profiling

Performance Test Results (decode mode)

head_q=head_k=16, head_v=32, dim=128

BF16 Results

B(BS) Triton (us) CuTeDSL (us) speedup check
8 9.95± 0.03 0.00± 0.00 0.00x ok [cutedsl_gdn, bf16]
8 9.95± 0.03 6.14± 0.02 1.62x ok [cutedsl_gdn_transpose, bf16]
16 16.92± 0.02 0.00± 0.00 0.00x ok [cutedsl_gdn, bf16]
16 16.92± 0.02 10.23± 0.01 1.65x ok [cutedsl_gdn_transpose, bf16]
32 31.53± 0.02 0.00± 0.00 0.00x ok [cutedsl_gdn, bf16]
32 31.53± 0.02 19.07± 0.01 1.65x ok [cutedsl_gdn_transpose, bf16]
64 60.11± 0.01 0.00± 0.00 0.00x ok [cutedsl_gdn, bf16]
64 60.11± 0.01 35.63± 0.02 1.69x ok [cutedsl_gdn_transpose, bf16]
128 119.16± 0.02 0.00± 0.00 0.00x ok [cutedsl_gdn, bf16]
128 119.16± 0.02 72.71± 0.02 1.64x ok [cutedsl_gdn_transpose, bf16]
256 234.96± 0.03 0.00± 0.00 0.00x ok [cutedsl_gdn, bf16]
256 234.96± 0.03 143.86± 0.04 1.63x ok [cutedsl_gdn_transpose, bf16]

FP32 Results

B(BS) Triton (us) CuTeDSL (us) speedup check
8 9.63± 0.01 9.47± 0.01 1.02x ok [cutedsl_gdn, f32]
8 9.63± 0.01 6.54± 0.01 1.47x ok [cutedsl_gdn_transpose, f32]
16 16.95± 0.03 16.64± 0.02 1.02x ok [cutedsl_gdn, f32]
16 16.95± 0.03 11.03± 0.02 1.54x ok [cutedsl_gdn_transpose, f32]
32 30.97± 0.03 26.27± 0.02 1.18x ok [cutedsl_gdn, f32]
32 30.97± 0.03 19.95± 0.01 1.55x ok [cutedsl_gdn_transpose, f32]
64 65.06± 0.05 54.31± 0.04 1.20x ok [cutedsl_gdn, f32]
64 65.06± 0.05 47.35± 0.03 1.37x ok [cutedsl_gdn_transpose, f32]
128 125.14± 0.44 102.60± 0.44 1.22x ok [cutedsl_gdn, f32]
128 125.14± 0.44 90.43± 0.13 1.38x ok [cutedsl_gdn_transpose, f32]
256 248.06± 2.39 200.63± 2.21 1.24x ok [cutedsl_gdn, f32]
256 248.06± 2.39 178.92± 0.80 1.39x ok [cutedsl_gdn_transpose, f32]

MTP Performance Test Results (Draft Tokens 3)

BF16

T(BS) Triton (us) CuTeDSL (us) Speedup
4 12.38± 0.07 9.59± 0.09 1.29x
8 21.82± 0.10 15.79± 0.12 1.38x
16 37.06± 0.10 24.69± 0.08 1.50x
32 69.28± 0.10 44.49± 0.09 1.56x
48 101.32± 0.10 64.72± 0.11 1.57x

F32

T(BS) Triton (us) CuTeDSL (us) Speedup
4 12.92± 0.09 11.48± 0.10 1.13x
8 22.85± 0.13 19.57± 0.17 1.17x
16 41.20± 0.14 33.09± 0.13 1.24x
32 73.65± 0.17 61.37± 0.11 1.20x
48 108.95± 0.16 95.19± 0.12 1.14x

Qwen3-Next-80B-A3B-Instruct-FP8 E2E test
SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE=1 python -m sglang.launch_server --model-path Qwen3-Next-80B-A3B-Instruct-FP8 --tp-size 2 --speculative-num-steps=2 --speculative-eagle-topk=1 --speculative-num-draft-tokens=3 --speculative-draft-model-path Qwen3-Next-80B-A3B-Instruct-FP8 --speculative-algorithm NEXTN

accuracy check
python benchmark/gsm8k/bench_sglang.py --num-questions 128
Accuracy: 0.961
Invalid: 0.000
Latency: 10.909 s
Output throughput: 1834.008 token/s

baseline vs cutedsl_gdn_transpose, In:Out=1:1500

image image

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Jon-WZQ, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant performance optimizations for the Qwen3-Next model, particularly for its decode and multi-token prediction phases on Blackwell hardware. By leveraging custom CuTeDSL kernels and an optimized memory layout for the SSM state, the changes aim to maximize hardware utilization and reduce inference latency. The modifications are designed to be seamlessly integrated and activated via configuration, ensuring both efficiency and compatibility.

Highlights

  • New CuTeDSL Kernels for Qwen3-Next: Implementation of new CuTeDSL-based kernels for Qwen3-Next model's decode and multi-token prediction (MTP) operations, specifically optimized for Blackwell GPUs.
  • Transposed SSM State Memory Layout: Introduction of a transposed mamba_ssm_state memory layout (from [B,H,K,V] to [B,H,V,K]) to enhance memory contiguity and access efficiency in the K dimension.
  • Performance Improvements: Demonstrated significant performance speedups (up to 1.76x for decode and 1.61x for MTP) while maintaining numerical accuracy, as evidenced by comprehensive precision and performance tests.
  • Conditional Kernel Integration: Integration of these new kernels into the HybridLinearAttentionBackend and MemoryPool components, with conditional activation via environment variables for flexible deployment.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces new CuTeDSL kernels for Qwen3-Next models on Blackwell architecture, aiming to improve performance by using a transposed state layout for better memory access patterns. The changes are comprehensive, including the new kernel implementation, extensive tests for precision and performance, and integration into the model execution backend.

My review identified a couple of issues: a typo in the test reporting script that would cause a NameError, and an incorrect double transposition of the SSM state in the attention backend which would result in an incorrect memory layout for the new kernel. Both are high-severity issues that should be addressed. Otherwise, the new kernel implementation and its integration appear well-structured and follow good practices for performance-critical code.

Comment on lines +994 to +997
if self.use_cutedsl_transpose:
recurrent_state = ssm_states.transpose(-2, -1)
else:
recurrent_state = ssm_states
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ssm_states tensor is already prepared with the correct transposed memory layout in memory_pool.py when use_cutedsl_transpose is true. Transposing it again here with ssm_states.transpose(-2, -1) will revert it to a row-major layout, which is incorrect for the transposed kernel. This if/else block should be removed and recurrent_state should be directly assigned ssm_states.

            recurrent_state = ssm_states

else fused_sigmoid_gating_delta_rule_update
)
self.use_cutedsl_transpose = Envs.SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE.get()
rank0_log(f"CuTe DSL GDN decode enabled: use_cutedsl: {use_cutedsl}, use_cutedsl_transpose: {self.use_cutedsl_transpose}")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these two environment variables be checked for mutual exclusivity?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have verified end-to-end precision across three configurations: the default kernel, SGLANG_USE_CUTEDSL_GDN_DECODE=1, and SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE=1.

if use_cutedsl
else fused_sigmoid_gating_delta_rule_update
)
self.use_cutedsl_transpose = Envs.SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE.get()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.use_cutedsl_transpose = Envs.SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE.get()
use_cutedsl_transpose = Envs.SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE.get()

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will use the env SGLANG_USE_CUTEDSL_GDN_DECODE_TRANSPOSE in forward_extend and forward_decode.

if a.dim() == 2:
a = a.unsqueeze(0)

q_ = from_dlpack(q.detach(), assumed_align=16)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add an address align check?

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 4, 2026

Any Qwen3-Next end2end acc can be reported?

@Jon-WZQ
Copy link
Copy Markdown
Author

Jon-WZQ commented Feb 4, 2026

Any Qwen3-Next end2end acc can be reported?
OK, we have added an accuracy check on the gsm8k benchmark.

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 4, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Feb 4, 2026
Jon-WZQ and others added 3 commits February 4, 2026 10:50
…formance.

Co-authored-by: dusan <dusan.du1006@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@Jon-WZQ Jon-WZQ force-pushed the cutedsl_decode_transpose branch from 26281f3 to f3534f7 Compare February 4, 2026 10:53
@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented Feb 6, 2026

/rerun-failed-ci

@Jon-WZQ
Copy link
Copy Markdown
Author

Jon-WZQ commented Feb 10, 2026

@BBuf Can you help take a look at the ci tests? We have checked all the ci outputs. And we think there are some errors unrelated to our patch.

HV = mV.shape[2]
blocks_per_head = mK.shape[-1] // BV

assert T // CACHE_STEPS == B, "batch * CACHE_STEPS must be equal to T"
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If cache_steps not input, the cache_steps = T, this asserts fallback to T//T==B, namely B must be 1. That means by default bs > 1 will assert, since this kernel is for decode only, it might need to be addressed in function name, or add some comment.

Copy link
Copy Markdown
Author

@Jon-WZQ Jon-WZQ Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, we will add some comments.

ga = cute.zipped_divide(ma, (1, 1, 1))
gdt_bias = mdt_bias

B = mQ.shape[0] if cu_seqlens is None else cu_seqlens.shape[0] - 1
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] In some places it uses if cu_seqlens is None, while in the same file L947 it uses if cu_seqlens != None:. Keeping the style unified is decent.

kaixih added a commit to kaixih/sglang that referenced this pull request Feb 20, 2026
- benchmark_gdn_transpose_vs_flashinfer.py: compares SGLang PR sgl-project#17981
  (cutedsl transpose kernel) vs FlashInfer PR sgl-project#2498 (gdn_kernels)
  - T=1: sigmoid decode kernel vs gated_delta_rule (both bf16 state)
  - T>1: MTP kernel vs gated_delta_rule_mtp (both fp32 state)
  - Correctness verified for T=1 and T>1 (g pre-computed from A_log/a/dt_bias)
- README_GDN_FLASHINFER_VS_SGLANG.md: benchmark results on B200

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@Jon-WZQ
Copy link
Copy Markdown
Author

Jon-WZQ commented Feb 26, 2026

In the latest commit, we tune y_threads to 4. Here is the latest results.

MTP Performance Test Results (BF16, Draft Tokens 3)

script: test_cutedsl_gdn.py

T(BS) Triton (us) CuTeDSL (us) Speedup
4 12.27± 0.09 9.52± 0.07 1.29x
8 21.63± 0.10 15.58± 0.12 1.39x
16 36.87± 0.12 22.24± 0.08 1.66x
32 69.00± 0.10 38.74± 0.10 1.78x
48 101.42± 0.12 56.28± 0.11 1.80x

Besides, we also compare with the #19150 (Impressed Work!)

script: benchmark_gdn_transpose_vs_flashinfer.py

BS CuteDSL (us) FI (us) #19150 Speedup
1 6.12 7.96 1.3x
32 98.99 114.33 1.16x
64 189.08 209.70 1.10x
128 375.89 408.63 1.08x

cc @hlu1 @yuan-luo @kaixih

… --linear-attn-backend. 2. Use ssm_state in shape [V, K]. 3. Rename kernel.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants