Skip to content

[Triton] optimized decode kernels for Qwen3-Next model#2423

Closed
hellozhuo-amd wants to merge 23 commits intomainfrom
zhuo/qwen3_triton_gdn
Closed

[Triton] optimized decode kernels for Qwen3-Next model#2423
hellozhuo-amd wants to merge 23 commits intomainfrom
zhuo/qwen3_triton_gdn

Conversation

@hellozhuo-amd
Copy link
Copy Markdown
Contributor

@hellozhuo-amd hellozhuo-amd commented Mar 23, 2026

Motivation

On the Qwen3-Next decode path, vLLM runs several Triton-backed steps back-to-back (causal conv1d state update, QKV layout work, gated delta rule / linear attention). Bringing well-tested kernels into aiter improves reuse on ROCm and keeps a single place for Triton tuning and CI.

What this PR adds

Triton code follows the aiter split: @triton.jit in aiter/ops/triton/_triton_kernels/, Python launchers and public APIs in aiter/ops/triton/.

Area Launcher / API Kernel location
Gated delta rule (decode) fused_rearrange_sigmoid_gated_delta_rule in aiter/ops/triton/gated_delta_net/ _triton_kernels/gated_delta_rule/decode/fused_rearrange_sigmoid_gdr.py
Causal conv1d "update" fast path causal_conv1d_update_single_token, fused_reshape_causal_conv1d_update_single_token in aiter/ops/triton/causal_conv1d_update_single_token.py _triton_kernels/causal_conv1d_update_single_token.py (uses PAD_SLOT_ID from _triton_kernels/causal_conv1d.py)
RMSNorm + gated + FP8 group quant fused_rms_gated_fp8_group_quant, get_fp8_min_max_bounds, calc_rows_per_block in aiter/ops/triton/quant/fused_fp8_quant.py _fused_rms_gated_fp8_group_quant_kernel in _triton_kernels/quant/fused_fp8_quant.py (colocated with other fused FP8 quant kernels)

Exports are wired through aiter/ops/triton/gated_delta_net/__init__.py and aiter/ops/triton/quant/__init__.py.

About Gated Delta Rule

paper: https
technical blog: https

Tests

PyTorch reference tests under op_tests/triton_tests/:

  • test_fused_rearrange_sigmoid_gdr.py
  • test_causal_conv1d_update_single_token.py
  • quant/test_fused_rms_gated_fp8_group_quant.py

Test command

python3 -m pytest \
  op_tests/triton_tests/test_fused_rearrange_sigmoid_gdr.py \
  op_tests/triton_tests/test_causal_conv1d_update_single_token.py \
  op_tests/triton_tests/quant/test_fused_rms_gated_fp8_group_quant.py \
  -v

Effect on vllm Qwen3 Next model

overall effect

Baseline: around 39us
image
PR: around 15.4us
image

effect from fused_rearrange_sigmoid_gated_delta_rule

Baseline with fused_recurrent_gated_delta_rule_packed_decode_kernel: 6.218 us (averaged)

PR with fused_rearrange_sigmoid_gated_delta_rule: 5.840us (averaged)

effect from fused_reshape_causal_conv1d_update_single_token

Baseline: 4-5 kernels with 20-24us
image

PR: fused to 1 kernel with 4.9us (averaged)
image

effect from rmsnorm_input_quant_fp8

Baseline: 2 kernels with around 9.4us
image

PR: fused to 1 kernel with 4.5us (averaged)
image

Submission checklist

@hellozhuo-amd hellozhuo-amd requested review from a team and Copilot March 23, 2026 07:40
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2423 --add-label <label>

Comment thread aiter/ops/triton/fusions/fused_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/_triton_kernels/fusions/fused_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/_triton_kernels/fusions/fused_conv1d_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/fusions/fused_conv1d_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/fusions/fused_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/fusions/fused_conv1d_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/fusions/fused_conv1d_rearrange_recurrent.py Outdated
Comment thread aiter/ops/triton/_triton_kernels/fusions/fused_conv1d_rearrange_recurrent.py Outdated
@hellozhuo-amd hellozhuo-amd marked this pull request as draft March 23, 2026 07:54
@hellozhuo-amd hellozhuo-amd changed the title Zhuo/qwen3 triton gdn Zhuo/qwen3 triton gdn: fused conv1d with recurrent gated delta rule Mar 23, 2026
@hellozhuo-amd hellozhuo-amd changed the title Zhuo/qwen3 triton gdn: fused conv1d with recurrent gated delta rule Zhuo/Performance enhancement for Qwen3-Next model with Triton kernels Apr 10, 2026
@hellozhuo-amd hellozhuo-amd force-pushed the zhuo/qwen3_triton_gdn branch from 43bf452 to da98e37 Compare April 10, 2026 21:30
@hellozhuo-amd hellozhuo-amd force-pushed the zhuo/qwen3_triton_gdn branch from da98e37 to 3216bce Compare April 10, 2026 21:39
@ROCm ROCm deleted a comment from Copilot AI Apr 10, 2026
Remove unused variable in rmsnorm FP8 test ref. Apply Black to
kernels, launchers, tests, and gated_delta_rule decode __init__.

Made-with: Cursor
@hellozhuo-amd hellozhuo-amd self-assigned this Apr 11, 2026
@hellozhuo-amd hellozhuo-amd marked this pull request as ready for review April 13, 2026 11:12
@hellozhuo-amd hellozhuo-amd changed the title Zhuo/Performance enhancement for Qwen3-Next model with Triton kernels [Triton] optimized decode kernels for Qwen3-Next model Apr 14, 2026
juuso-oskari

This comment was marked as outdated.

@juuso-oskari juuso-oskari dismissed their stale review April 21, 2026 12:06

rereviewing

juuso-oskari
juuso-oskari previously approved these changes Apr 22, 2026
Copy link
Copy Markdown
Contributor

@juuso-oskari juuso-oskari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

…group_quant

Colocate the gated RMSNorm + FP8 group quant path with the other fused FP8
ops. The Triton kernel is now _fused_rms_gated_fp8_group_quant_kernel in
_triton_kernels/quant/fused_fp8_quant.py; the Python entry point is
fused_rms_gated_fp8_group_quant in quant/fused_fp8_quant.py, with a docstring
that contrasts it with fused_rms_fp8_group_quant. Remove the old
rmsnorm_input_quant_fp8 module and rms_norm_input_quant_fp8 kernel file.
Re-export the new symbol and helpers (get_fp8_min_max_bounds,
calc_rows_per_block) from aiter.ops.triton.quant. Rename the test file to
test_fused_rms_gated_fp8_group_quant.py and update test.sh.

BREAKING CHANGE: rmsnorm_input_quant_fp8 is removed; use
fused_rms_gated_fp8_group_quant instead.

Made-with: Cursor
tpopp added a commit to tpopp/vllm that referenced this pull request Apr 23, 2026
Follow upstream aiter rename (ROCm/aiter#2423). The kernel moved from
aiter.ops.triton.quant.rmsnorm_input_quant_fp8 to
aiter.ops.triton.quant.fused_fp8_quant.fused_rms_gated_fp8_group_quant.
Update the vLLM custom op registration, impl, fake, getter, and fusion
pass references accordingly.

Made-with: Cursor

Signed-off-by: Tres Popp <tres.popp@amd.com>
@sunway513
Copy link
Copy Markdown
Collaborator

This PR's content was bulk-merged via #3005 ([Silo] Bulk merge: kernel fixes and features, merged 2026-05-05 03:34 UTC). Please close this PR as superseded.

Tracking issue: ROCm/AI-Frameworks-Dashboard#141

sunway513 added a commit that referenced this pull request May 5, 2026
…st1 (#3005)

Squash-merged from main commit 2c855fb.

Includes 8 atomic Silo PRs (4 bug fixes + 3 features + 1):
Bug fixes:
- #2457 MoE dispatch fix for Quark W4A6 (MXFP4 weights with QuantType.No)
- #2464 CK MoE tuner cascading bugs
- #2547 ck_moe_stage1 split-K buffer overflow (memory safety)
- #2866 pa_mqa_logits OOB stores fix (memory safety)
Features:
- #2423 Triton optimized decode for Qwen3-Next (GDN, conv1d, fused FP8 quant)
- #2541 SplitK support for CK/CKTile Block-Scale GEMMs
- #2687 Allow preallocated MoE sorting buffer

Conflict resolutions (3 files):
- .github/workflows/aiter-test.yaml (3 blocks): took HEAD to preserve
  release/v0.1.13's prebuilt-image-extract CI flow. Theirs would replace
  it with main's inline  flow which has not
  been validated for this release branch.
- .github/workflows/vllm_benchmark.yaml (1 block): took HEAD for the same
  CI architecture preservation reason.
- aiter/ops/triton/gated_delta_net/__init__.py (1 block): took HEAD.
  Theirs would expose  and
  , but those functions exist only on main
  (added in a separate commit not part of this PR), so taking theirs
  would break import on this release branch with NameError. The new
  fused_rearrange_sigmoid_gdr.py wrapper that #3005 introduces is
  importable directly from its module path; not exporting it via
  __init__.py simply means library consumers must use the longer import
  path. Acceptable trade-off vs broken imports.

28 files changed, +5433/-47 (3 of original 31 dropped to HEAD per above).

Driver: vLLM 0.21 freeze 2026-05-08 — Silo customers need these kernel
fixes (especially #2547 / #2866 memory safety) on the AITER release
wheel, not nightly.

Verification gates added before tag:
- ATOM 5-model accuracy unchanged within +/- 0.005 vs v0.1.13-rc1
- New Qwen3-Next decode codepath smoke (GDN + causal_conv1d_single_token
  + fused_fp8_quant must JIT-compile and produce coherent output)
- Memory safety regression check on Kimi-K2.5-MXFP4 (exercises ck_moe
  stage1) and DeepSeek-V3.2 (exercises pa_mqa_logits)
- Perf delta sample on Kimi/MiniMax/DSv3.2 c=1 + c=64 vs rc1 baseline

(cherry picked from commit 2c855fb)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants