[pref] qwen3_next add triton ops : fused_sigmoid_gating_delta_rule_update#4818
[pref] qwen3_next add triton ops : fused_sigmoid_gating_delta_rule_update#4818wangxiyuan merged 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new Triton kernel, fused_sigmoid_gating_delta_rule_update, to optimize the decode path for the Qwen3-Next model. The implementation is sound and targets a specific performance improvement scenario. My review includes suggestions for refactoring parts of the new Triton kernel for better readability and efficiency, as well as cleaning up some redundant and commented-out code in both the new kernel and the associated patch file. Addressing these points will improve the overall quality and maintainability of the code.
| idx = tl.load(h0_indices + i_n) | ||
| # if idx >= 0: | ||
| tmp0 = tl.where(idx < 0, 0, idx) | ||
| p_h0 = ( | ||
| h0_source | ||
| + tmp0 * HV * K * V | ||
| + i_hv * K * V | ||
| + o_k[:, None] * V | ||
| + o_v[None, :] | ||
| ) | ||
| temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) | ||
| temp2 = tl.zeros_like(temp1) | ||
| value0 = tl.where(idx < 0, temp2, temp1) | ||
| b_h += value0 # tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) |
There was a problem hiding this comment.
The logic for loading the initial state is a bit convoluted and can be simplified for better readability and efficiency. The current implementation uses tl.where and performs a memory load that is subsequently discarded for negative indices. Using a simple if idx >= 0: check, as is done when storing the final state later in the kernel, would be cleaner and avoid the unnecessary load operation.
| idx = tl.load(h0_indices + i_n) | |
| # if idx >= 0: | |
| tmp0 = tl.where(idx < 0, 0, idx) | |
| p_h0 = ( | |
| h0_source | |
| + tmp0 * HV * K * V | |
| + i_hv * K * V | |
| + o_k[:, None] * V | |
| + o_v[None, :] | |
| ) | |
| temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) | |
| temp2 = tl.zeros_like(temp1) | |
| value0 = tl.where(idx < 0, temp2, temp1) | |
| b_h += value0 # tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) | |
| idx = tl.load(h0_indices + i_n) | |
| if idx >= 0: | |
| p_h0 = ( | |
| h0_source | |
| + idx * HV * K * V | |
| + i_hv * K * V | |
| + o_k[:, None] * V | |
| + o_v[None, :] | |
| ) | |
| b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) |
| # # Update pointers for next timestep | ||
| # p_q += H * K | ||
| # p_k += H * K | ||
| # p_o += HV * V | ||
| # p_v += HV * V | ||
| # p_b += HV | ||
| # p_a += HV |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| #from collections.abc import Iterable |
| chunk_gated_delta_rule, | ||
| fused_recurrent_gated_delta_rule, | ||
| ) | ||
| from vllm.model_executor.layers.mamba.abstract import MambaBase |
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
efefcb2 to
d5b14ca
Compare
|
please fix the CI |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
46d0c30 to
7d1d08c
Compare
014cdad to
5a3bc6c
Compare
|
|
||
| class TestChunkGatedDeltaRule(PytestBase): | ||
|
|
||
| def test_triton_fusion_ops(self): |
There was a problem hiding this comment.
no need to base UT for e2e test. You can just write the test like other e2e do
| # Related PR (if no, explain why): | ||
| # https://github.com/vllm-project/vllm/pull/4818 | ||
| # Future Plan: | ||
| # . |
df79c66 to
a67f4b4
Compare
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
…to eplb_refactor * 'main' of https://github.com/vllm-project/vllm-ascend: (52 commits) [Doc]Add the user_guide doc file regarding fine-grained TP. (vllm-project#5084) [pref] qwen3_next add triton ops : fused_sigmoid_gating_delta_rule_update (vllm-project#4818) [Feature] Add token mask for DispatchGmmCombineDecode operator (vllm-project#5171) [CI] Improve CI (vllm-project#5078) [Refactor] remove some metadata variables in attention_v1. (vllm-project#5160) Add Qwen3-VL-235B-A22B-Instruct tutorials (vllm-project#5167) [Doc] Add a perf tune section (vllm-project#5127) [Image] Refactor image build (vllm-project#5175) [refactor] refactor weight trans nz and transpose (vllm-project#4878) [BugFix]Fix precision issue for LoRA feature (vllm-project#4141) 【Doc】Deepseekv3.1/R1 doc enhancement (vllm-project#4827) support basic long_seq feature st (vllm-project#5140) [Bugfix] install trition for test_custom_op (vllm-project#5112) [2/N][Pangu][MoE] Remove Pangu Related Code (vllm-project#5130) [bugfix] Use FUSED_MC2 MoE comm path for the op `dispatch_ffn_combine` (vllm-project#5156) [BugFix] Fix top_p,top_k issue with EAGLE and add top_p,top_k in EAGLE e2e (vllm-project#5131) [Doc][P/D] Fix MooncakeConnector's name (vllm-project#5172) [Bugfix] Fix in_profile_run in mtp_proposer dummy_run (vllm-project#5165) [Doc] Refact benchmark doc (vllm-project#5173) [Nightly] Avoid max_model_len being smaller than the decoder prompt to prevent single-node-accuray-tests from failing (vllm-project#5174) ... Signed-off-by: 白永斌 <baiyongbin3@h-partners.com>
…date (vllm-project#4818) ### What this PR does / why we need it? qwen3_next add fused_sigmoid_gating_delta_rule_update op which fused fused_gdn_gating+fused_recurrent_gated_delta_rule - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
…date (vllm-project#4818) ### What this PR does / why we need it? qwen3_next add fused_sigmoid_gating_delta_rule_update op which fused fused_gdn_gating+fused_recurrent_gated_delta_rule - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
…date (vllm-project#4818) ### What this PR does / why we need it? qwen3_next add fused_sigmoid_gating_delta_rule_update op which fused fused_gdn_gating+fused_recurrent_gated_delta_rule - vLLM version: v0.12.0 - vLLM main: vllm-project/vllm@ad32e3e Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
What this PR does / why we need it?
qwen3_next add fused_sigmoid_gating_delta_rule_update op which fused fused_gdn_gating+fused_recurrent_gated_delta_rule
fused before:

fused after:

Does this PR introduce any user-facing change?
How was this patch tested?