Skip to content

[None][feat] Optimize qwen3.5 decode delta kernel#12740

Merged
nv-guomingz merged 1 commit intoNVIDIA:mainfrom
nv-guomingz:user/guomingz/fla-decode-kernel-opt
Apr 8, 2026
Merged

[None][feat] Optimize qwen3.5 decode delta kernel#12740
nv-guomingz merged 1 commit intoNVIDIA:mainfrom
nv-guomingz:user/guomingz/fla-decode-kernel-opt

Conversation

@nv-guomingz
Copy link
Copy Markdown
Collaborator

@nv-guomingz nv-guomingz commented Apr 3, 2026

  • keep decode qkv views and make the fused recurrent kernel stride-aware
  • restore the decode tile choice that wins on the representative bs256 pure-decode benchmark

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced tensor-parallel support for Qwen3 model inference
    • Improved variable-length sequence handling in FLA modules
  • Performance Improvements

    • Optimized memory access patterns and kernel execution for faster inference
    • Better resource utilization in multi-GPU distributed setups

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 3, 2026

📝 Walkthrough

Walkthrough

This pull request optimizes tensor-parallel dimension handling in a Qwen3 model's gated attention layer and refactors a fused Triton kernel for improved pointer arithmetic and grid-striding efficiency, extending kernel signatures to support dynamic stride parameters.

Changes

Cohort / File(s) Summary
Qwen3 Model Tensor-Parallel Optimization
tensorrt_llm/_torch/models/modeling_qwen3_next.py
Added per-tensor-parallel derived dimensions (num_k_heads_per_tp, num_v_heads_per_tp, key_dim_per_tp, value_dim_per_tp) in __init__ and updated forward_decode to use these precomputed values instead of inline recalculation, eliminating redundant dimension derivations from tensor shapes.
Triton Kernel Control Flow and Pointer Arithmetic
tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
Refactored Triton kernel program ID mapping from direct (i_nh, i_v, i_k) to (i_k, i_v, i_nh) with while-loop grid-striding; replaced fixed-stride pointer increments with caller-provided stride_q/k/v/a/b parameters; changed launch grid shape and added varlen sequence length computation; extended kernel signature with total_nh and stride parameters.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The PR description is minimal and lacks required sections from the template: it omits a proper title format, detailed Description, Test Coverage, and does not comprehensively address the PR Checklist items. Add a properly formatted title following [type] convention, expand the Description section with detailed explanation of changes and rationale, specify Test Coverage with relevant test names, and complete the PR Checklist items explicitly.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly relates to the main changes: optimizing the decode delta kernel for Qwen3.5, which matches the modifications to both the Qwen3Next model and the fused sigmoid gating recurrent kernel.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py (3)

203-207: Stride extraction looks correct but merits a clarifying comment.

The stride extraction uses stride(1) for q/k/v (3D+ tensors with token dim at index 1) and stride(-2) for a/b. For 2D tensors like a and b shaped [num_tokens, num_heads], stride(-2) equals stride(0) which gives the correct per-token stride (num_heads).

This is subtly correct but could benefit from a brief comment explaining why different stride indices are used.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py` around
lines 203 - 207, Add a succinct clarifying comment above the stride extraction
explaining why q/k/v use q.stride(1) (because they are 3D+ tensors with token
dimension at index 1, so per-token stride lives at dim 1) while a/b use
stride(-2) (for 2D tensors shaped [num_tokens, num_heads] stride(-2) ==
stride(0), yielding the per-token stride equal to num_heads); update the block
around stride_q/stride_k/stride_v/stride_a/stride_b so the reader understands
the shape assumptions and why different index forms are used.

72-77: Rename all to avoid shadowing Python builtin.

The variable all shadows the Python builtin function. While this is inside a Triton kernel (JIT-compiled), it's still a code quality issue that could cause confusion during maintenance.

Suggested fix
         if IS_VARLEN:
             bos, eos = (
                 tl.load(cu_seqlens + i_n).to(tl.int64),
                 tl.load(cu_seqlens + i_n + 1).to(tl.int64),
             )
-            all = T
+            total_tokens = T
             seq_T = eos - bos
         else:
             bos, eos = i_n * T, i_n * T + T
-            all = B * T
+            total_tokens = B * T
             seq_T = T

Also update line 88:

-        p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
+        p_o = o + ((i_k * total_tokens + bos) * HV + i_hv) * V + o_v
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py` around
lines 72 - 77, The variable name "all" in the fused Triton kernel shadows the
Python builtin; rename it to a clearer identifier (for example "all_elems" or
"total_elems") wherever it's defined in the fused_sigmoid_gating_recurrent
kernel (the block assigning bos, eos, seq_T, and the branch using i_n, T, B) and
update all subsequent references (including the later usage mentioned near line
88) to the new name so the kernel logic remains identical but no longer masks
the builtin.

1-2: Missing NVIDIA copyright header.

Per coding guidelines, all modified files should have an updated NVIDIA copyright header with the current year. The source attribution should be retained, but an NVIDIA copyright header should be added.

Suggested addition at top of file
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
 # Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py

As per coding guidelines: "Add NVIDIA copyright header to ALL new files; update year on modified files".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py` around
lines 1 - 2, Add the standard NVIDIA copyright header (with the current year
2026) at the very top of the file, preserving the existing "Adapted from ..."
source attribution line; ensure the header matches the project's canonical
NVIDIA header format (including copyright owner, year, and any required license
notice) and appears before the existing comment so the file begins with the
NVIDIA header followed by the source attribution.
tensorrt_llm/_torch/models/modeling_qwen3_next.py (1)

714-715: Consider using precomputed key_dim_per_tp for consistency (optional).

forward_extend still computes key_split_dim = self.key_dim // self.attn_tp_size inline, while forward_decode now uses the precomputed self.key_dim_per_tp. For maintainability, consider using the precomputed attribute here as well:

Suggested fix
-        key_split_dim = self.key_dim // self.attn_tp_size
-        value_split_dim = self.value_dim // self.attn_tp_size
+        key_split_dim = self.key_dim_per_tp
+        value_split_dim = self.value_dim_per_tp
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/modeling_qwen3_next.py` around lines 714 - 715, In
forward_extend replace the inline computation key_split_dim = self.key_dim //
self.attn_tp_size with the precomputed attribute self.key_dim_per_tp (and
likewise use self.value_dim_per_tp if available instead of value_split_dim =
self.value_dim // self.attn_tp_size) so the method uses the same
per-tensor-parallel dimensions as forward_decode (update references to
key_split_dim/value_split_dim in forward_extend to use
self.key_dim_per_tp/self.value_dim_per_tp to ensure consistency).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tensorrt_llm/_torch/models/modeling_qwen3_next.py`:
- Around line 714-715: In forward_extend replace the inline computation
key_split_dim = self.key_dim // self.attn_tp_size with the precomputed attribute
self.key_dim_per_tp (and likewise use self.value_dim_per_tp if available instead
of value_split_dim = self.value_dim // self.attn_tp_size) so the method uses the
same per-tensor-parallel dimensions as forward_decode (update references to
key_split_dim/value_split_dim in forward_extend to use
self.key_dim_per_tp/self.value_dim_per_tp to ensure consistency).

In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py`:
- Around line 203-207: Add a succinct clarifying comment above the stride
extraction explaining why q/k/v use q.stride(1) (because they are 3D+ tensors
with token dimension at index 1, so per-token stride lives at dim 1) while a/b
use stride(-2) (for 2D tensors shaped [num_tokens, num_heads] stride(-2) ==
stride(0), yielding the per-token stride equal to num_heads); update the block
around stride_q/stride_k/stride_v/stride_a/stride_b so the reader understands
the shape assumptions and why different index forms are used.
- Around line 72-77: The variable name "all" in the fused Triton kernel shadows
the Python builtin; rename it to a clearer identifier (for example "all_elems"
or "total_elems") wherever it's defined in the fused_sigmoid_gating_recurrent
kernel (the block assigning bos, eos, seq_T, and the branch using i_n, T, B) and
update all subsequent references (including the later usage mentioned near line
88) to the new name so the kernel logic remains identical but no longer masks
the builtin.
- Around line 1-2: Add the standard NVIDIA copyright header (with the current
year 2026) at the very top of the file, preserving the existing "Adapted from
..." source attribution line; ensure the header matches the project's canonical
NVIDIA header format (including copyright owner, year, and any required license
notice) and appears before the existing comment so the file begins with the
NVIDIA header followed by the source attribution.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6770d7c6-b9fc-4403-b4da-5a28f0a58f7d

📥 Commits

Reviewing files that changed from the base of the PR and between 1045f38 and e6b8d66.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/models/modeling_qwen3_next.py
  • tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py

- keep decode qkv views and make the fused recurrent kernel stride-aware
- restore the decode tile choice that wins on the representative bs256 pure-decode benchmark

Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
@nv-guomingz nv-guomingz changed the title Optimize qwen3.5 decode delta kernel [None][feat] Optimize qwen3.5 decode delta kernel Apr 3, 2026
@nv-guomingz nv-guomingz force-pushed the user/guomingz/fla-decode-kernel-opt branch from e6b8d66 to 8a2468a Compare April 3, 2026 15:26
@nv-guomingz nv-guomingz requested a review from a team as a code owner April 3, 2026 15:26
@nv-guomingz nv-guomingz requested a review from tomeras91 April 3, 2026 15:26
@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run --add-multi-gpu-test --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41694 [ run ] triggered by Bot. Commit: 8a2468a Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41777 [ run ] triggered by Bot. Commit: 8a2468a Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41777 [ run ] completed with state SUCCESS. Commit: 8a2468a
/LLM/main/L0_MergeRequest_PR pipeline #32672 completed with status: 'SUCCESS'

CI Report

Link to invocation

Copy link
Copy Markdown
Collaborator

@rosenrodt rosenrodt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Collaborator

@QiJune QiJune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nv-guomingz nv-guomingz merged commit 2ff65f5 into NVIDIA:main Apr 8, 2026
6 of 7 checks passed
suyoggupta pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Apr 8, 2026
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants