-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[pref] qwen3_next add triton ops : fused_sigmoid_gating_delta_rule_update #4818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+539
−1
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
65 changes: 65 additions & 0 deletions
65
tests/e2e/singlecard/test_fused_sigmoid_gating_delta_rule.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| import torch | ||
| from vllm.model_executor.layers.fla.ops import fused_recurrent_gated_delta_rule | ||
| from vllm.model_executor.models.qwen3_next import fused_gdn_gating | ||
|
|
||
| from vllm_ascend.ops.triton.fla.sigmoid_gating import \ | ||
| fused_sigmoid_gating_delta_rule_update | ||
|
|
||
|
|
||
| def test_triton_fusion_ops(): | ||
| q = torch.randn(1, 1, 4, 128, dtype=torch.bfloat16).npu() | ||
| k = torch.randn(1, 1, 4, 128, dtype=torch.bfloat16).npu() | ||
| v = torch.randn(1, 1, 8, 128, dtype=torch.bfloat16).npu() | ||
| a = torch.tensor([[ | ||
| -2.6094, -0.2617, -0.3848, 2.2656, 3.6250, -0.7383, -1.0938, -0.0505 | ||
| ]]).bfloat16().npu() | ||
| b = torch.tensor( | ||
| [[0.4277, 0.8906, 1.6875, 2.3750, 4.1562, 0.3809, 1.0625, | ||
| 3.6719]]).bfloat16().npu() | ||
| ssm_state = torch.randn(1, 8, 128, 128, dtype=torch.bfloat16).npu() | ||
| non_spec_state_indices_tensor = torch.tensor([2]).int().npu() | ||
| non_spec_query_start_loc = torch.tensor([0, 1]).int().npu() | ||
| a_log = torch.tensor([ | ||
| -2.6875, -3.2031, -3.3438, -2.7812, -3.0625, -4.0312, -5.3750, 5.7188 | ||
| ]).bfloat16().npu() | ||
| dt_bias = torch.tensor( | ||
| [-4.7812, -5.0938, -5.5000, 9.4375, 7.6250, -4.3750, -3.0938, | ||
| 0.9688]).bfloat16().npu() | ||
|
|
||
| core_attn_out_non_spec_fused = fused_sigmoid_gating_delta_rule_update( | ||
| A_log=a_log.contiguous(), | ||
| dt_bias=dt_bias.contiguous(), | ||
| q=q.contiguous(), | ||
| k=k.contiguous(), | ||
| v=v.contiguous(), | ||
| a=a.contiguous(), | ||
| b=b.contiguous(), | ||
| initial_state_source=ssm_state, | ||
| initial_state_indices=non_spec_state_indices_tensor, | ||
| cu_seqlens=non_spec_query_start_loc, | ||
| use_qk_l2norm_in_kernel=True, | ||
| softplus_beta=1.0, | ||
| softplus_threshold=20.0, | ||
| ) | ||
|
|
||
| g, beta = fused_gdn_gating(a_log, a, b, dt_bias) | ||
| g_non_spec = g | ||
| beta_non_spec = beta | ||
| core_attn_out_non_spec_split, last_recurrent_state = ( | ||
| fused_recurrent_gated_delta_rule( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| g=g_non_spec, | ||
| beta=beta_non_spec, | ||
| initial_state=ssm_state, | ||
| inplace_final_state=True, | ||
| cu_seqlens=non_spec_query_start_loc, | ||
| ssm_state_indices=non_spec_state_indices_tensor, | ||
| use_qk_l2norm_in_kernel=True, | ||
| )) | ||
| torch.testing.assert_close(core_attn_out_non_spec_fused, | ||
| core_attn_out_non_spec_split, | ||
| rtol=1e-02, | ||
| atol=1e-02, | ||
| equal_nan=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
|
|
||
| import os | ||
|
|
||
| import torch | ||
| from vllm.triton_utils import tl, tldevice, triton | ||
|
|
||
| if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': | ||
|
|
@@ -169,3 +170,228 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( | |
| p_ht = ht + (bos + i_t) * stride_final_state_token | ||
| p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] | ||
| tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) | ||
|
|
||
|
|
||
| @triton.heuristics({ | ||
| "USE_INITIAL_STATE": lambda args: args["h0_source"] is not None, | ||
| "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, | ||
| }) | ||
| @triton.jit(do_not_specialize=["T"]) | ||
| def fused_sigmoid_gating_delta_rule_update_kernel( | ||
| A_log, | ||
| a, | ||
| dt_bias, | ||
| softplus_beta, | ||
| softplus_threshold, | ||
| q, | ||
| k, | ||
| v, | ||
| b, | ||
| o, | ||
| h0_source, | ||
| h0_indices, | ||
| cu_seqlens, | ||
| scale, | ||
| T, | ||
| B: tl.constexpr, | ||
| H: tl.constexpr, | ||
| HV: tl.constexpr, | ||
| K: tl.constexpr, | ||
| V: tl.constexpr, | ||
| BK: tl.constexpr, | ||
| BV: tl.constexpr, | ||
| USE_INITIAL_STATE: tl.constexpr, | ||
| USE_QK_L2NORM_IN_KERNEL: tl.constexpr, | ||
| IS_VARLEN: tl.constexpr, | ||
| ): | ||
| """ | ||
| Fused kernel that combines sigmoid gating computation with recurrent delta rule update. | ||
| """ | ||
| i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) | ||
| i_n, i_hv = i_nh // HV, i_nh % HV | ||
| i_h = i_hv // (HV // H) | ||
|
|
||
| if IS_VARLEN: | ||
| bos, eos = ( | ||
| tl.load(cu_seqlens + i_n).to(tl.int64), | ||
| tl.load(cu_seqlens + i_n + 1).to(tl.int64), | ||
| ) | ||
| all = T | ||
| T = eos - bos | ||
| else: | ||
| bos, eos = i_n * T, i_n * T + T | ||
| all = B * T | ||
|
|
||
| o_k = i_k * BK + tl.arange(0, BK) | ||
| o_v = i_v * BV + tl.arange(0, BV) | ||
|
|
||
| p_q = q + (bos * H + i_h) * K + o_k | ||
| p_k = k + (bos * H + i_h) * K + o_k | ||
| p_v = v + (bos * HV + i_hv) * V + o_v | ||
| p_b = b + bos * HV + i_hv | ||
| p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v | ||
|
|
||
| # Gating computation pointers | ||
| p_A_log = A_log + i_hv | ||
| p_a = a + bos * HV + i_hv | ||
| p_dt_bias = dt_bias + i_hv | ||
|
|
||
| mask_k = o_k < K | ||
| mask_v = o_v < V | ||
| mask_h = mask_k[:, None] & mask_v[None, :] | ||
|
|
||
| b_h = tl.zeros([BK, BV], dtype=tl.float32) | ||
| if USE_INITIAL_STATE: | ||
| idx = tl.load(h0_indices + i_n) | ||
| # if idx >= 0: | ||
| tmp0 = tl.where(idx < 0, 0, idx) | ||
| p_h0 = (h0_source + tmp0 * HV * K * V + i_hv * K * V + | ||
| o_k[:, None] * V + o_v[None, :]) | ||
| temp1 = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) | ||
| temp2 = tl.zeros_like(temp1) | ||
| value0 = tl.where(idx < 0, temp2, temp1) | ||
| b_h += value0 # tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) | ||
|
|
||
| for i in range(0, T): | ||
| # Load inputs | ||
| b_q = tl.load(p_q + i * H * K, mask=mask_k, other=0).to(tl.float32) | ||
| b_k = tl.load(p_k + i * H * K, mask=mask_k, other=0).to(tl.float32) | ||
| b_v = tl.load(p_v + i * HV * V, mask=mask_v, other=0).to(tl.float32) | ||
| b_b = tl.load(p_b + i * HV).to(tl.float32) | ||
|
|
||
| # Compute sigmoid gating | ||
| # Load gating parameters | ||
| b_A_log = tl.load(p_A_log).to(tl.float32) | ||
| b_a = tl.load(p_a + i * HV).to(tl.float32) | ||
| b_dt_bias = tl.load(p_dt_bias).to(tl.float32) | ||
|
|
||
| # Compute g = -exp(A_log) * softplus(a + dt_bias) | ||
| x = b_a + b_dt_bias | ||
| beta_x = softplus_beta * x | ||
| # Apply softplus with numerical stability | ||
| softplus_x = tl.where( | ||
| beta_x <= softplus_threshold, | ||
| (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), | ||
| x, | ||
| ) | ||
| b_g = -tl.exp(b_A_log) * softplus_x | ||
|
|
||
| # Compute beta = sigmoid(b) | ||
| b_beta = 1.0 / (1.0 + tl.exp(-b_b)) | ||
|
|
||
| # Apply L2 normalization if enabled | ||
| if USE_QK_L2NORM_IN_KERNEL: | ||
| b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) | ||
| b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) | ||
|
|
||
| b_q = b_q * scale | ||
|
|
||
| # Apply gating to hidden state: h *= exp(g) | ||
| b_h *= tl.exp(b_g) | ||
|
|
||
| # Delta rule: v -= sum(h * k, dim=0) | ||
| b_v -= tl.sum(b_h * b_k[:, None], 0) | ||
|
|
||
| # Apply beta gating: v *= beta | ||
| b_v *= b_beta | ||
|
|
||
| # Update hidden state: h += k[:, None] * v[None, :] | ||
| b_h += b_k[:, None] * b_v[None, :] | ||
|
|
||
| # Compute output: o = sum(h * q, dim=0) | ||
| b_o = tl.sum(b_h * b_q[:, None], 0) | ||
| tl.store(p_o + i * HV * V, b_o.to(p_o.dtype.element_ty), mask=mask_v) | ||
|
|
||
| # # Update pointers for next timestep | ||
| # p_q += H * K | ||
| # p_k += H * K | ||
| # p_o += HV * V | ||
| # p_v += HV * V | ||
| # p_b += HV | ||
| # p_a += HV | ||
|
Comment on lines
+305
to
+311
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # Store final state back to h0_source with bounds checking | ||
| if USE_INITIAL_STATE: | ||
| idx = tl.load(h0_indices + i_n) | ||
| if idx >= 0: | ||
| p_h0 = (h0_source + idx * HV * K * V + i_hv * K * V + | ||
| o_k[:, None] * V + o_v[None, :]) | ||
| tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) | ||
|
|
||
|
|
||
| def fused_sigmoid_gating_delta_rule_update( | ||
| A_log: torch.Tensor, | ||
| a: torch.Tensor, | ||
| dt_bias: torch.Tensor, | ||
| softplus_beta: float, | ||
| softplus_threshold: float, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| b: torch.Tensor, | ||
| initial_state_source: torch.Tensor, | ||
| initial_state_indices: torch.Tensor, | ||
| scale: float = None, | ||
| use_qk_l2norm_in_kernel: bool = False, | ||
| cu_seqlens: torch.Tensor = None, | ||
| ): | ||
| """ | ||
| Fused triton implementation of sigmoid gating delta rule update. | ||
| This function uses a single fused kernel that combines both sigmoid gating computation | ||
| and the recurrent delta rule update for better performance. | ||
| """ | ||
| B, T, H, K, V = *k.shape, v.shape[-1] | ||
| HV = v.shape[2] | ||
| N = B if cu_seqlens is None else len(cu_seqlens) - 1 | ||
| BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64) | ||
| NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) | ||
| assert NK == 1, "NK > 1 is not supported yet" | ||
| num_stages = 3 | ||
| num_warps = 1 | ||
|
|
||
| if scale is None: | ||
| scale = k.shape[-1]**-0.5 | ||
| else: | ||
| assert scale > 0, "scale must be positive" | ||
|
|
||
| o = q.new_empty(NK, *v.shape) | ||
| grid = (NK, NV, N * HV) | ||
|
|
||
| if not initial_state_indices.is_contiguous(): | ||
| initial_state_indices = initial_state_indices.contiguous() | ||
| if not initial_state_source.is_contiguous(): | ||
| initial_state_source_contiguous = initial_state_source.contiguous() | ||
| if not cu_seqlens.is_contiguous(): | ||
| cu_seqlens = cu_seqlens.contiguous() | ||
|
|
||
| fused_sigmoid_gating_delta_rule_update_kernel[grid]( | ||
| A_log=A_log, | ||
| a=a, | ||
| dt_bias=dt_bias, | ||
| softplus_beta=softplus_beta, | ||
| softplus_threshold=softplus_threshold, | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| b=b, | ||
| o=o, | ||
| h0_source=initial_state_source_contiguous, | ||
| h0_indices=initial_state_indices, | ||
| cu_seqlens=cu_seqlens, | ||
| scale=scale, | ||
| T=T, | ||
| B=B, | ||
| H=H, | ||
| HV=HV, | ||
| K=K, | ||
| V=V, | ||
| BK=BK, | ||
| BV=BV, | ||
| USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, | ||
| num_warps=num_warps, | ||
| num_stages=num_stages, | ||
| ) | ||
| initial_state_source.copy_( | ||
| initial_state_source_contiguous.view_as(initial_state_source)) | ||
| o = o.squeeze(0) | ||
| return o | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for loading the initial state is a bit convoluted and can be simplified for better readability and efficiency. The current implementation uses
tl.whereand performs a memory load that is subsequently discarded for negative indices. Using a simpleif idx >= 0:check, as is done when storing the final state later in the kernel, would be cleaner and avoid the unnecessary load operation.