From 20733385d9e79a940e05cebebff6dfc98b72f4bd Mon Sep 17 00:00:00 2001 From: shiyuan680 <917935075@qq.com> Date: Sat, 8 Nov 2025 16:28:56 +0800 Subject: [PATCH] support triton chunk_gated_delta_rule ops Signed-off-by: shiyuan680 <917935075@qq.com> --- .github/workflows/_e2e_test.yaml | 2 +- .../multicard/test_chunk_gated_delta_rule.py | 33 ++ vllm_ascend/models/qwen3_next.py | 58 +-- vllm_ascend/ops/triton/fla/chunk.py | 226 ++++++++++ vllm_ascend/ops/triton/fla/chunk_delta_h.py | 259 +++++++++++ vllm_ascend/ops/triton/fla/chunk_o.py | 168 +++++++ .../ops/triton/fla/chunk_scaled_dot_kkt.py | 147 ++++++ vllm_ascend/ops/triton/fla/cumsum.py | 145 ++++++ .../triton/fla/{fla.py => layernorm_guard.py} | 98 ---- vllm_ascend/ops/triton/fla/solve_tril.py | 419 ++++++++++++++++++ vllm_ascend/ops/triton/fla/utils.py | 79 ++++ vllm_ascend/ops/triton/fla/wy_fast.py | 131 ++++++ vllm_ascend/patch/worker/patch_triton.py | 9 +- 13 files changed, 1625 insertions(+), 149 deletions(-) create mode 100644 tests/e2e/multicard/test_chunk_gated_delta_rule.py create mode 100644 vllm_ascend/ops/triton/fla/chunk.py create mode 100644 vllm_ascend/ops/triton/fla/chunk_delta_h.py create mode 100644 vllm_ascend/ops/triton/fla/chunk_o.py create mode 100644 vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py create mode 100644 vllm_ascend/ops/triton/fla/cumsum.py rename vllm_ascend/ops/triton/fla/{fla.py => layernorm_guard.py} (62%) create mode 100644 vllm_ascend/ops/triton/fla/solve_tril.py create mode 100644 vllm_ascend/ops/triton/fla/utils.py create mode 100644 vllm_ascend/ops/triton/fla/wy_fast.py diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 628314a4c28..3420c1581bc 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -276,7 +276,7 @@ jobs: shell: bash -l {0} run: | . /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh - python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev20250914-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" + python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" - name: Run vllm-project/vllm-ascend Qwen3 Next test working-directory: ./vllm-ascend diff --git a/tests/e2e/multicard/test_chunk_gated_delta_rule.py b/tests/e2e/multicard/test_chunk_gated_delta_rule.py new file mode 100644 index 00000000000..a0e4b6ef9df --- /dev/null +++ b/tests/e2e/multicard/test_chunk_gated_delta_rule.py @@ -0,0 +1,33 @@ +import torch + +from tests.ut.base import PytestBase +from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule + + +class TestChunkGatedDeltaRule(PytestBase): + + def test_triton_fusion_ops(self, mock_moe_env): + q = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu() + k = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu() + v = torch.randn(1, 17, 8, 128, dtype=torch.bfloat16).npu() + g = torch.randn(1, 17, 8, dtype=torch.float32).npu() + beta = torch.randn(1, 17, 8, dtype=torch.bfloat16).npu() + initial_state = torch.randn(3, 8, 128, 128, dtype=torch.bfloat16).npu() + q_start_loc = torch.range(0, 3, dtype=torch.int).npu() + + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk_gated_delta_rule(q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=q_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True) + + assert core_attn_out_non_spec.shape == (1, 17, 8, 128) + assert last_recurrent_state.shape == (3, 8, 128, 128) diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index a14d0f0b152..b1d7b5444a9 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -423,50 +423,20 @@ def _forward( non_spec_state_indices_tensor].contiguous() initial_state[~has_initial_state, ...] = 0 - batch_size = initial_state.shape[0] - core_attn_out = [] - last_recurrent_state = [] - - for b_idx in range(batch_size): - start, end = non_spec_query_start_loc[ - b_idx], non_spec_query_start_loc[b_idx + 1] - cur_q = query_non_spec[:, start:end, ...] - cur_k = key_non_spec[:, start:end, ...] - cur_v = value_non_spec[:, start:end, ...] - cur_g = g_non_spec[:, start:end, ...] - cur_b = beta_non_spec[:, start:end, ...] - cur_state = initial_state[b_idx].unsqueeze(0) - - ( - cur_core_attn_out_non_spec, - cur_last_recurrent_state, - ) = chunk.chunk_gated_delta_rule( - query=cur_q, - key=cur_k, - value=cur_v, - g=cur_g, - beta=cur_b, - initial_state=cur_state, - output_final_state=True, - use_qk_l2norm_in_kernel=True, - ) - - core_attn_out.append(cur_core_attn_out_non_spec) - last_recurrent_state.append(cur_last_recurrent_state) - - tar_dtype = core_attn_out[0].dtype - tar_device = core_attn_out[0].device - tar_shape = list(core_attn_out[0].shape) - tar_shape[1] = non_spec_query_start_loc[-1] - core_attn_out_non_spec = torch.empty(tar_shape, - dtype=tar_dtype, - device=tar_device) - for b_idx in range(batch_size): - cur_core_attn_out = core_attn_out[b_idx] - start, end = non_spec_query_start_loc[ - b_idx], non_spec_query_start_loc[b_idx + 1] - core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out - last_recurrent_state = torch.cat(last_recurrent_state, dim=0) + ( + core_attn_out_non_spec, + last_recurrent_state, + ) = chunk.chunk_gated_delta_rule( + q=query_non_spec, + k=key_non_spec, + v=value_non_spec, + g=g_non_spec, + beta=beta_non_spec, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True) # Init cache ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( diff --git a/vllm_ascend/ops/triton/fla/chunk.py b/vllm_ascend/ops/triton/fla/chunk.py new file mode 100644 index 00000000000..2d3dade7741 --- /dev/null +++ b/vllm_ascend/ops/triton/fla/chunk.py @@ -0,0 +1,226 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors +import warnings +from typing import Optional + +import torch +from einops import rearrange +from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd +from vllm.model_executor.layers.fla.ops.utils import SUPPRESS_LEVEL + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .solve_tril import solve_tril +from .utils import input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, + beta=beta, + g_cumsum=g, + cu_seqlens=cu_seqlens, + output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule(q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len( + beta.shape + ) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2) + q, k, v, beta, g = map( + lambda x: rearrange(x, 'b h t ... -> b t h ...'), + (q, k, v, beta, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if initial_state is not None and initial_state.shape[0] != len( + cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1]**-0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, + use_qk_l2norm_in_kernel) + if head_first: + o = rearrange(o, 'b t h ... -> b h t ...') + return o, final_state \ No newline at end of file diff --git a/vllm_ascend/ops/triton/fla/chunk_delta_h.py b/vllm_ascend/ops/triton/fla/chunk_delta_h.py new file mode 100644 index 00000000000..846623ad53f --- /dev/null +++ b/vllm_ascend/ops/triton/fla/chunk_delta_h.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_indices, prepare_chunk_offsets, safe_exp + +_CONDITIONS = ("seq7168", ) + + +@triton.heuristics({ + "USE_G": lambda args: args["g"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, +}) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_nh = tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + T_max = 1 * T + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + stride_v = H * V + stride_k = Hg * K + stride_w = H * K + + b_h1_bv1 = tl.zeros([128, 64], dtype=tl.float32) + b_h1_bv2 = tl.zeros([128, 64], dtype=tl.float32) + + v_start1 = 0 + v_start2 = 64 + + offs_k = tl.arange(0, 128)[:, None] + offs_v1 = v_start1 + tl.arange(0, 64)[None, :] + offs_v2 = v_start2 + tl.arange(0, 64)[None, :] + mask_kv1 = (offs_k < K) & (offs_v1 < V) + mask_kv2 = (offs_k < K) & (offs_v2 < V) + + # load initial state + if USE_INITIAL_STATE: + h0_ptr = h0 + i_nh * K * V + ptr_h0_bv1 = h0_ptr + offs_k * V + offs_v1 * 1 + b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1, + other=0.0).to(tl.float32) + + ptr_h0_bv2 = h0_ptr + offs_k * V + offs_v2 * 1 + b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2, + other=0.0).to(tl.float32) + + # main recurrence + for i_t in range(NT): + h_base = h + (boh + i_t) * H * K * V + i_h * K * V + + p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1), + (128, 64), (1, 0)) + tl.store(p_h1_bv1, + b_h1_bv1.to(p_h1_bv1.dtype.element_ty), + boundary_check=(0, 1)) + + p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2), + (128, 64), (1, 0)) + tl.store(p_h1_bv2, + b_h1_bv2.to(p_h1_bv2.dtype.element_ty), + boundary_check=(0, 1)) + + offs_t_wv = (i_t * BT + tl.arange(0, BT))[:, None] + offs_k_wv = tl.arange(0, 128)[None, :] + mask_w = (offs_t_wv < T) & (offs_k_wv < K) + + w_base = w + bos * H * K + i_h * K + ptr_w = w_base + offs_t_wv * stride_w + offs_k_wv * 1 + b_w = tl.load(ptr_w, mask=mask_w, other=0.0) + + k_base = k + bos * Hg * K + (i_h // (H // Hg)) * K + p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT), + (128, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + v_new_base = v_new + bos * H * V + i_h * V + + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + bos + i_h * T_max + last_idx) + + offs_t = i_t * BT + tl.arange(0, BT) + mask_t = offs_t < T + g_ptr = g + bos + i_h * T_max + b_g = tl.load(g_ptr + offs_t, mask=mask_t, other=0.0) + + b_g = safe_exp(b_g_last - b_g) + b_g_last = tl.exp(b_g_last) + + offs_t_v = (i_t * BT + tl.arange(0, BT))[:, None] + mask_v1 = (offs_t_v < T) & (offs_v1 < V) + + v_base = v + bos * H * V + i_h * V + ptr_v1 = v_base + offs_t_v * stride_v + offs_v1 * 1 + b_v1 = tl.load(ptr_v1, mask=mask_v1, other=0.0) + b_v_new1 = b_v1.to(tl.float32) + b_v_new1 -= tl.dot(b_w, b_h1_bv1.to(b_w.dtype)) + + if SAVE_NEW_VALUE: + p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), + (i_t * BT, v_start1), (BT, 64), + (1, 0)) + tl.store(p_v_new1, + b_v_new1.to(p_v_new1.dtype.element_ty), + boundary_check=(0, 1)) + + if USE_G: + b_v_new1 = b_v_new1 * b_g[:, None] + b_h1_bv1 = b_h1_bv1 * b_g_last + + b_v_new1 = b_v_new1.to(k.dtype.element_ty) + b_h1_bv1 += tl.dot(b_k, b_v_new1) + + mask_v2 = (offs_t_v < T) & (offs_v2 < V) + ptr_v2 = v_base + offs_t_v * stride_v + offs_v2 * 1 + b_v2 = tl.load(ptr_v2, mask=mask_v2, other=0.0) + b_v_new2 = b_v2.to(tl.float32) + b_v_new2 -= tl.dot(b_w, b_h1_bv2.to(b_w.dtype)) + + if SAVE_NEW_VALUE: + p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), + (i_t * BT, v_start2), (BT, 64), + (1, 0)) + tl.store(p_v_new2, + b_v_new2.to(p_v_new2.dtype.element_ty), + boundary_check=(0, 1)) + + if USE_G: + b_v_new2 = b_v_new2 * b_g[:, None] + b_h1_bv2 = b_h1_bv2 * b_g_last + + b_v_new2 = b_v_new2.to(k.dtype.element_ty) + b_h1_bv2 += tl.dot(b_k, b_v_new2) + + # epilogue + if STORE_FINAL_STATE: + ht_ptr = ht + i_nh * K * V + + p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1), + (128, 64), (1, 0)) + tl.store(p_ht1_bv1, + b_h1_bv1.to(p_ht1_bv1.dtype.element_ty), + boundary_check=(0, 1)) + + p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2), + (128, 64), (1, 0)) + tl.store(p_ht1_bv2, + b_h1_bv2.to(p_ht1_bv2.dtype.element_ty), + boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = (prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is not None else None) + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = (k.new_empty(N, H, K, V, dtype=torch.float32) + if output_final_state else None) + + v_new = torch.empty_like(u) if save_new_value else None + g = g.transpose(1, 2).contiguous() + + def grid(meta): + return (1, N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + num_warps=4, + num_stages=2, + ) + return h, v_new, final_state diff --git a/vllm_ascend/ops/triton/fla/chunk_o.py b/vllm_ascend/ops/triton/fla/chunk_o.py new file mode 100644 index 00000000000..5a3578a8261 --- /dev/null +++ b/vllm_ascend/ops/triton/fla/chunk_o.py @@ -0,0 +1,168 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_offsets, safe_exp + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_offsets, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + T_max = T + + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int64) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + + for i_t in range(NT): + i_tg = boh + i_t + h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), + (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), + (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h_base, (K, V), (V, 1), + (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + offs_t = i_t * BT + tl.arange(0, BT) + mask_t = offs_t < T + g_ptr = g + bos + i_h * T_max + b_g = tl.load(g_ptr + offs_t, mask=mask_t, other=0.0) + + b_o = b_o * tl.exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_i = tl.arange(0, BT).to(tl.float32) + m_A = o_i[:, None] >= o_i[None, :] + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), + (BT, BV), (1, 0)) + + b_v = tl.load(p_v, boundary_check=(0, 1)) + # to fix mma -> mma layout conversion + # already solved by fla v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + BT = chunk_size + + if scale is None: + scale = k.shape[-1]**-0.5 + + o = torch.empty_like(v) + if cu_seqlens is None: + N, chunk_offsets = B, None + else: + N, chunk_offsets = ( + len(cu_seqlens) - 1, + prepare_chunk_offsets(cu_seqlens, BT), + ) + + def grid(meta): + return (triton.cdiv(V, meta['BV']), N * H) + + g = g.transpose(1, 2).contiguous() + chunk_fwd_kernel_o[grid]( + q=q, + k=k, + v=v, + h=h, + g=g, + o=o, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=128, + BV=128, + num_warps=4, + num_stages=2, + ) + return o diff --git a/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py new file mode 100644 index 00000000000..aa183149a67 --- /dev/null +++ b/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_indices, safe_exp + + +@triton.heuristics({ + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_G': lambda args: args['g_cumsum'] is not None, +}) +@triton.jit(do_not_specialize=['T']) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, # [H, B, T] + g_cumsum, # [H, B, T] + A, + cu_seqlens, + chunk_indices, + T, + B, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + bt_stride = B * T + i_t_i, _ = tl.program_id(0), tl.program_id(1) + + for i_bh in range(B * H): + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t_i * 2).to( + tl.int32), tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + i_t = i_t_i + o_t = tl.arange(0, BT) + o_t_fp32 = o_t.to(tl.float32) + + p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T, ), (1, ), + (i_t * BT, ), (BT, ), (0, )) + b_beta = tl.load(p_beta, boundary_check=(0, )) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), + (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T, ), + (1, ), (i_t * BT, ), (BT, ), (0, )) + b_g = tl.load(p_g, boundary_check=(0, )) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A *= safe_exp(b_g_diff) + + b_A *= b_beta[:, None] + b_A = tl.where(o_t_fp32[:, None] > o_t_fp32[None, :], b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), + (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. + gk (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`. + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + B, T, Hg, K = k.shape + + H = beta.shape[-1] + BT = chunk_size + if cu_seqlens is not None: + cu_seqlens = cu_seqlens.cpu() + chunk_indices = (prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is not None else None) + chunk_indices = chunk_indices.npu() + cu_seqlens = cu_seqlens.npu() + else: + chunk_indices = None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + + chunk_scaled_dot_kkt_fwd_kernel[(NT, 1)]( + k=k, + beta=torch.permute(beta, (2, 0, 1)).contiguous(), + g_cumsum=torch.permute(g_cumsum, (2, 0, 1)).contiguous(), + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + Hg=Hg, + K=K, + BT=BT, + BK=128, + num_warps=8, + num_stages=3, + multibuffer=True, + ) + return A diff --git a/vllm_ascend/ops/triton/fla/cumsum.py b/vllm_ascend/ops/triton/fla/cumsum.py new file mode 100644 index 00000000000..e93a2438ffa --- /dev/null +++ b/vllm_ascend/ops/triton/fla/cumsum.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_indices + + +@triton.heuristics({ + 'HAS_SCALE': lambda args: args['scale'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BLOCK_T: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, + CHUNK_SIZE: tl.constexpr = 64, +): + i_block, i_b = tl.program_id(0), tl.program_id(1) + N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE + + if IS_VARLEN: + i_s, i_block = tl.load(chunk_indices + i_block * 2).to( + tl.int32), tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_s).to( + tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + ptr_s = tl.make_block_ptr(s + bos * H, (H, T), (T, 1), + (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)) + ptr_o = tl.make_block_ptr(o + bos * H, (H, T), (T, 1), + (0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0)) + b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32) + b_s = tl.reshape(b_s, (H, N_CHUNKS, CHUNK_SIZE)) + b_s = tl.trans(b_s, (2, 0, 1)) + b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE) + if HAS_SCALE: + b_o *= scale + b_o = tl.trans(b_o, (2, 0, 1)) + b_o = tl.reshape(b_o, (H, BLOCK_T)) + else: + ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1), + (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) + ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1), + (i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0)) + b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32) + b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H)) + b_s = tl.trans(b_s, (1, 0, 2)) + b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE) + if HAS_SCALE: + b_o *= scale + b_o = tl.trans(b_o, (1, 0, 2)) + b_o = tl.reshape(b_o, (BLOCK_T, H)) + + tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0, )) + return + + +def chunk_local_cumsum_scalar( + g, + chunk_size, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.Tensor] = torch.float, +): + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2**(chunk_size.bit_length() - + 1), "chunk_size must be a power of 2" + OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size)) + block_indices = prepare_chunk_indices( + cu_seqlens, + chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None + num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv( + T, OPTIM_BLOCK_SIZE) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (num_blocks, B) + chunk_local_cumsum_scalar_kernel[grid](s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=block_indices, + T=T, + B=B, + H=H, + BLOCK_T=OPTIM_BLOCK_SIZE, + CHUNK_SIZE=chunk_size, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=8, + num_stages=3) + return g + + +def chunk_local_cumsum(g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs) -> torch.Tensor: + if cu_seqlens is not None: + assert g.shape[ + 0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype) + else: + raise ValueError(f"Unsupported input shape {g.shape}, " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise") diff --git a/vllm_ascend/ops/triton/fla/fla.py b/vllm_ascend/ops/triton/fla/layernorm_guard.py similarity index 62% rename from vllm_ascend/ops/triton/fla/fla.py rename to vllm_ascend/ops/triton/fla/layernorm_guard.py index 79039002d1f..c99f9e08d4b 100644 --- a/vllm_ascend/ops/triton/fla/fla.py +++ b/vllm_ascend/ops/triton/fla/layernorm_guard.py @@ -7,7 +7,6 @@ # mypy: ignore-errors import torch -import torch.nn.functional as F from vllm.triton_utils import tl, triton MAX_CORES = 65535 @@ -200,100 +199,3 @@ def forward( is_rms_norm=is_rms_norm, ) return y.reshape(x_shape_og) - - -def torch_chunk_gated_delta_rule( - query, - key, - value, - g, - beta, - chunk_size=64, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=False, -): - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - query = F.normalize(query, p=2, dim=-1) - key = F.normalize(key, p=2, dim=-1) - query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) - for x in (query, key, value, beta, g) - ] - - batch_size, sequence_length, num_heads, k_head_dim = key.shape - v_head_dim = value.shape[-1] - pad_size = (chunk_size - num_heads % chunk_size) % chunk_size - query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) - key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) - value = F.pad(value, (0, 0, 0, pad_size)) - beta = F.pad(beta, (0, pad_size)) - g = F.pad(g, (0, pad_size)) - tot_heads = num_heads + pad_size - scale = 1 / (query.shape[-1]**0.5) - query = query * scale - - v_beta = value * beta.unsqueeze(-1) - k_beta = key * beta.unsqueeze(-1) - # reshape to chunks - query, key, value, k_beta, v_beta = [ - x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) - for x in (query, key, value, k_beta, v_beta) - ] - g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) - mask = torch.triu(torch.ones(chunk_size, - chunk_size, - dtype=torch.bool, - device=query.device), - diagonal=0) - - # chunk decay - g = g.cumsum(dim=-1) - decay_mask = ((g.unsqueeze(-1) - - g.unsqueeze(-2)).tril().exp().float()).tril() - attn = -( - (k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - value = attn @ v_beta - k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) - - last_recurrent_state = (torch.zeros(batch_size, sequence_length, - k_head_dim, v_head_dim).to(value) if - initial_state is None else initial_state.to(value)) - - core_attn_out = torch.zeros_like(value) - mask = torch.triu(torch.ones(chunk_size, - chunk_size, - dtype=torch.bool, - device=query.device), - diagonal=1) - - # for each chunk - for i in range(0, tot_heads // chunk_size): - q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = (q_i @ k_i.transpose(-1, -2) * - decay_mask[:, :, i]).masked_fill_(mask, 0) - v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - v_new = v_i - v_prime - attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - core_attn_out[:, :, i] = attn_inter + attn @ v_new - last_recurrent_state = ( - last_recurrent_state * g[:, :, i, -1, None, None].exp() + - (k_i * - (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( - -1, -2) @ v_new) - - if not output_final_state: - last_recurrent_state = None - core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], - core_attn_out.shape[1], -1, - core_attn_out.shape[-1]) - core_attn_out = core_attn_out[:, :, :num_heads] - core_attn_out = core_attn_out.transpose(1, - 2).contiguous().to(initial_dtype) - return core_attn_out, last_recurrent_state diff --git a/vllm_ascend/ops/triton/fla/solve_tril.py b/vllm_ascend/ops/triton/fla/solve_tril.py new file mode 100644 index 00000000000..a80003207ca --- /dev/null +++ b/vllm_ascend/ops/triton/fla/solve_tril.py @@ -0,0 +1,419 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ad, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, + LARGE_BLOCK_T: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A = A + (bos * H + i_h) * BT + Ad = Ad + (bos * H + i_h) * 16 + + base_t = i_t * LARGE_BLOCK_T + + NTASKS: tl.constexpr = 2 + N_BLOCKS: tl.constexpr = LARGE_BLOCK_T // 16 // NTASKS + + for taskid in range(0, NTASKS): + base_t += taskid * (LARGE_BLOCK_T // NTASKS) + + # use make_block_ptr to reduce vector computation + b_A = tl.zeros((N_BLOCKS, 16, 16), dtype=tl.float32) + for blkid in range(0, N_BLOCKS): + row_start_o = base_t + blkid * 16 + col_start_o = row_start_o % BT + + # 1 Create in-block offset + offs_rows_in_block = tl.arange(0, 16) + offs_cols_in_block = tl.arange(0, 16) + + # 2 Calculate the pointer of each element + ptr_A_subrec16 = (A + row_start_o * H * BT + col_start_o + + offs_rows_in_block[:, None] * H * BT + + offs_cols_in_block[None, :]) + + # 3 Create a mask to prevent out-of-bounds access + global_rows = row_start_o + offs_rows_in_block[:, None] + global_cols = col_start_o + offs_cols_in_block[None, :] + load_mask = (global_rows < T) & (global_cols < BT) + + # 4 Use mask to safely load data + b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, + other=0.0).to(tl.float32) + b_A = tl.insert_slice( + ful=b_A, + sub=b_A_subrec16[None, :, :], # (1, 16, 16) + offsets=[blkid, 0, 0], + sizes=[1, 16, 16], + strides=[1, 1, 1]) + + local_ori_A = tl.trans(b_A, (1, 0, 2)) + local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS)) + + # Convert mask into matrix multiplication to avoid for loops ub oom + tmp = tl.arange(0, 16).to(tl.float32) + rows = tmp[:, None] + cols = tmp[None, :] + is_lower = (rows > cols).to(b_A.dtype) + b_A = -b_A * is_lower + + # for loop to update N_BLOCKS row vector + for i in range(1, 16): + nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), + (1, 16 * N_BLOCKS), + (16 * N_BLOCKS, 1)) + b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16)) + + dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2)) + dot_product = tl.sum(dot_tmp, 0) + b_a = b_a + dot_product + + b_a_new_expanded = b_a[:, None, :] + b_A = tl.insert_slice(ful=b_A, + sub=b_a_new_expanded, + offsets=[0, i, 0], + sizes=[N_BLOCKS, 1, 16], + strides=[1, 1, 1]) + + on_diagonal = (rows == cols) + b_A = tl.where(on_diagonal, b_A + 1.0, b_A) + + b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0), + (N_BLOCKS * 16, 16), (1, 0)) + + # 1 Create in-block offset + offs_rows_to_store = tl.arange(0, N_BLOCKS * 16) + offs_cols_to_store = tl.arange(0, 16) + + # 2 Calculate the pointer of each element + p_Ai = (Ad + base_t * H * 16 + 0 + + offs_rows_to_store[:, None] * H * 16 + + offs_cols_to_store[None, :]) + # 3 Create a mask to prevent out-of-bounds access, only check rows + global_store_rows = base_t + offs_rows_to_store[:, None] + store_mask = global_store_rows < T + # 4 use mask to save data safely + tl.store(p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + mask=store_mask) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 32 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 32 + + p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), + (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), + (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), + (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), + (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_21 = -tl.dot( + tl.dot(Ai_22, A_21, input_precision="ieee"), + Ai_11, + input_precision="ieee", + ) + tl.store( + p_Ai_11, + Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t_val = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + i_t = i_t_val + else: + bos, eos = i_b * T, i_b * T + T + + # Base pointers (already offset by batch and head) + A += (bos * H + i_h) * 64 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 64 + + # load Ai_22 (Ad block at row i_t * 64 + 16, col 0, 16 * 16) + offs_m = i_t * 64 + 16 + tl.arange(0, 16) + offs_n = tl.arange(0, 16) + mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) + ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] + Ai_22 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) + + # load A_21 (A block at row i_t * 64 + 16, col 0, 16 * 16) + mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :] + A_21 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) + tmp = tl.dot(Ai_22, A_21, input_precision="ieee") + + # load Ai_11 (Ad block at row i_t * 64, col 0, 16 * 16) + offs_m = i_t * 64 + tl.arange(0, 16) + offs_n = tl.arange(0, 16) + mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) + ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] + Ai_11 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) + + Ai_21 = -tl.dot(tmp, Ai_11, input_precision="ieee") + + # load Ai_44 (Ad block at row i_t * 64 + 48, col 0, 16 * 16) + offs_m = i_t * 64 + 48 + tl.arange(0, 16) + offs_n = tl.arange(0, 16) + mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) + ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] + Ai_44 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) + + # load A_43 (Ad block at row i_t * 64 + 48, col 32, 16 * 16) + offs_n = 32 + tl.arange(0, 16) + mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :] + A_43 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) + tmp = tl.dot(Ai_44, A_43, input_precision="ieee") + + # load Ai_33 (Ad block at row i_t * 64 + 32, col 0, 16 * 16) + offs_m = i_t * 64 + 32 + tl.arange(0, 16) + offs_n = tl.arange(0, 16) + mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16) + ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :] + Ai_33 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32) + + Ai_43 = -tl.dot(tmp, Ai_33, input_precision="ieee") + + # build Ai_22_32 (32 * 32) + Ai_22_32 = tl.zeros((32, 32), tl.float32) + Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1)) + Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1)) + Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1)) + + # load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32) + offs_m = i_t * 64 + 32 + tl.arange(0, 32) + offs_n = tl.arange(0, 32) + mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :] + A_21_32 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) + tmp = tl.dot(Ai_22_32, A_21_32, input_precision="ieee") + + # build Ai_11_32 (32 * 32) + Ai_11_32 = tl.zeros((32, 32), tl.float32) + Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1)) + Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1)) + Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1)) + + Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee") + + # store Ai_11_32 to (i_t * 64, 0) + offs_m = i_t * 64 + tl.arange(0, 32) + offs_n = tl.arange(0, 32) + mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] + tl.store(ptr_Ai, + Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + mask=mask_store) + + # store Ai_22_32 to (i_t * 64 + 32, 32) + offs_m = i_t * 64 + 32 + tl.arange(0, 32) + offs_n = 32 + tl.arange(0, 32) + mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] + tl.store(ptr_Ai, + Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + mask=mask_store) + + # store Ai_21_32 to (i_t * 64 + 32, 32) + offs_n = tl.arange(0, 32) + mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64) + ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :] + tl.store(ptr_Ai, + Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + mask=mask_store) + + # zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63) + offs_m = i_t * 64 + tl.arange(0, 32) + offs_n = 32 + tl.arange(0, 32) + mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < BT) + ptr_Ai = Ai + offs_m[:, None] * (H * BT) + offs_n[None, :] + zero_block = tl.zeros((32, 32), dtype=ptr_Ai.dtype.element_ty) + tl.store(ptr_Ai, zero_block, mask=mask_store) + + +def solve_tril( + A: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + """ + Compute the inverse of the matrix I + A + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, BT], where BT should only be 16, 32, or 64. + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. Default: `None`. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float`. + If `None`, the output dtype will be the same as the input dtype. + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + + B, T, H, BT = A.shape + Ad = torch.empty(B, + T, + H, + 16, + device=A.device, + dtype=torch.float if BT != 16 else output_dtype) + + LARGE_BLOCK_T = 608 * 2 + + chunk_indices = (prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T) + if cu_seqlens is not None else None) + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv( + T, LARGE_BLOCK_T) + + solve_tril_16x16_kernel[NT, B * H]( + A=A, + Ad=Ad, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + LARGE_BLOCK_T=LARGE_BLOCK_T, + num_warps=1, + num_stages=4, + ) + + if BT == 16: + return Ad + + Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) + merge_fn = (merge_16x16_to_32x32_inverse_kernel + if BT == 32 else merge_16x16_to_64x64_inverse_kernel) + chunk_indices = (prepare_chunk_indices(cu_seqlens, BT) + if cu_seqlens is not None else None) + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + + merge_fn[NT, B * H]( + A=A, + Ad=Ad, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + num_warps=4, + num_stages=3, + ) + return Ai diff --git a/vllm_ascend/ops/triton/fla/utils.py b/vllm_ascend/ops/triton/fla/utils.py new file mode 100644 index 00000000000..4d2cd1350ff --- /dev/null +++ b/vllm_ascend/ops/triton/fla/utils.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +from typing import Callable + +import torch +from vllm.triton_utils import tl, triton + + +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, + chunk_size: int) -> torch.LongTensor: + indices = torch.cat([ + torch.arange(n) + for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() + ]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], + 1).to(cu_seqlens) + + +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, + chunk_size: int) -> torch.LongTensor: + return torch.cat([ + cu_seqlens.new_tensor([0]), + triton.cdiv(prepare_lens(cu_seqlens), chunk_size) + ]).cumsum(-1) + + +def input_guard( + fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else + i.contiguous() for i in args) + contiguous_kwargs = { + k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) + for k, v in kwargs.items() + } + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.npu.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@triton.jit +def safe_exp(x): + return tl.exp(tl.where(x <= 0, x, float("-inf"))) diff --git a/vllm_ascend/ops/triton/fla/wy_fast.py b/vllm_ascend/ops/triton/fla/wy_fast.py new file mode 100644 index 00000000000..1d4c295553f --- /dev/null +++ b/vllm_ascend/ops/triton/fla/wy_fast.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +# mypy: ignore-errors +from typing import Optional, Tuple + +import torch +from vllm.triton_utils import tl, triton + +from .utils import prepare_chunk_indices + + +@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None}) +@triton.jit(do_not_specialize=['T']) +def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices, + T, H: tl.constexpr, Hg: tl.constexpr, + K: tl.constexpr, V: tl.constexpr, + BT: tl.constexpr, BK: tl.constexpr, + BV: tl.constexpr, IS_VARLEN: tl.constexpr): + T_max = T + i_t_o = tl.program_id(0) + + for i_bh in range(H): + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t_o * 2).to( + tl.int32), tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to( + tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + offs_t = tl.arange(0, BT) + global_offs_t = i_t * BT + offs_t + mask_t = global_offs_t < T + + offs_t_2d = global_offs_t[:, None] + offs_bt = tl.arange(0, BT)[None, :] + ptr_A = (A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1) + mask_A = mask_t[:, None] + b_A = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32) + + ptr_g = g + bos + i_h * T_max + global_offs_t + b_g = tl.exp(tl.load(ptr_g, mask=mask_t, other=0.0)).to(tl.float32) + + ptr_beta = beta + bos + i_h * T_max + global_offs_t + b_beta = tl.load(ptr_beta, mask=mask_t, other=0.0).to(tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + offs_v = i_v * BV + tl.arange(0, BV)[None, :] + mask_v = (mask_t[:, None]) & (offs_v < V) + + ptr_v = (v + (bos * H + i_h) * V + offs_t_2d * (H * V) + + offs_v * 1) + b_v = tl.load(ptr_v, mask=mask_v, other=0.0).to(tl.float32) + + b_vb = (b_v * b_beta[:, None]) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + + ptr_u = (u + (bos * H + i_h) * V + offs_t_2d * (H * V) + + offs_v * 1) + tl.store(ptr_u, b_u.to(ptr_u.dtype.element_ty), mask=mask_v) + + for i_k in range(tl.cdiv(K, BK)): + offs_k = i_k * BK + tl.arange(0, BK)[None, :] + mask_k = (mask_t[:, None]) & (offs_k < K) + ptr_k = (k + (bos * Hg + i_h // (H // Hg)) * K + offs_t_2d * + (Hg * K) + offs_k * 1) + b_k = tl.load(ptr_k, mask=mask_k, other=0.0).to(tl.float32) + + b_kb = (b_k * b_beta[:, None] * b_g[:, None]) + b_w = tl.dot(b_A, b_kb) + + ptr_w = (w + (bos * H + i_h) * K + offs_t_2d * (H * K) + + offs_k * 1) + tl.store(ptr_w, b_w.to(ptr_w.dtype.element_ty), mask=mask_k) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) \ + if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + BK = 64 + BV = 64 + + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + beta = beta.transpose(1, 2).contiguous() + g_cumsum = g_cumsum.transpose(1, 2).contiguous() + recompute_w_u_fwd_kernel[(NT, B)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + num_warps=4, + num_stages=3, + ) + return w, u diff --git a/vllm_ascend/patch/worker/patch_triton.py b/vllm_ascend/patch/worker/patch_triton.py index eb3f300bfac..2f5af43be48 100644 --- a/vllm_ascend/patch/worker/patch_triton.py +++ b/vllm_ascend/patch/worker/patch_triton.py @@ -1,10 +1,7 @@ -import vllm.model_executor.layers.fla.ops.chunk -import vllm.model_executor.layers.fla.ops.fused_recurrent -import vllm.model_executor.layers.fla.ops.layernorm_guard import vllm.model_executor.layers.mamba.ops.causal_conv1d -from vllm_ascend.ops.triton.fla.fla import (LayerNormFn, - torch_chunk_gated_delta_rule) +from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule +from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn from vllm_ascend.ops.triton.fla.sigmoid_gating import \ fused_recurrent_gated_delta_rule_fwd_kernel from vllm_ascend.ops.triton.mamba.casual_conv1d import ( @@ -14,4 +11,4 @@ vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn -vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule +vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = chunk_gated_delta_rule