[None][feat] Optimize qwen3.5 decode delta kernel#12740
[None][feat] Optimize qwen3.5 decode delta kernel#12740nv-guomingz merged 1 commit intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis pull request optimizes tensor-parallel dimension handling in a Qwen3 model's gated attention layer and refactors a fused Triton kernel for improved pointer arithmetic and grid-striding efficiency, extending kernel signatures to support dynamic stride parameters. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py (3)
203-207: Stride extraction looks correct but merits a clarifying comment.The stride extraction uses
stride(1)for q/k/v (3D+ tensors with token dim at index 1) andstride(-2)for a/b. For 2D tensors likeaandbshaped[num_tokens, num_heads],stride(-2)equalsstride(0)which gives the correct per-token stride (num_heads).This is subtly correct but could benefit from a brief comment explaining why different stride indices are used.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py` around lines 203 - 207, Add a succinct clarifying comment above the stride extraction explaining why q/k/v use q.stride(1) (because they are 3D+ tensors with token dimension at index 1, so per-token stride lives at dim 1) while a/b use stride(-2) (for 2D tensors shaped [num_tokens, num_heads] stride(-2) == stride(0), yielding the per-token stride equal to num_heads); update the block around stride_q/stride_k/stride_v/stride_a/stride_b so the reader understands the shape assumptions and why different index forms are used.
72-77: Renameallto avoid shadowing Python builtin.The variable
allshadows the Python builtin function. While this is inside a Triton kernel (JIT-compiled), it's still a code quality issue that could cause confusion during maintenance.Suggested fix
if IS_VARLEN: bos, eos = ( tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64), ) - all = T + total_tokens = T seq_T = eos - bos else: bos, eos = i_n * T, i_n * T + T - all = B * T + total_tokens = B * T seq_T = TAlso update line 88:
- p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + p_o = o + ((i_k * total_tokens + bos) * HV + i_hv) * V + o_v🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py` around lines 72 - 77, The variable name "all" in the fused Triton kernel shadows the Python builtin; rename it to a clearer identifier (for example "all_elems" or "total_elems") wherever it's defined in the fused_sigmoid_gating_recurrent kernel (the block assigning bos, eos, seq_T, and the branch using i_n, T, B) and update all subsequent references (including the later usage mentioned near line 88) to the new name so the kernel logic remains identical but no longer masks the builtin.
1-2: Missing NVIDIA copyright header.Per coding guidelines, all modified files should have an updated NVIDIA copyright header with the current year. The source attribution should be retained, but an NVIDIA copyright header should be added.
Suggested addition at top of file
+# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# # Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.pyAs per coding guidelines: "Add NVIDIA copyright header to ALL new files; update year on modified files".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py` around lines 1 - 2, Add the standard NVIDIA copyright header (with the current year 2026) at the very top of the file, preserving the existing "Adapted from ..." source attribution line; ensure the header matches the project's canonical NVIDIA header format (including copyright owner, year, and any required license notice) and appears before the existing comment so the file begins with the NVIDIA header followed by the source attribution.tensorrt_llm/_torch/models/modeling_qwen3_next.py (1)
714-715: Consider using precomputedkey_dim_per_tpfor consistency (optional).
forward_extendstill computeskey_split_dim = self.key_dim // self.attn_tp_sizeinline, whileforward_decodenow uses the precomputedself.key_dim_per_tp. For maintainability, consider using the precomputed attribute here as well:Suggested fix
- key_split_dim = self.key_dim // self.attn_tp_size - value_split_dim = self.value_dim // self.attn_tp_size + key_split_dim = self.key_dim_per_tp + value_split_dim = self.value_dim_per_tp🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/models/modeling_qwen3_next.py` around lines 714 - 715, In forward_extend replace the inline computation key_split_dim = self.key_dim // self.attn_tp_size with the precomputed attribute self.key_dim_per_tp (and likewise use self.value_dim_per_tp if available instead of value_split_dim = self.value_dim // self.attn_tp_size) so the method uses the same per-tensor-parallel dimensions as forward_decode (update references to key_split_dim/value_split_dim in forward_extend to use self.key_dim_per_tp/self.value_dim_per_tp to ensure consistency).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tensorrt_llm/_torch/models/modeling_qwen3_next.py`:
- Around line 714-715: In forward_extend replace the inline computation
key_split_dim = self.key_dim // self.attn_tp_size with the precomputed attribute
self.key_dim_per_tp (and likewise use self.value_dim_per_tp if available instead
of value_split_dim = self.value_dim // self.attn_tp_size) so the method uses the
same per-tensor-parallel dimensions as forward_decode (update references to
key_split_dim/value_split_dim in forward_extend to use
self.key_dim_per_tp/self.value_dim_per_tp to ensure consistency).
In `@tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py`:
- Around line 203-207: Add a succinct clarifying comment above the stride
extraction explaining why q/k/v use q.stride(1) (because they are 3D+ tensors
with token dimension at index 1, so per-token stride lives at dim 1) while a/b
use stride(-2) (for 2D tensors shaped [num_tokens, num_heads] stride(-2) ==
stride(0), yielding the per-token stride equal to num_heads); update the block
around stride_q/stride_k/stride_v/stride_a/stride_b so the reader understands
the shape assumptions and why different index forms are used.
- Around line 72-77: The variable name "all" in the fused Triton kernel shadows
the Python builtin; rename it to a clearer identifier (for example "all_elems"
or "total_elems") wherever it's defined in the fused_sigmoid_gating_recurrent
kernel (the block assigning bos, eos, seq_T, and the branch using i_n, T, B) and
update all subsequent references (including the later usage mentioned near line
88) to the new name so the kernel logic remains identical but no longer masks
the builtin.
- Around line 1-2: Add the standard NVIDIA copyright header (with the current
year 2026) at the very top of the file, preserving the existing "Adapted from
..." source attribution line; ensure the header matches the project's canonical
NVIDIA header format (including copyright owner, year, and any required license
notice) and appears before the existing comment so the file begins with the
NVIDIA header followed by the source attribution.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6770d7c6-b9fc-4403-b4da-5a28f0a58f7d
📒 Files selected for processing (2)
tensorrt_llm/_torch/models/modeling_qwen3_next.pytensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
- keep decode qkv views and make the fused recurrent kernel stride-aware - restore the decode tile choice that wins on the representative bs256 pure-decode benchmark Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
e6b8d66 to
8a2468a
Compare
|
/bot run --add-multi-gpu-test --disable-fail-fast |
|
PR_Github #41694 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #41777 [ run ] triggered by Bot. Commit: |
|
PR_Github #41777 [ run ] completed with state |
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Summary by CodeRabbit
Release Notes
New Features
Performance Improvements
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.