Skip to content

[jit-kernel] Add CuTe DSL GDN Decode Kernel#15631

Merged
merrymercy merged 39 commits intosgl-project:mainfrom
liz-badada:cutedsl_gdn
Jan 18, 2026
Merged

[jit-kernel] Add CuTe DSL GDN Decode Kernel#15631
merrymercy merged 39 commits intosgl-project:mainfrom
liz-badada:cutedsl_gdn

Conversation

@liz-badada
Copy link
Collaborator

@liz-badada liz-badada commented Dec 22, 2025

Co-author: @zhou9402, @HongliMi, @xutizhou

Motivation

Modifications

Add CuTe DSL GDN Decode Kernel

  • hybrid_linear_attn_backend.py - Add CuTe DSL decode branch (env: SGLANG_USE_CUTEDSL_GDN_DECODE=1)
  • python/sglang/jit_kernel/cutedsl_gdn.py - CuTe DSL kernel to support decode gdn
  • python/sglang/jit_kernel/tests/test_cutedsl_gdn.py - Precision & performance tests

nvidia-cutlass-dsl >=4.3.0

Note: This is a ssm state transpose free version, will also have a version with better performance (need pre-transpose ssm state just once) @HongliMi.

Accuracy Tests

pytest path/to/python/sglang/jit_kernel/tests/test_cutedsl_gdn.py -v -s

SGLANG_USE_CUTEDSL_GDN_DECODE=1 python -m sglang.bench_one_batch --model-path Qwen/Qwen3-Next-80B-A3B-Instruct \
    --tp 2 --batch 1 --input-len 128 --output-len 64 --cuda-graph-bs 1 --correctness-test

Benchmarking and Profiling

SGLANG_USE_CUTEDSL_GDN_DECODE=1 python3 -m sglang.bench_one_batch_server --model-path Qwen/Qwen3-Next-80B-A3B-Instruct \
    --batch-size 128 --input-len 128 1024 2048 4096 8192 --output-len 1024 --disable-radix-cache \
    --tp 2 --cuda-graph-bs 128 --show-report

E2E (H200, TP=2, Output lens: [1024], E2E output throughput speedup: 4.6% - 5.2%)

Triton

batch size input len latency (s) input throughput (tok/s) output throughput (tok/s) acc length ITL (ms) input cost ($/1M) output cost ($/1M) cache hit rate
128 1024 25.87 56064.6 5568.82 n/a 22.99 0.03 0.2 n/a
128 2048 26.74 56155.4 5938.23 n/a 21.56 0.03 0.19 n/a
128 4096 31.56 55581.9 5924.45 n/a 21.61 0.03 0.19 n/a
128 8192 42.45 53994 5691.37 n/a 22.49 0.03 0.2 n/a

CuTe DSL

batch size input len latency (s) input throughput (tok/s) output throughput (tok/s) acc length ITL (ms) input cost ($/1M) output cost ($/1M) cache hit rate
128 1024 24.85 55840.9 5825.61 n/a 21.97 0.03 0.19 n/a
128 2048 25.73 56161.4 6223.17 n/a 20.57 0.03 0.18 n/a
128 4096 30.54 55649.4 6206.95 n/a 20.62 0.03 0.18 n/a
128 8192 41.26 54103.8 5989.78 n/a 21.37 0.03 0.19 n/a

E2E (B200, TP=2, Output lens: [1024], E2E output throughput speedup: 2.6% - 3.4%)

Triton

batch size input len latency (s) input throughput (tok/s) output throughput (tok/s) acc length ITL (ms) input cost ($/1M) output cost ($/1M) cache hit rate
128 1024 20.11 84113.8 7064.7 n/a 18.12 0.04 0.31 n/a
128 2048 20.99 83395.6 7344.68 n/a 17.43 0.04 0.3 n/a
128 4096 24.57 80808.5 7249.77 n/a 17.66 0.04 0.31 n/a
128 8192 33.29 76240.3 6709.32 n/a 19.08 0.04 0.33 n/a

CuTe DSL

batch size input len latency (s) input throughput (tok/s) output throughput (tok/s) acc length ITL (ms) input cost ($/1M) output cost ($/1M) cache hit rate
128 1024 19.51 83999.9 7301.8 n/a 17.53 0.04 0.3 n/a
128 2048 20.48 83453.1 7558.23 n/a 16.94 0.04 0.29 n/a
128 4096 24.1 81450.2 7420.56 n/a 17.25 0.04 0.3 n/a
128 8192 32.79 76249.7 6884.49 n/a 18.59 0.04 0.32 n/a

E2E (H20, TP=2, Output lens: [1024], E2E output throughput speedup: 1.7% - 2.5%)

Triton

batch size input len latency (s) input throughput (tok/s) output throughput (tok/s) acc length ITL (ms) input cost ($/1M) output cost ($/1M) cache hit rate
128 1024 35.74 22943.5 4365.37 n/a 29.32 0.07 0.25 n/a
128 2048 40.33 22939.3 4534.73 n/a 28.23 0.07 0.25 n/a
128 4096 52.35 22592.1 4497.54 n/a 28.46 0.07 0.25 n/a
128 8192 79.48 21679.9 4212.54 n/a 30.39 0.07 0.26 n/a

CuTe DSL

batch size input len latency (s) input throughput (tok/s) output throughput (tok/s) acc length ITL (ms) input cost ($/1M) output cost ($/1M) cache hit rate
128 1024 35.24 22992.2 4437.69 n/a 28.84 0.07 0.25 n/a
128 2048 39.62 22955.2 4647.63 n/a 27.54 0.07 0.24 n/a
128 4096 51.86 22568 4578.15 n/a 27.96 0.07 0.24 n/a
128 8192 78.83 21665.1 4307.04 n/a 29.72 0.07 0.26 n/a

Checklist

@github-actions github-actions bot added dependencies Pull requests that update a dependency file sgl-kernel labels Dec 22, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @liz-badada, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new, highly optimized CuTe DSL kernel for the Fused Sigmoid Gating Delta Rule (GDN) decode operation, providing an alternative to the existing Triton-based implementation. The primary goal is to enhance performance for GDN decode by leveraging advanced GPU programming features like CP.ASYNC and optimized memory access patterns. The new kernel is integrated with a toggle for easy activation and is thoroughly tested for both accuracy and speed against the current reference.

Highlights

  • New CuTe DSL GDN Decode Kernel: Introduced a high-performance CuTe DSL implementation for the Fused Sigmoid Gating Delta Rule (GDN) decode operation, designed for efficiency on NVIDIA GPUs.
  • Conditional Kernel Selection: The new CuTe DSL kernel is integrated into the hybrid_linear_attn_backend.py and can be enabled via the SGLANG_USE_CUTEDSL_GDN_DECODE=1 environment variable, allowing users to switch between the new CuTe DSL kernel and the existing Triton kernel.
  • Optimized Kernel Design: The CuTe DSL kernel leverages advanced GPU programming techniques such as CP.ASYNC for efficient global-to-shared memory transfers, bank-conflict-free access patterns with swizzle layouts, and automatic batch size selection (small vs. big batch kernels) for optimal performance.
  • Comprehensive Testing: New tests have been added to verify both the precision and performance of the CuTe DSL GDN kernel against the Triton reference implementation, covering different batch sizes (B=16 and B=128).
  • Dependency Update: The nvidia-cutlass-dsl dependency has been updated to version >=4.3.0 to support the new CuTe DSL kernel features.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new, high-performance CuTe DSL kernel for the GDN decode operation, which can be enabled via an environment variable. The changes are well-structured, including the kernel implementation, its integration into the existing attention backend with lazy loading, and a comprehensive test suite for precision and performance against the existing Triton kernel. My feedback focuses on improving code maintainability by reducing code duplication in the kernel selection logic and enhancing clarity and consistency in the test suite. Overall, this is a solid contribution.

@Fridge003
Copy link
Collaborator

@hebiao064 Can you please take a look

@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

1 similar comment
@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

@liz-badada
Copy link
Collaborator Author

/tag-and-rerun-ci

@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

1 similar comment
@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

@BBuf
Copy link
Collaborator

BBuf commented Jan 17, 2026

@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

1 similar comment
@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

6 similar comments
@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

@liz-badada
Copy link
Collaborator Author

/rerun-failed-ci

@merrymercy merrymercy merged commit e00b434 into sgl-project:main Jan 18, 2026
542 of 567 checks passed
DotSlash-A pushed a commit to DotSlash-A/sglang that referenced this pull request Jan 19, 2026
* fix(ci): recover from corrupted MMMU parquet cache (sgl-project#17256)

* [diffusion] feat: support default 4-step inference for Flux2-Klein distilled models (sgl-project#17225)

Signed-off-by: Lancer <maruixiang6688@gmail.com>

* Add runner utilization report workflow (sgl-project#17234)

* cli: support sglang version (sgl-project#17250)

* Use swa radix cache and memory pool for gpt-oss model (sgl-project#17261)

* [VLM][Reland] Refactor load_mm_data to improve performance (sgl-project#16152)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>

* [Tiny] Improve docs (sgl-project#17264)

* [diffusion] fix: set guidance_scale default to None (sgl-project#17182)

* Tiny fix comment typo (sgl-project#17287)

* [SPEC_V2] Enable cudagraph draft_extend for trtllm_mla_backend and Acclen Fix for DP under cudagraph mode (sgl-project#16974)

* Add kl test for swa radix cache (sgl-project#17281)

* fix: Handle multiple named chat templates in HuggingFace tokenizers (sgl-project#17236)

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>

* Move radix cache related tests (sgl-project#17295)

* [Refactor] Add `-fp4-gemm-backend` to replace `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` (sgl-project#16534)

Co-authored-by: Vincent Zhong <207368749+vincentzed@users.noreply.github.com>

* [Bugfix] Fix PD accuracy when MTP is not configured on the prefill node (sgl-project#17212)

Co-authored-by: Shangming Cai <csmthu@gmail.com>

* [Diffusion] Apply jit qk_norm to flux1 (sgl-project#17296)

* [Refactor] Split out deepseek v2 weight loader function into mixin (sgl-project#16649)

* [NPU]Support GPT-OSS for NPU (sgl-project#14197)

* [jit-kernel] Add CuTe DSL GDN Decode Kernel (sgl-project#15631)

Co-authored-by: Jinyan Chen <jinyanc@nvidia.com>

* [GLM 4.7] Add RTX 6000 Pro aka sm120 (sgl-project#17235)

Co-authored-by: root <root@ubuntu-nvidia.localdomain>

* Update CODEOWNERS for multimodal_gen (sgl-project#17308)

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>

* [Feature] overlap LoRA weight loading with compute (sgl-project#15512)

* [PD] Optimize MHA models pp util calculation logic (sgl-project#17306)

* [Minor] Correct sglang version when installing from source (sgl-project#17315)

* Use dsv3 optimized routing `fused_topk_deepseek` instead of `moe_fused_gate` (sgl-project#15347)

* [DeepSeek v3.2] Opt MTP decode cuda batch sizes and nsa implementation (sgl-project#16961)

* Update code sync scripts (sgl-project#17319)

* [Auto Sync] Update tokenizer_manager.py (20260119) (sgl-project#17317)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* support new qwen3_coder_detector (sgl-project#16744)

Co-authored-by: liugaoji.lgj <liugaoji.lgj@alibaba-inc.com>

* Fix kernel selection in biased_grouped_topk_gpu (sgl-project#17325)

* KV Cache Events with Attention DP bug fix (sgl-project#16030) (sgl-project#16412)

* [Perf] fuse q, k norm for Flux2Attention (sgl-project#17241)

Co-authored-by: Minglei Zhu <zminglei@linkedin.com>

* [CI] Add partition to stage-b-test-large-1-gpu (11->12) (sgl-project#17245)

* fix(ci): rate limit and permission errors in trace publishing (sgl-project#17238)

* Revert "[Perf] fuse q, k norm for Flux2Attention (sgl-project#17241)" (sgl-project#17332)

* Migrate performance, accuracy, and quantization tests to CI registry (sgl-project#17177)

Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>

* Inclusion of nvfp4 blockscale in EPLB Rebalance (sgl-project#17158)

* [Refactor] Set `fp4-gemm-backend=auto` on SM100 and rename `fp4-gemm-backend` with `flashinfer_` prefix (sgl-project#17309)

* [Diffusion] Apply qknorm to flux2 and apply lightx2v rms_norm_one_pass kernel(without residual) (sgl-project#17305)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Fix v32 continue_final_message not work (sgl-project#16567)

* Evict swa kv cache during decoding (sgl-project#17220)

* [RadixTree][1/N Refactor]: Support unified match_prefix params (sgl-project#17142)

Co-authored-by: yizhang2077 <1109276519@qq.com>
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>

* [AMD CI] Migrate and Add More Testcases (sgl-project#17116)

Co-authored-by: yctseng0211 <yctseng@amd.com>

* [AMD] CI - add partitions for stage-b-test-small-1-gpu-amd (sgl-project#17345)

* Restore deepseek_v2.py to main's code, except the utils

* Ran `pre-commit`

---------

Signed-off-by: Lancer <maruixiang6688@gmail.com>
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Hudson Xing <1277646412@qq.com>
Co-authored-by: Lancer <402430575@qq.com>
Co-authored-by: Alison Shao <54658187+alisonshao@users.noreply.github.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: Ke Bao <ispobaoke@gmail.com>
Co-authored-by: Yuan Luo <yuan.luo@hotmail.com>
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: Mohammad Miadh Angkad <mangkad.bsdsba2027@aim.edu>
Co-authored-by: Changyi Yang <112288487+ChangyiYang@users.noreply.github.com>
Co-authored-by: YAMY <74099316+YAMY1234@users.noreply.github.com>
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: b8zhong <b8zhong@uwaterloo.ca>
Co-authored-by: Vincent Zhong <207368749+vincentzed@users.noreply.github.com>
Co-authored-by: Ch3ngY1 <91232537+Ch3ngY1@users.noreply.github.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: Jerry Ji <jerryjilol@gmail.com>
Co-authored-by: Todobe <43903496+Todobe@users.noreply.github.com>
Co-authored-by: Jinyan Chen <93358689+liz-badada@users.noreply.github.com>
Co-authored-by: Jinyan Chen <jinyanc@nvidia.com>
Co-authored-by: Koushik Dutta <koush@koushikdutta.com>
Co-authored-by: root <root@ubuntu-nvidia.localdomain>
Co-authored-by: Glen Liu <62917497+glenliu21@users.noreply.github.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: Lee Nau <lnau@nvidia.com>
Co-authored-by: Yongfei Xu <xuyongfei.xyf@antgroup.com>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Gaoji Liu <34803073+attack204@users.noreply.github.com>
Co-authored-by: liugaoji.lgj <liugaoji.lgj@alibaba-inc.com>
Co-authored-by: yudian0504 <138860534+yudian0504@users.noreply.github.com>
Co-authored-by: Kartik Ramesh <kartikx2000@gmail.com>
Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com>
Co-authored-by: Minglei Zhu <zminglei@linkedin.com>
Co-authored-by: Kangyan-Zhou <zky314343421@gmail.com>
Co-authored-by: Shu Wang <shuw@nvidia.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: ybyang <10629930+whybeyoung@users.noreply.github.com>
Co-authored-by: zhangheng <hzh0425@apache.org>
Co-authored-by: yizhang2077 <1109276519@qq.com>
Co-authored-by: pansicheng <sicheng.pan.chn@gmail.com>
Co-authored-by: Bingxu Chen <Bingxu.Chen@amd.com>
Co-authored-by: yctseng0211 <yctseng@amd.com>
@BBuf
Copy link
Collaborator

BBuf commented Mar 12, 2026

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dependencies Pull requests that update a dependency file run-ci sgl-kernel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants