[Aiter][ROCm] gdn_linear_attn kernel fusion#40711
[Aiter][ROCm] gdn_linear_attn kernel fusion#40711vllm-bot merged 9 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a fast-path for Gated Delta Net (GDN) attention using Triton kernels on ROCm, optimizing the decode path. It updates the FLA (Fast Linear Attention) operations to support in-place output computation and introduces a new _forward_core_decode_fast method in the GDN linear attention layer. Feedback highlights critical issues in the fast-path implementation, specifically an out-of-bounds access when indexing state tensors and the need to correctly index gating tensors a and b to match the selected tokens during speculative decoding.
6ba76fe to
9a24a86
Compare
| else: | ||
| core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) | ||
|
|
||
| def _forward_core_decode_fast( |
There was a problem hiding this comment.
Why this decode function has prefill logic?
There was a problem hiding this comment.
There wasn't a reason for it. The prefill logic has been removed.
There was a problem hiding this comment.
Resolved in commit Move the prefill and spec part outside _forward_core_decode_fast — the fast decode path is now strictly decode-only; prefill / spec-decode are dispatched in the parent forward_* methods above it.
| query, key, value = torch.split( | ||
| mixed_qkv, | ||
| @torch.compile(fullgraph=True) | ||
| def prepare_gdn_attention_core_inputs( |
There was a problem hiding this comment.
Could you explain why we need to change this so much?
There was a problem hiding this comment.
This was to fuse more Triton kernels together. Pytorch doesn't want to fuse functions with operations splitting, rearranging, and then returning multiple values due to different indexing dimensions. Flattening the tensors, concatenating them, and then operating on slices convinces the compilation to create a single kernel for this computation and its consumers rather than 3 kernels without requiring more custom Triton kernels.
There was a problem hiding this comment.
The bulk of the diff is the new ROCm fast path (_forward_core_decode_fast + _forward_core_rocm); the rest is a mechanical refactor to thread core_attn_out through the chunk kernels so prefill writes directly into the pre-allocated output buffer (eliminating a copy). See commit Clean up GDN attention code: imports, dead code, docs, and guards for inline comments walking through the new structure.
9a24a86 to
d6fec1e
Compare
|
I will provide pictures of traces to show what fusions are occurring, but probably not today. |
d6fec1e to
e1a7465
Compare
e1a7465 to
80e9303
Compare
|
I've added some comments to try to explain why the torch code was made more complex. I thought it was still nice to leave this as pytorch rather than making a triton kernel as the other common way of forcing operations to be fused. Accuracy is correct with and without aiter. I don't have an nv-gpu to test against. |
|
Hi @tpopp, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
e185fa5 to
6c08e40
Compare
|
Rebased to retrigger CI. Failures were existing failures. |
|
Hi @tpopp, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Sorry, I didn't realize before, @ZJY0516 seems to be the relevant codeowner for this PR. |
ef81bb3 to
05c1ef9
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
fef2eac to
1155c17
Compare
|
Current failures are existing failures on main. |
1155c17 to
a5341ac
Compare
|
@tpopp added me as a collaborator so I rebased this branch onto current @vadiklyutiy could you please re-review when you have a chance? Your elif current_platform.is_rocm():
self._forward_method = self.forward_hip
else:
self._forward_method = self.forward_cudawith @ZJY0516 to address your two earlier review comments:
Tres is offline tonight (Finland time); I'll handle review feedback on his behalf until he's back. Signed-off-by: Chuan Li chuali@amd.com |
|
Quick update for context: CI re-running on the rebased branch is so far green where it has completed ( @vadiklyutiy — heads up that landing this also unblocks your #41966 ( cc @gshtras @dllehr-amd for visibility on the RC5 timing. |
Add are_gdn_triton_kernels_available() to centralise the import probe for optional AITER Triton kernels used by GatedDeltaNetAttention. Made-with: Cursor Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Chuan Li <chuali@amd.com>
Integrate optional ROCm AITER Triton kernels for GatedDeltaNetAttention: - causal_conv1d_update_single_token for decode fast-path - fused_rearrange_sigmoid_gated_delta_rule for recurrence - Gated via are_gdn_triton_kernels_available() for older aiter compat - prepare_gdn_attention_core_inputs for GQA interleaved layout unpacking - FLA chunk ops updated for core_attn_out parameter - Lint fixes (E741 l->seq_len, mypy no-redef type hints) Made-with: Cursor Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Chuan Li <chuali@amd.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Chuan Li <chuali@amd.com>
- Simplify module-level AITER Triton import indirection; remove unused gdn_aiter_causal_conv1d_update_single_token import and Any type hint. - Remove dead variables in _forward_core_decode_fast (spec-related fields that are never read in the decode-only path). - Add docstrings to _forward_core and gdn_attention_core explaining the dual parameter semantics controlled by fast_kernel. - Use keyword args (fast_kernel=True/False) at gdn_attention_core call sites and _encode_layer_name consistently in both paths. - Replace hasattr(self, "in_proj_qkv") checks with an explicit self.has_lora_projections boolean set in __init__. - Add rearrange_mixed_qkv docstring describing the flatten-cat-slice pattern and why it replaces the original rearrange+contiguous calls. - Add comment to prepare_gdn_attention_core_inputs explaining the contiguity-forcing cat+slice approach. - Add bounds assertion in chunk_fwd_o for core_attn_out buffer reuse. - Add inference-only guard in ChunkGatedDeltaRuleFunction.forward rejecting core_attn_out when grad is enabled. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Chuan Li <chuali@amd.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Chuan Li <chuali@amd.com>
…ilable - Add forward_rocm dispatching AITER Triton fused projection+attention when available, falling back to forward_cuda otherwise. - Extract _output_projection to avoid duplicating the RMSNormGated + out_proj sequence across forward methods. - Update _forward_method dispatch to use forward_rocm on ROCm. - Make are_gdn_triton_kernels_available a classmethod with @if_aiter_supported and _AITER_ENABLED check, matching the pattern of all other is_* methods. Simplify the caller accordingly. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Chuan Li <chuali@amd.com>
Eliminate the polymorphic fast_kernel parameter from _forward_core by splitting it into two methods with clear, non-overloaded signatures: - _forward_core(mixed_qkv, b, a, core_attn_out): standard conv1d + recurrent attention path used by both CUDA and ROCm. - _forward_core_rocm(qkvz, ba, z_out, core_attn_out): ROCm AITER fast path that either dispatches to _forward_core_decode_fast or unpacks the packed layout and delegates to _forward_core. The gdn_attention_core custom op now dispatches between the two based on the fast_kernel flag. Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Chuan Li <chuali@amd.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Chuan Li <chuali@amd.com>
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Chuan Li <chuali@amd.com>
a5341ac to
3a7cfd2
Compare
There doesn't seem to have any noticeable tests that is related to gdn on CUDA CI, and this PR changes behaviour for CUDA as well. So just to be sure, I have triggered the Qwen3.5-Next test on B200 to validate the changes. Qwen3.5 Test results on CUDA https://buildkite.com/vllm/ci/builds/65043#019e05df-d38b-4432-b813-5f66b42e419a |
Signed-off-by: Tres Popp <tres.popp@amd.com> Signed-off-by: Chuan Li <chuali@amd.com> Co-authored-by: hellozhuo <zhuo.su@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Libin Tang <libin.tang@intel.com>
Overview:
This change merges various triton compiled kernels into single kernels. For all backends, triton kernels are merged into one. For AITER, the triton kernels are further merged into optimized kernels for AMD.
Fusions remove 20-26us in launch overhead per layer (36 GDN layers and 12 attention layers per decode step) which corresponds to 5-8% improvements in TPOT and throughput depending on workload size and input/output size rations.
Summary
This PR integrates AITER's optimized Triton kernels into the GatedDeltaNetAttention decode path and restructures the forward pass to reduce overhead. When running on ROCm with AITER enabled, decode-only batches use fused Triton kernels for the
convolution update, gating + delta rule update, and QKV rearrangement, replacing multiple individual PyTorch operations with fewer, more efficient kernel launches. The prefill and speculative-decode paths are factored out of the decode-only hot
path to keep the CUDA-graphed decode region lean.
Changes
• Import and gate AITER Triton kernels at module level (gdn_linear_attn.py): Import causal_conv1d_update_single_token, fused_reshape_causal_conv1d_update_single_token, and fused_rearrange_sigmoid_gated_delta_rule from AITER behind a
GDN_AITER_TRITON_AVAILABLE flag that checks rocm_aiter_ops.are_gdn_triton_kernels_available(). This avoids per-call import overhead and gracefully falls back on systems without these kernels.
• Add _forward_core_decode_fast (gdn_linear_attn.py): A new decode fast-path method that uses the AITER Triton kernels for:
• Fused reshape + causal conv1d single-token update (replacing separate reshape, conv1d, and state update operations)
• Fused rearrange + sigmoid gating + delta rule state update (replacing separate sigmoid, gating, rearrangement, and recurrence operations)
• Restructure forward (gdn_linear_attn.py): When AITER fast-path kernels are available, the forward method packs QKV and Z together, passes them through a prepare_gdn_attention_core_inputs compiled helper for prefill, and routes decode-only
batches to the new _forward_core_decode_fast. The prefill and speculative-decode logic is factored out of the decode-only branch to minimize the CUDA-graphed region.
• Pass core_attn_out into chunk kernels (chunk.py, chunk_o.py): Thread an optional core_attn_out tensor through the chunked GDN kernel pipeline so that prefill output can be written directly into the pre-allocated output buffer, eliminating a
separate copy.
• Add are_gdn_triton_kernels_available() (_aiter_ops.py): Centralized check for the optional AITER Triton kernels (conv1d single-token, gated delta net). Returns False on older AITER builds that lack these modules.
AITER Dependency
The AITER Triton kernels used in this PR are provided by ROCm/aiter#2423 (https://github.com/ROCm/aiter/pull/2423) ("[Triton] optimized decode kernels for Qwen3-Next model"). The fast-path is gated behind
rocm_aiter_ops.are_gdn_triton_kernels_available(), so it is a no-op on AITER versions that do not include this PR — the existing decode path is used as-is.
Benchmark Results
Setup:
• Model: Qwen/Qwen3-Next-80B-A3B-Instruct-FP8, TP=1
• GPU: AMD MI355x (gfx950), single GPU
• Base image: vllm/vllm-openai-rocm:nightly (vLLM v0.19.2rc1) with AITER rebuilt from aiter:main + ROCm/aiter#2423 (https://github.com/ROCm/aiter/pull/2423)
• Attention backend: ROCM_AITER_FA
• Compilation: cudagraph_mode=FULL_AND_PIECEWISE, custom_ops=["-rms_norm", "-silu_and_mul", "+quant_fp8"], pass_config={"fuse_norm_quant": true}
• Benchmark: vllm bench serve --dataset_name random --random_input_len 1024 --random_output_len 1024 --max_concurrency 4 --num_prompts 32 --num_warmups 4 --seed 1 --temperature 0 --ignore_eos
• Accuracy: lm_eval --model local-completions --tasks gsm8k --num_fewshot 5 against the running server
• Baseline: same image and configuration without this PR applied
Throughput (ISL=1024, OSL=1024, concurrency=4):
┌─────────────────────────────────┬──────────┬─────────┬───────┐
│ Metric │ Baseline │ With PR │ Delta │
├─────────────────────────────────┼──────────┼─────────┼───────┤
│ Output token throughput (tok/s) │ 456.52 │ 482.55 │ +5.7% │
│ Total token throughput (tok/s) │ 913.04 │ 965.09 │ +5.7% │
│ Mean TPOT (ms) │ 8.66 │ 8.15 │ −5.9% │
│ P99 TPOT (ms) │ 8.98 │ 8.28 │ −7.8% │
│ Mean E2EL (ms) │ 8,971 │ 8,488 │ −5.4% │
└─────────────────────────────────┴──────────┴─────────┴───────┘
Accuracy (lm_eval, gsm8k, 5-shot):
┌──────────────────┬────────────────┬────────────────┬─────────────────────────────┐
│ Filter │ Baseline │ With PR │ Delta │
├──────────────────┼────────────────┼────────────────┼─────────────────────────────┤
│ flexible-extract │ 0.8506 ±0.0098 │ 0.8605 ±0.0095 │ +0.0099 (within error bars) │
│ strict-match │ 0.8097 ±0.0108 │ 0.8203 ±0.0106 │ +0.0106 (within error bars) │
└──────────────────┴────────────────┴────────────────┴─────────────────────────────┘
Accuracy is statistically identical — all deltas are within the standard error bounds.
Test plan
• [x] vllm bench serve — 5.7% output throughput improvement, 5.9% TPOT reduction
• [x] lm_eval --tasks gsm8k --num_fewshot 5 — accuracy unchanged vs. baseline (within stderr)
• [x] Server starts and serves requests correctly with ROCM_AITER_FA attention backend
• [x] Verified graceful fallback when AITER lacks GDN kernels (GDN_AITER_TRITON_AVAILABLE == False)
• [x] torch.profiler trace generated and inspected for correct kernel dispatch