Skip to content

[Aiter][ROCm] gdn_linear_attn kernel fusion#40711

Merged
vllm-bot merged 9 commits intovllm-project:mainfrom
tpopp:tpopp/gdn-linear-attn-cleanup
May 8, 2026
Merged

[Aiter][ROCm] gdn_linear_attn kernel fusion#40711
vllm-bot merged 9 commits intovllm-project:mainfrom
tpopp:tpopp/gdn-linear-attn-cleanup

Conversation

@tpopp
Copy link
Copy Markdown
Contributor

@tpopp tpopp commented Apr 23, 2026

Overview:
This change merges various triton compiled kernels into single kernels. For all backends, triton kernels are merged into one. For AITER, the triton kernels are further merged into optimized kernels for AMD.

Fusions remove 20-26us in launch overhead per layer (36 GDN layers and 12 attention layers per decode step) which corresponds to 5-8% improvements in TPOT and throughput depending on workload size and input/output size rations.

Summary

This PR integrates AITER's optimized Triton kernels into the GatedDeltaNetAttention decode path and restructures the forward pass to reduce overhead. When running on ROCm with AITER enabled, decode-only batches use fused Triton kernels for the
convolution update, gating + delta rule update, and QKV rearrangement, replacing multiple individual PyTorch operations with fewer, more efficient kernel launches. The prefill and speculative-decode paths are factored out of the decode-only hot
path to keep the CUDA-graphed decode region lean.

Changes

• Import and gate AITER Triton kernels at module level (gdn_linear_attn.py): Import causal_conv1d_update_single_token, fused_reshape_causal_conv1d_update_single_token, and fused_rearrange_sigmoid_gated_delta_rule from AITER behind a
GDN_AITER_TRITON_AVAILABLE flag that checks rocm_aiter_ops.are_gdn_triton_kernels_available(). This avoids per-call import overhead and gracefully falls back on systems without these kernels.

• Add _forward_core_decode_fast (gdn_linear_attn.py): A new decode fast-path method that uses the AITER Triton kernels for:
• Fused reshape + causal conv1d single-token update (replacing separate reshape, conv1d, and state update operations)
• Fused rearrange + sigmoid gating + delta rule state update (replacing separate sigmoid, gating, rearrangement, and recurrence operations)

• Restructure forward (gdn_linear_attn.py): When AITER fast-path kernels are available, the forward method packs QKV and Z together, passes them through a prepare_gdn_attention_core_inputs compiled helper for prefill, and routes decode-only
batches to the new _forward_core_decode_fast. The prefill and speculative-decode logic is factored out of the decode-only branch to minimize the CUDA-graphed region.

• Pass core_attn_out into chunk kernels (chunk.py, chunk_o.py): Thread an optional core_attn_out tensor through the chunked GDN kernel pipeline so that prefill output can be written directly into the pre-allocated output buffer, eliminating a
separate copy.

• Add are_gdn_triton_kernels_available() (_aiter_ops.py): Centralized check for the optional AITER Triton kernels (conv1d single-token, gated delta net). Returns False on older AITER builds that lack these modules.

AITER Dependency

The AITER Triton kernels used in this PR are provided by ROCm/aiter#2423 (​https://github.com/ROCm/aiter/pull/2423​) ("[Triton] optimized decode kernels for Qwen3-Next model"). The fast-path is gated behind
rocm_aiter_ops.are_gdn_triton_kernels_available(), so it is a no-op on AITER versions that do not include this PR — the existing decode path is used as-is.

Benchmark Results

Setup:
• Model: Qwen/Qwen3-Next-80B-A3B-Instruct-FP8, TP=1
• GPU: AMD MI355x (gfx950), single GPU
• Base image: vllm/vllm-openai-rocm:nightly (vLLM v0.19.2rc1) with AITER rebuilt from aiter:main + ROCm/aiter#2423 (​https://github.com/ROCm/aiter/pull/2423​)
• Attention backend: ROCM_AITER_FA
• Compilation: cudagraph_mode=FULL_AND_PIECEWISE, custom_ops=["-rms_norm", "-silu_and_mul", "+quant_fp8"], pass_config={"fuse_norm_quant": true}
• Benchmark: vllm bench serve --dataset_name random --random_input_len 1024 --random_output_len 1024 --max_concurrency 4 --num_prompts 32 --num_warmups 4 --seed 1 --temperature 0 --ignore_eos
• Accuracy: lm_eval --model local-completions --tasks gsm8k --num_fewshot 5 against the running server
• Baseline: same image and configuration without this PR applied

Throughput (ISL=1024, OSL=1024, concurrency=4):

┌─────────────────────────────────┬──────────┬─────────┬───────┐
│ Metric │ Baseline │ With PR │ Delta │
├─────────────────────────────────┼──────────┼─────────┼───────┤
│ Output token throughput (tok/s) │ 456.52 │ 482.55 │ +5.7% │
│ Total token throughput (tok/s) │ 913.04 │ 965.09 │ +5.7% │
│ Mean TPOT (ms) │ 8.66 │ 8.15 │ −5.9% │
│ P99 TPOT (ms) │ 8.98 │ 8.28 │ −7.8% │
│ Mean E2EL (ms) │ 8,971 │ 8,488 │ −5.4% │
└─────────────────────────────────┴──────────┴─────────┴───────┘

Accuracy (lm_eval, gsm8k, 5-shot):

┌──────────────────┬────────────────┬────────────────┬─────────────────────────────┐
│ Filter │ Baseline │ With PR │ Delta │
├──────────────────┼────────────────┼────────────────┼─────────────────────────────┤
│ flexible-extract │ 0.8506 ±0.0098 │ 0.8605 ±0.0095 │ +0.0099 (within error bars) │
│ strict-match │ 0.8097 ±0.0108 │ 0.8203 ±0.0106 │ +0.0106 (within error bars) │
└──────────────────┴────────────────┴────────────────┴─────────────────────────────┘

Accuracy is statistically identical — all deltas are within the standard error bounds.

Test plan

• [x] vllm bench serve — 5.7% output throughput improvement, 5.9% TPOT reduction
• [x] lm_eval --tasks gsm8k --num_fewshot 5 — accuracy unchanged vs. baseline (within stderr)
• [x] Server starts and serves requests correctly with ROCM_AITER_FA attention backend
• [x] Verified graceful fallback when AITER lacks GDN kernels (GDN_AITER_TRITON_AVAILABLE == False)
• [x] torch.profiler trace generated and inspected for correct kernel dispatch

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify Bot added the rocm Related to AMD ROCm label Apr 23, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Apr 23, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a fast-path for Gated Delta Net (GDN) attention using Triton kernels on ROCm, optimizing the decode path. It updates the FLA (Fast Linear Attention) operations to support in-place output computation and introduces a new _forward_core_decode_fast method in the GDN linear attention layer. Feedback highlights critical issues in the fast-path implementation, specifically an out-of-bounds access when indexing state tensors and the need to correctly index gating tensors a and b to match the selected tokens during speculative decoding.

Comment thread vllm/model_executor/layers/mamba/gdn_linear_attn.py Outdated
Comment thread vllm/model_executor/layers/mamba/gdn_linear_attn.py Outdated
Comment thread vllm/model_executor/layers/mamba/gdn_linear_attn.py Outdated
@tpopp tpopp force-pushed the tpopp/gdn-linear-attn-cleanup branch from 6ba76fe to 9a24a86 Compare April 23, 2026 15:35
else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)

def _forward_core_decode_fast(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this decode function has prefill logic?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There wasn't a reason for it. The prefill logic has been removed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved in commit Move the prefill and spec part outside _forward_core_decode_fast — the fast decode path is now strictly decode-only; prefill / spec-decode are dispatched in the parent forward_* methods above it.

query, key, value = torch.split(
mixed_qkv,
@torch.compile(fullgraph=True)
def prepare_gdn_attention_core_inputs(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why we need to change this so much?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was to fuse more Triton kernels together. Pytorch doesn't want to fuse functions with operations splitting, rearranging, and then returning multiple values due to different indexing dimensions. Flattening the tensors, concatenating them, and then operating on slices convinces the compilation to create a single kernel for this computation and its consumers rather than 3 kernels without requiring more custom Triton kernels.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bulk of the diff is the new ROCm fast path (_forward_core_decode_fast + _forward_core_rocm); the rest is a mechanical refactor to thread core_attn_out through the chunk kernels so prefill writes directly into the pre-allocated output buffer (eliminating a copy). See commit Clean up GDN attention code: imports, dead code, docs, and guards for inline comments walking through the new structure.

@tpopp tpopp force-pushed the tpopp/gdn-linear-attn-cleanup branch from 9a24a86 to d6fec1e Compare April 24, 2026 07:10
@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented Apr 24, 2026

I will provide pictures of traces to show what fusions are occurring, but probably not today.

@tpopp tpopp force-pushed the tpopp/gdn-linear-attn-cleanup branch from d6fec1e to e1a7465 Compare May 4, 2026 06:37
@tpopp tpopp force-pushed the tpopp/gdn-linear-attn-cleanup branch from e1a7465 to 80e9303 Compare May 4, 2026 14:00
@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented May 5, 2026

I've added some comments to try to explain why the torch code was made more complex. I thought it was still nice to leave this as pytorch rather than making a triton kernel as the other common way of forcing operations to be fused.

Accuracy is correct with and without aiter. I don't have an nv-gpu to test against.

@gshtras gshtras added the ready ONLY add when PR is ready to merge/full CI is needed label May 5, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 5, 2026

Hi @tpopp, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@tpopp tpopp force-pushed the tpopp/gdn-linear-attn-cleanup branch 2 times, most recently from e185fa5 to 6c08e40 Compare May 6, 2026 08:06
@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented May 6, 2026

Rebased to retrigger CI. Failures were existing failures.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 6, 2026

Hi @tpopp, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@tpopp tpopp requested a review from ZJY0516 May 6, 2026 08:19
@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented May 6, 2026

Sorry, I didn't realize before, @ZJY0516 seems to be the relevant codeowner for this PR.

Comment thread vllm/model_executor/layers/mamba/gdn_linear_attn.py
@tpopp tpopp force-pushed the tpopp/gdn-linear-attn-cleanup branch from ef81bb3 to 05c1ef9 Compare May 6, 2026 18:32
@tpopp tpopp requested a review from vadiklyutiy May 6, 2026 18:42
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented May 7, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tpopp.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label May 7, 2026
@tpopp tpopp force-pushed the tpopp/gdn-linear-attn-cleanup branch from fef2eac to 1155c17 Compare May 7, 2026 06:55
@mergify mergify Bot removed the needs-rebase label May 7, 2026
@tpopp tpopp requested a review from gshtras May 7, 2026 08:30
@tpopp
Copy link
Copy Markdown
Contributor Author

tpopp commented May 7, 2026

Current failures are existing failures on main.

@ChuanLi1101 ChuanLi1101 force-pushed the tpopp/gdn-linear-attn-cleanup branch from 1155c17 to a5341ac Compare May 7, 2026 20:45
@ChuanLi1101
Copy link
Copy Markdown
Collaborator

@tpopp added me as a collaborator so I rebased this branch onto current main (force-push, no code changes — same 9 commits with sign-offs preserved). 21-commit gap is now closed; mergeable: true. CI is re-running.

@vadiklyutiy could you please re-review when you have a chance? Your CHANGES_REQUESTED was for separating the ROCm AITER fast path out of forward_cuda — that's already done in commits Separate ROCm forward path... + Split _forward_core into standard and ROCm AITER paths + Rename forward_rocm to forward_hip (per @gshtras's naming correction). The dispatch is now the standard platform-check pattern at gdn_linear_attn.py:297-309:

elif current_platform.is_rocm():
    self._forward_method = self.forward_hip
else:
    self._forward_method = self.forward_cuda

with forward_hip (now at L704) falling back to forward_cuda when AITER Triton kernels aren't available — so this is a true no-op on non-ROCm and on ROCm builds without ROCm/aiter#2423.

@ZJY0516 to address your two earlier review comments:

  • "Why this decode function has prefill logic?" — that's been resolved by commit Move the prefill and spec part outside _forward_core_decode_fast. The fast decode path is now strictly decode-only; prefill/spec are dispatched separately above it.
  • "Could you explain why we need to change this so much?" — the bulk is the new ROCm fast path (_forward_core_decode_fast + _forward_core_rocm); the rest is mechanical refactor to thread core_attn_out through the chunk kernels so prefill writes directly into the pre-allocated output buffer (eliminating a copy). @tpopp added inline comments in commit Clean up GDN attention code: imports, dead code, docs, and guards to walk through the new structure.

Tres is offline tonight (Finland time); I'll handle review feedback on his behalf until he's back.

Signed-off-by: Chuan Li chuali@amd.com

@ChuanLi1101
Copy link
Copy Markdown
Collaborator

Quick update for context:

CI re-running on the rebased branch is so far green where it has completed (pre-commit, pre-run-check, Meta, Summary); rest of the Buildkite suite is queued. Mergeable, no conflicts.

@vadiklyutiy — heads up that landing this also unblocks your #41966 ([Qwen3.5] Add FlashInfer GDN decode backend): once this merges, you rebase onto a clean forward_cuda with the ROCm path factored out into forward_hip, which gives FlashInfer GDN decode a clear seam to plug into instead of branching mid-method. So a re-review whenever you're back would unblock both PRs at once. We'd like to land for RC5 tonight if possible — but if your timezone doesn't allow, no worries, just flag and we'll plan around it.

cc @gshtras @dllehr-amd for visibility on the RC5 timing.

hellozhuo-amd and others added 9 commits May 7, 2026 16:37
Add are_gdn_triton_kernels_available() to centralise the import probe for
optional AITER Triton kernels used by GatedDeltaNetAttention.

Made-with: Cursor

Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Chuan Li <chuali@amd.com>
Integrate optional ROCm AITER Triton kernels for GatedDeltaNetAttention:
- causal_conv1d_update_single_token for decode fast-path
- fused_rearrange_sigmoid_gated_delta_rule for recurrence
- Gated via are_gdn_triton_kernels_available() for older aiter compat
- prepare_gdn_attention_core_inputs for GQA interleaved layout unpacking
- FLA chunk ops updated for core_attn_out parameter
- Lint fixes (E741 l->seq_len, mypy no-redef type hints)

Made-with: Cursor

Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Chuan Li <chuali@amd.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Chuan Li <chuali@amd.com>
- Simplify module-level AITER Triton import indirection; remove unused
  gdn_aiter_causal_conv1d_update_single_token import and Any type hint.
- Remove dead variables in _forward_core_decode_fast (spec-related
  fields that are never read in the decode-only path).
- Add docstrings to _forward_core and gdn_attention_core explaining the
  dual parameter semantics controlled by fast_kernel.
- Use keyword args (fast_kernel=True/False) at gdn_attention_core call
  sites and _encode_layer_name consistently in both paths.
- Replace hasattr(self, "in_proj_qkv") checks with an explicit
  self.has_lora_projections boolean set in __init__.
- Add rearrange_mixed_qkv docstring describing the flatten-cat-slice
  pattern and why it replaces the original rearrange+contiguous calls.
- Add comment to prepare_gdn_attention_core_inputs explaining the
  contiguity-forcing cat+slice approach.
- Add bounds assertion in chunk_fwd_o for core_attn_out buffer reuse.
- Add inference-only guard in ChunkGatedDeltaRuleFunction.forward
  rejecting core_attn_out when grad is enabled.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Chuan Li <chuali@amd.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Chuan Li <chuali@amd.com>
…ilable

- Add forward_rocm dispatching AITER Triton fused projection+attention
  when available, falling back to forward_cuda otherwise.
- Extract _output_projection to avoid duplicating the RMSNormGated +
  out_proj sequence across forward methods.
- Update _forward_method dispatch to use forward_rocm on ROCm.
- Make are_gdn_triton_kernels_available a classmethod with
  @if_aiter_supported and _AITER_ENABLED check, matching the pattern
  of all other is_* methods. Simplify the caller accordingly.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Chuan Li <chuali@amd.com>
Eliminate the polymorphic fast_kernel parameter from _forward_core by
splitting it into two methods with clear, non-overloaded signatures:

- _forward_core(mixed_qkv, b, a, core_attn_out): standard conv1d +
  recurrent attention path used by both CUDA and ROCm.
- _forward_core_rocm(qkvz, ba, z_out, core_attn_out): ROCm AITER
  fast path that either dispatches to _forward_core_decode_fast or
  unpacks the packed layout and delegates to _forward_core.

The gdn_attention_core custom op now dispatches between the two based
on the fast_kernel flag.

Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Chuan Li <chuali@amd.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Chuan Li <chuali@amd.com>
Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Chuan Li <chuali@amd.com>
@ChuanLi1101 ChuanLi1101 force-pushed the tpopp/gdn-linear-attn-cleanup branch from a5341ac to 3a7cfd2 Compare May 7, 2026 21:37
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented May 8, 2026

I've added some comments to try to explain why the torch code was made more complex. I thought it was still nice to leave this as pytorch rather than making a triton kernel as the other common way of forcing operations to be fused.

Accuracy is correct with and without aiter. I don't have an nv-gpu to test against.

There doesn't seem to have any noticeable tests that is related to gdn on CUDA CI, and this PR changes behaviour for CUDA as well. So just to be sure, I have triggered the Qwen3.5-Next test on B200 to validate the changes.

Qwen3.5 Test results on CUDA https://buildkite.com/vllm/ci/builds/65043#019e05df-d38b-4432-b813-5f66b42e419a

Copy link
Copy Markdown
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vllm-bot vllm-bot merged commit ed582b6 into vllm-project:main May 8, 2026
56 of 59 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD May 8, 2026
libinta pushed a commit to libinta/vllm that referenced this pull request May 8, 2026
Signed-off-by: Tres Popp <tres.popp@amd.com>
Signed-off-by: Chuan Li <chuali@amd.com>
Co-authored-by: hellozhuo <zhuo.su@amd.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Libin Tang <libin.tang@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

8 participants