-
Notifications
You must be signed in to change notification settings - Fork 305
[Triton] optimized decode kernels for Qwen3-Next model #2423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
3216bce
Add gdn fusions
hellozhuo-amd 9811501
style: fix ruff F841 and black-format Triton PR files
hellozhuo-amd b26972f
Update fused_rearrange_sigmoid_gdr.py
hellozhuo-amd 8695885
Update op_tests
hellozhuo-amd b69cb72
Fix BLACK format problem
hellozhuo-amd c4db40f
Fix black check failure
hellozhuo-amd ac48df0
Update test_fused_rearrange_sigmoid_gdr.py
hellozhuo-amd b9f33dd
Merge branch 'origin/main' into zhuo/qwen3_triton_gdn
hellozhuo-amd 56a2b85
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd f214128
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd 5084462
Merge branch 'main' into zhuo/qwen3_triton_gdn
juuso-oskari 3d084e2
Merge branch 'main' into zhuo/qwen3_triton_gdn
juuso-oskari b2ab876
Merge branch 'main' into zhuo/qwen3_triton_gdn
juuso-oskari bdc9a96
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd 7fbd9ad
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd 3ffd13c
Replace _fast with _single_token for causal conv1d update kernels for…
hellozhuo-amd 9946258
Fix blck format error
hellozhuo-amd b8ea372
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd 0f41d78
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd 2aa2493
refactor(triton): rename gated RMSNorm+FP8 op to fused_rms_gated_fp8_…
hellozhuo-amd 35035ff
Merge branch 'main' into zhuo/qwen3_triton_gdn
juuso-oskari 711c9e9
Merge branch 'main' into zhuo/qwen3_triton_gdn
hellozhuo-amd d5e7712
Merge branch 'main' into zhuo/qwen3_triton_gdn
nholmber File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
507 changes: 507 additions & 0 deletions
507
aiter/ops/triton/_triton_kernels/causal_conv1d_update_single_token.py
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
165 changes: 165 additions & 0 deletions
165
aiter/ops/triton/_triton_kernels/gated_delta_rule/decode/fused_rearrange_sigmoid_gdr.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| # SPDX-FileCopyrightText: Songlin Yang, Yu Zhang | ||
| # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| # | ||
| # This file contains code copied from the flash-linear-attention project. | ||
| # The original source code was licensed under the MIT license and included | ||
| # the following copyright notice: | ||
| # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | ||
|
|
||
| import triton | ||
| import triton.language as tl | ||
|
|
||
|
|
||
| @triton.heuristics( | ||
| { | ||
| "USE_INITIAL_STATE": lambda args: args["h0"] is not None, | ||
| "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, | ||
| "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, | ||
| "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, | ||
| } | ||
| ) | ||
| @triton.jit(do_not_specialize=["N", "T"]) | ||
| def fused_rearrange_sigmoid_gated_delta_rule_update_kernel( | ||
| A_log, | ||
| a, | ||
| b, | ||
| dt_bias, | ||
| beta, | ||
| threshold, | ||
| qkv, | ||
| o, | ||
| h0, | ||
| ht, | ||
| cu_seqlens, | ||
| ssm_state_indices, | ||
| num_accepted_tokens, | ||
| scale, | ||
| N: tl.int64, # num of sequences | ||
| T: tl.int64, # num of tokens | ||
| B: tl.constexpr, | ||
| H: tl.constexpr, | ||
| HV: tl.constexpr, | ||
| K: tl.constexpr, | ||
| V: tl.constexpr, | ||
| BK: tl.constexpr, | ||
| BV: tl.constexpr, | ||
| stride_qkv_l: tl.constexpr, | ||
| stride_qkv_hd: tl.constexpr, | ||
| stride_init_state_token: tl.constexpr, | ||
| stride_final_state_token: tl.constexpr, | ||
| stride_indices_seq: tl.constexpr, | ||
| stride_indices_tok: tl.constexpr, | ||
| USE_INITIAL_STATE: tl.constexpr, # whether to use initial state | ||
| INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace | ||
| USE_QK_L2NORM_IN_KERNEL: tl.constexpr, | ||
| IS_VARLEN: tl.constexpr, | ||
| IS_CONTINUOUS_BATCHING: tl.constexpr, | ||
| IS_SPEC_DECODING: tl.constexpr, | ||
| IS_KDA: tl.constexpr, | ||
| ): | ||
| i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) | ||
| i_n, i_hv = i_nh // HV, i_nh % HV | ||
| i_h = i_hv // (HV // H) | ||
| if IS_VARLEN: | ||
| bos, eos = ( | ||
| tl.load(cu_seqlens + i_n).to(tl.int64), | ||
| tl.load(cu_seqlens + i_n + 1).to(tl.int64), | ||
| ) | ||
| all = T | ||
| T = eos - bos | ||
| else: | ||
| bos, eos = i_n * T, i_n * T + T | ||
| all = B * T | ||
|
|
||
| if T == 0: | ||
| return | ||
|
|
||
| o_k = i_k * BK + tl.arange(0, BK) | ||
| o_v = i_v * BV + tl.arange(0, BV) | ||
|
|
||
| p_q = qkv + bos * stride_qkv_l + ((i_h * K) + o_k) * stride_qkv_hd | ||
| p_k = qkv + bos * stride_qkv_l + (H * K + (i_h * K) + o_k) * stride_qkv_hd | ||
| p_v = qkv + bos * stride_qkv_l + (2 * H * K + (i_hv * V) + o_v) * stride_qkv_hd | ||
|
|
||
| p_A_log = A_log + i_hv | ||
| if not IS_KDA: | ||
| p_a = a + bos * HV + i_hv | ||
| p_dt_bias = dt_bias + i_hv | ||
| else: | ||
| p_a = a + (bos * HV + i_hv) * K + o_k | ||
| p_dt_bias = dt_bias + i_hv * K + o_k | ||
|
|
||
| p_b = b + bos * HV + i_hv | ||
| p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v | ||
|
|
||
| mask_k = o_k < K | ||
| mask_v = o_v < V | ||
| mask_h = mask_v[:, None] & mask_k[None, :] | ||
|
|
||
| b_h = tl.zeros([BV, BK], dtype=tl.float32) | ||
| if USE_INITIAL_STATE: | ||
| if IS_CONTINUOUS_BATCHING: | ||
| if IS_SPEC_DECODING: | ||
| i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 | ||
| else: | ||
| i_t = 0 | ||
| state_idx = tl.load( | ||
| ssm_state_indices + i_n * stride_indices_seq + i_t * stride_indices_tok | ||
| ).to(tl.int64) | ||
| if state_idx < 0: | ||
| return | ||
| p_h0 = h0 + state_idx * stride_init_state_token | ||
| else: | ||
| p_h0 = h0 + bos * HV * V * K | ||
|
juuso-oskari marked this conversation as resolved.
|
||
| p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :] | ||
| b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) | ||
|
|
||
| for i_t in range(0, T): | ||
| b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) | ||
| b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) | ||
| b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) | ||
| b_b = tl.load(p_b).to(tl.float32) | ||
|
|
||
| x = tl.load(p_a).to(tl.float32) + tl.load(p_dt_bias).to(tl.float32) | ||
| softplus_x = tl.where( | ||
| beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x | ||
| ) | ||
| b_g = -tl.exp(tl.load(p_A_log).to(tl.float32)) * softplus_x | ||
|
|
||
| b_beta = tl.sigmoid(b_b.to(tl.float32)) | ||
|
|
||
| if USE_QK_L2NORM_IN_KERNEL: | ||
| b_q = b_q * tl.rsqrt(tl.sum(b_q * b_q) + 1e-6) | ||
| b_k = b_k * tl.rsqrt(tl.sum(b_k * b_k) + 1e-6) | ||
| b_q = b_q * scale | ||
| if not IS_KDA: | ||
| b_h *= tl.exp(b_g) | ||
| else: | ||
| b_h *= tl.exp(b_g[None, :]) | ||
| b_v -= tl.sum(b_h * b_k[None, :], 1) | ||
| b_v *= b_beta | ||
| b_h += b_v[:, None] * b_k[None, :] | ||
| b_o = tl.sum(b_h * b_q[None, :], 1) | ||
| tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) | ||
|
|
||
| if INPLACE_FINAL_STATE: | ||
| final_state_idx = tl.load( | ||
| ssm_state_indices + i_n * stride_indices_seq + i_t * stride_indices_tok | ||
| ).to(tl.int64) | ||
| if final_state_idx >= 0: | ||
| p_ht = ht + final_state_idx * stride_final_state_token | ||
| p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] | ||
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) | ||
| else: | ||
| p_ht = ht + (bos + i_t) * stride_final_state_token | ||
| p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] | ||
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) | ||
|
|
||
| p_q += stride_qkv_l | ||
| p_k += stride_qkv_l | ||
| p_v += stride_qkv_l | ||
| p_o += HV * V | ||
| p_b += HV | ||
| p_a += HV | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.