diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7f3f3a8cc..ca951c1c4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,12 +35,12 @@ repos: - id: clang-format types_or: [c++, c] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.14.10 # sync with requirements-lint.txt + rev: v0.14.11 # sync with requirements-lint.txt hooks: - id: ruff-check args: [--fix, --exit-non-zero-on-fix] - id: ruff-format - args: [--exit-non-zero-on-format] + args: [--exit-non-zero-on-format, --diff] - repo: https://github.com/codespell-project/codespell rev: v2.4.1 # sync with requirements-lint.txt hooks: diff --git a/examples/kda/FLA_KDA/cumsum.py b/examples/kda/FLA_KDA/cumsum.py new file mode 100644 index 000000000..0fb3368f6 --- /dev/null +++ b/examples/kda/FLA_KDA/cumsum.py @@ -0,0 +1,469 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + + +import torch +import triton +import triton.language as tl + +from .fla_utils import prepare_chunk_indices, autotune_cache_kwargs, input_guard + +BS_LIST = [32, 64] + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BS": BS}, num_warps=num_warps) for BS in BS_LIST for num_warps in [2, 4, 8]], + key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + if REVERSE: + b_o = tl.cumsum(b_s, axis=0, reverse=True) + else: + b_o = tl.cumsum(b_s, axis=0) + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BT": BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [32, 64, 128, 256] + for num_warps in [2, 4, 8] + for num_stages in [1, 2, 3, 4] + ], + key=["B", "H", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_global_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_nh = tl.program_id(0) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + T = eos - bos + + b_z = tl.zeros([], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT - 1 - i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + b_ss = tl.sum(b_s, 0) + if REVERSE: + b_o = -b_o + b_ss + b_s + b_o += b_z + if i_c >= 0: + b_z += b_ss + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BT": BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [16, 32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [1, 2, 3, 4] + ], + key=["B", "H", "S", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_global_cumsum_vector_kernel( + s, + o, + scale, + cu_seqlens, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + T = eos - bos + + b_z = tl.zeros([BS], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT - 1 - i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + if REVERSE: + b_c = b_z[None, :] + tl.cumsum(b_s, axis=0, reverse=True) + else: + b_c = b_z[None, :] + tl.cumsum(b_s, axis=0) + if HAS_SCALE: + b_c *= scale + tl.store(p_o, b_c.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + b_z += tl.sum(b_s, 0) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: torch.Tensor = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, + chunk_indices: torch.LongTensor = None, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: torch.Tensor = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, + chunk_indices: torch.LongTensor = None, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + def grid(meta): + return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return g + + +@input_guard +def chunk_global_cumsum_scalar( + s: torch.Tensor, + reverse: bool = False, + cu_seqlens: torch.Tensor = None, + scale: float = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T = s.shape + else: + B, T, H = s.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + + z = torch.empty_like(s, dtype=output_dtype or s.dtype) + grid = (N * H,) + chunk_global_cumsum_scalar_kernel[grid]( + s=s, + o=z, + scale=scale, + cu_seqlens=cu_seqlens, + T=T, + B=B, + H=H, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return z + + +@input_guard +def chunk_global_cumsum_vector( + s: torch.Tensor, + reverse: bool = False, + cu_seqlens: torch.Tensor = None, + scale: float = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T, S = s.shape + else: + B, T, H, S = s.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BS = min(32, triton.next_power_of_2(S)) + + z = torch.empty_like(s, dtype=output_dtype or s.dtype) + grid = (triton.cdiv(S, BS), N * H) + chunk_global_cumsum_vector_kernel[grid]( + s=s, + o=z, + scale=scale, + cu_seqlens=cu_seqlens, + T=T, + B=B, + H=H, + S=S, + BS=BS, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return z + + +@input_guard +def chunk_global_cumsum( + s: torch.Tensor, + reverse: bool = False, + cu_seqlens: torch.Tensor = None, + scale: float = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + if cu_seqlens is not None: + assert s.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(s.shape) == 3: + return chunk_global_cumsum_scalar( + s=s, + reverse=reverse, + cu_seqlens=cu_seqlens, + scale=scale, + head_first=head_first, + output_dtype=output_dtype, + ) + elif len(s.shape) == 4: + return chunk_global_cumsum_vector( + s=s, + reverse=reverse, + cu_seqlens=cu_seqlens, + scale=scale, + head_first=head_first, + output_dtype=output_dtype, + ) + else: + raise ValueError( + f"Unsupported input shape {s.shape}, " + f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` " + f"or [B, H, T]/[B, H, T, D] otherwise", + ) + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: torch.Tensor = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, + chunk_indices: torch.LongTensor = None, + **kwargs, +) -> torch.Tensor: + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + chunk_indices=chunk_indices, + ) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + chunk_indices=chunk_indices, + ) + else: + raise ValueError( + f"Unsupported input shape {g.shape}, which should be (B, T, H, D) if `head_first=False` or (B, H, T, D) otherwise", + ) diff --git a/examples/kda/FLA_KDA/fla_chunk_delta.py b/examples/kda/FLA_KDA/fla_chunk_delta.py new file mode 100644 index 000000000..3b0fc908d --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_delta.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl +from .fla_utils import prepare_chunk_indices, exp, exp2, USE_CUDA_GRAPH, autotune_cache_kwargs + +NUM_WARPS = [2, 4] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ], + key=["H", "K", "V", "BT", "USE_EXP2"], + use_cuda_graph=USE_CUDA_GRAPH, + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += ((boh * H + i_h) * K * V).to(tl.int64) + v += ((bos * H + i_h) * V).to(tl.int64) + k += ((bos * H + i_h) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + if SAVE_NEW_VALUE: + v_new += ((bos * H + i_h) * V).to(tl.int64) + stride_v = H * V + stride_h = H * K * V + stride_k = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v + + if SAVE_NEW_VALUE: + p_v = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + if USE_EXP2: + b_v = b_v * tl.where(m_t, exp2(b_g_last - b_g), 0)[:, None] + b_g_last = exp2(b_g_last) + else: + b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_g_last = exp(b_g_last) + b_h1 *= b_g_last + if K > 64: + b_h2 *= b_g_last + if K > 128: + b_h3 *= b_g_last + if K > 192: + b_h4 *= b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k1, mask=(o_k1 < K), other=0.0) + if USE_EXP2: + b_h1 *= exp2(b_gk_last1)[:, None] + else: + b_h1 *= exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k2, mask=(o_k2 < K), other=0.0) + if USE_EXP2: + b_h2 *= exp2(b_gk_last2)[:, None] + else: + b_h2 *= exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k3, mask=(o_k3 < K), other=0.0) + if USE_EXP2: + b_h3 *= exp2(b_gk_last3)[:, None] + else: + b_h3 *= exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k4, mask=(o_k4 < K), other=0.0) + if USE_EXP2: + b_h4 *= exp2(b_gk_last4)[:, None] + else: + b_h4 *= exp(b_gk_last4)[:, None] + b_v = b_v.to(k.dtype.element_ty) + + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v) + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["dh0"] is not None, + "USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in ([4, 3, 2]) + for BV in [64, 32] + ], + key=["H", "K", "V", "BT", "BV", "USE_G", "USE_EXP2"], + use_cuda_graph=USE_CUDA_GRAPH, + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( + q, + k, + w, + g, + gk, + dht, + dh0, + do, + dh, + dv, + dv2, + cu_seqlens, + chunk_offsets, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_dh2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_dh3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_dh4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + q += ((bos * H + i_h) * K).to(tl.int64) + k += ((bos * H + i_h) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + do += ((bos * H + i_h) * V).to(tl.int64) + dv += ((bos * H + i_h) * V).to(tl.int64) + dv2 += ((bos * H + i_h) * V).to(tl.int64) + dh += ((boh * H + i_h) * K * V).to(tl.int64) + if USE_GK: + gk += ((bos * H + i_h) * K).to(tl.int64) + + stride_v = H * V + stride_h = H * K * V + stride_k = H * K + if USE_INITIAL_STATE: + dh0 += i_nh * K * V + if USE_FINAL_STATE_GRADIENT: + dht += i_nh * K * V + + if USE_FINAL_STATE_GRADIENT: + p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_dh1 += tl.load(p_dht1, boundary_check=(0, 1)) + if K > 64: + p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_dh2 += tl.load(p_dht2, boundary_check=(0, 1)) + if K > 128: + p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_dh3 += tl.load(p_dht3, boundary_check=(0, 1)) + if K > 192: + p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_dh4 += tl.load(p_dht4, boundary_check=(0, 1)) + + for i_t in range(NT - 1, -1, -1): + p_dh1 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_dh2 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_dh3 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_dh4 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + bg_last = tl.load(g + (bos + last_idx) * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + if USE_EXP2: + bg_last_exp = exp2(bg_last) + b_g_exp = exp2(b_g) + else: + bg_last_exp = exp(bg_last) + b_g_exp = exp(b_g) + + p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # Update dv + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.0) + b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype)) + + if K > 64: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load(gk + last_idx * H * K + o_k2, mask=(o_k2 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype)) + + if K > 128: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load(gk + last_idx * H * K + o_k3, mask=(o_k3 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype)) + + if K > 192: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load(gk + last_idx * H * K + o_k4, mask=(o_k4 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype)) + + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + if USE_EXP2: + b_dv *= tl.where(m_t, exp2(bg_last - b_g), 0)[:, None] + else: + b_dv *= tl.where(m_t, exp(bg_last - b_g), 0)[:, None] + b_dv += tl.load(p_dv, boundary_check=(0, 1)) + + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # Update dh + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + if USE_G: + b_dh1 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh1 *= exp2(b_gk_last1[:, None]) + else: + b_dh1 *= exp(b_gk_last1[:, None]) + b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 64: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh2 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh2 *= exp2(b_gk_last2[:, None]) + else: + b_dh2 *= exp(b_gk_last2[:, None]) + b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 128: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh3 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh3 *= exp2(b_gk_last3[:, None]) + else: + b_dh3 *= exp(b_gk_last3[:, None]) + b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 192: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh4 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh4 *= exp2(b_gk_last4[:, None]) + else: + b_dh4 *= exp(b_gk_last4[:, None]) + b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor = None, + gk: torch.Tensor = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, + use_exp2: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, u.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + gk=gk, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + ) + return h, v_new, final_state + + +def chunk_gated_delta_rule_bwd_dhu( + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + do: torch.Tensor, + dv: torch.Tensor, + g: torch.Tensor = None, + gk: torch.Tensor = None, + h0: torch.Tensor = None, + dht: torch.Tensor = None, + scale: float = None, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + chunk_indices: torch.LongTensor = None, + use_exp2: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *q.shape, do.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + BT = 64 + assert K <= 256, "current kernel does not support head dimension being larger than 256." + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + + dh = q.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dv2 = torch.empty_like(dv) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[grid]( + q=q, + k=k, + w=w, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dv=dv, + dv2=dv2, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + ) + return dh, dh0, dv2 diff --git a/examples/kda/FLA_KDA/fla_chunk_inter.py b/examples/kda/FLA_KDA/fla_chunk_inter.py new file mode 100644 index 000000000..e6de9bb28 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_inter.py @@ -0,0 +1,193 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + + +import torch +import triton +import triton.language as tl + +from .fla_utils import prepare_chunk_indices, exp2, autotune_cache_kwargs, check_shared_mem + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [64, 128] if check_shared_mem("ampere") else [16, 32] + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_bwd_kernel_inter( + q, + k, + v, + g, + h, + do, + dh, + dq, + dk, + dv, + dw, + dg, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + o_k = i_k * BK + tl.arange(0, BK) + o_t = i_t * BT + tl.arange(0, BT) + m_k = o_k < K + m_t = o_t < T + m_last = o_t == min(T, i_t * BT + BT) - 1 + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + g += (bos * H + i_h) * K + h += (i_tg * H + i_h) * K * V + do += (bos * H + i_h) * V + dh += (i_tg * H + i_h) * K * V + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dw += (bos * H + i_h) * K + dv += (bos * H + i_h) * V + dg += (bos * H + i_h) * K + + p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + p_gn = g + (min(T, i_t * BT + BT) - 1) * H * K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BK] + b_dgk += tl.sum(b_h * b_dh, axis=0) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + + p_dv = tl.make_block_ptr(dv, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + p_dw = tl.make_block_ptr(dw, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + + b_dgk *= exp2(b_gn) + b_dq *= scale + b_dq = b_dq * exp2(b_g) + b_dk = b_dk * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0) + + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dgk += tl.sum(b_dk * b_k, axis=0) + b_dg = b_q * b_dq - b_k * b_dk + m_last[:, None] * b_dgk + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_kda_bwd_dqkwg( + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + dv: torch.Tensor, + scale: float = None, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + dw = torch.empty_like(w) + dg = torch.empty_like(g) + + def grid(meta): + return (triton.cdiv(K, meta["BK"]), NT, B * H) + + chunk_kda_bwd_kernel_inter[grid]( + q=q, + k=k, + v=v, + g=g, + h=h, + do=do, + dh=dh, + dq=dq, + dk=dk, + dv=dv, + dw=dw, + dg=dg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return dq, dk, dw, dg diff --git a/examples/kda/FLA_KDA/fla_chunk_intra.py b/examples/kda/FLA_KDA/fla_chunk_intra.py new file mode 100644 index 000000000..244f05f1c --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_intra.py @@ -0,0 +1,650 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from .fla_utils import autotune_cache_kwargs, exp2, prepare_chunk_indices +from .cumsum import chunk_local_cumsum + +IS_TF32_SUPPORTED = False +if IS_TF32_SUPPORTED: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3") +else: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee") +SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32") +# ============================================================================ +# Fused inter + solve_tril kernel: compute off-diagonal Akk and solve in one pass +# ============================================================================ + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BK": BK}, num_warps=num_warps) for BK in [32, 64] for num_warps in [1, 2, 4]], + key=["H", "K", "BC"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_fwd_kernel_inter_solve_fused( + q, + k, + g, + beta, + Aqk, + Akk_diag, + Akk, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + """ + Fused kernel: compute inter-subchunk Akk + solve_tril in one pass. + Prerequisite: token_parallel has already computed diagonal Akk blocks in Akk_diag. + + This kernel: + 1. Computes off-diagonal Aqk blocks -> writes to global + 2. Computes off-diagonal Akk blocks -> keeps in registers + 3. Loads diagonal Akk blocks from Akk_diag (fp32) + 4. Does forward substitution on diagonals + 5. Computes merged Akk_inv + 6. Writes Akk_inv to Akk + """ + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT >= T: + return + + i_tc0 = i_t * BT + i_tc1 = i_t * BT + BC + i_tc2 = i_t * BT + 2 * BC + i_tc3 = i_t * BT + 3 * BC + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + Aqk += (bos * H + i_h) * BT + Akk += (bos * H + i_h) * BT + Akk_diag += (bos * H + i_h) * BC + + m_tc1 = (i_tc1 + tl.arange(0, BC)) < T + m_tc2 = (i_tc2 + tl.arange(0, BC)) < T + m_tc3 = (i_tc3 + tl.arange(0, BC)) < T + + b_Aqk10 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk10 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk21 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk21 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk32 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk32 = tl.zeros([BC, BC], dtype=tl.float32) + + ################################################################################ + # 1. off-diagonal blocks + ################################################################################ + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k0 = tl.make_block_ptr(k, (K, T), (1, H * K), (i_k * BK, i_tc0), (BK, BC), (0, 1)) + p_g0 = tl.make_block_ptr(g, (K, T), (1, H * K), (i_k * BK, i_tc0), (BK, BC), (0, 1)) + b_kt0 = tl.load(p_k0, boundary_check=(0, 1)).to(tl.float32) + b_gt0 = tl.load(p_g0, boundary_check=(0, 1)).to(tl.float32) + + b_kt1, b_gt1 = b_kt0, b_gt0 + b_kt2, b_gt2 = b_kt0, b_gt0 + if i_tc1 < T: + p_q1 = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + p_k1 = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + p_g1 = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + + b_q1 = tl.load(p_q1, boundary_check=(0, 1)).to(tl.float32) + b_k1 = tl.load(p_k1, boundary_check=(0, 1)).to(tl.float32) + b_g1 = tl.load(p_g1, boundary_check=(0, 1)).to(tl.float32) + b_kt1 = tl.trans(b_k1) + b_gt1 = tl.trans(b_g1) + + b_gn1 = tl.load(g + i_tc1 * H * K + o_k, mask=m_k, other=0).to(tl.float32) + b_gqn1 = tl.where(m_tc1[:, None], exp2(b_g1 - b_gn1[None, :]), 0) + b_qg1 = b_q1 * b_gqn1 + b_kg1 = b_k1 * b_gqn1 + b_kgt = b_kt0 * exp2(b_gn1[:, None] - b_gt0) + b_Aqk10 += tl.dot(b_qg1, b_kgt) + b_Akk10 += tl.dot(b_kg1, b_kgt) + + if i_tc2 < T: + p_q2 = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + p_g2 = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + + b_q2 = tl.load(p_q2, boundary_check=(0, 1)).to(tl.float32) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)).to(tl.float32) + b_g2 = tl.load(p_g2, boundary_check=(0, 1)).to(tl.float32) + b_kt2 = tl.trans(b_k2) + b_gt2 = tl.trans(b_g2) + + b_gn2 = tl.load(g + i_tc2 * H * K + o_k, mask=m_k, other=0).to(tl.float32) + b_gqn2 = tl.where(m_tc2[:, None], exp2(b_g2 - b_gn2[None, :]), 0) + b_qg2 = b_q2 * b_gqn2 + b_kg2 = b_k2 * b_gqn2 + b_kgt = b_kt0 * exp2(b_gn2[:, None] - b_gt0) + b_Aqk20 += tl.dot(b_qg2, b_kgt) + b_Akk20 += tl.dot(b_kg2, b_kgt) + + b_kgt = b_kt1 * exp2(b_gn2[:, None] - b_gt1) + b_Aqk21 += tl.dot(b_qg2, b_kgt) + b_Akk21 += tl.dot(b_kg2, b_kgt) + + if i_tc3 < T: + p_q3 = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + p_k3 = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + p_g3 = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + b_q3 = tl.load(p_q3, boundary_check=(0, 1)).to(tl.float32) + b_k3 = tl.load(p_k3, boundary_check=(0, 1)).to(tl.float32) + b_g3 = tl.load(p_g3, boundary_check=(0, 1)).to(tl.float32) + + b_gn3 = tl.load(g + i_tc3 * H * K + o_k, mask=m_k, other=0).to(tl.float32) + b_gqn3 = tl.where(m_tc3[:, None], exp2(b_g3 - b_gn3[None, :]), 0) + b_qg3 = b_q3 * b_gqn3 + b_kg3 = b_k3 * b_gqn3 + b_kgt = b_kt0 * exp2(b_gn3[:, None] - b_gt0) + b_Aqk30 += tl.dot(b_qg3, b_kgt) + b_Akk30 += tl.dot(b_kg3, b_kgt) + + b_kgt = b_kt1 * exp2(b_gn3[:, None] - b_gt1) + b_Aqk31 += tl.dot(b_qg3, b_kgt) + b_Akk31 += tl.dot(b_kg3, b_kgt) + + b_kgt = b_kt2 * exp2(b_gn3[:, None] - b_gt2) + b_Aqk32 += tl.dot(b_qg3, b_kgt) + b_Akk32 += tl.dot(b_kg3, b_kgt) + + ################################################################################ + # 2. save off-diagonal Aqk blocks and prepare Akk + ################################################################################ + if i_tc1 < T: + p_Aqk10 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + tl.store(p_Aqk10, (b_Aqk10 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b1 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc1,), (BC,), (0,)) + b_b1 = tl.load(p_b1, boundary_check=(0,)).to(tl.float32) + b_Akk10 = b_Akk10 * b_b1[:, None] + if i_tc2 < T: + p_Aqk20 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Aqk21 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) + tl.store(p_Aqk20, (b_Aqk20 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk21, (b_Aqk21 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b2 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc2,), (BC,), (0,)) + b_b2 = tl.load(p_b2, boundary_check=(0,)).to(tl.float32) + b_Akk20 = b_Akk20 * b_b2[:, None] + b_Akk21 = b_Akk21 * b_b2[:, None] + if i_tc3 < T: + p_Aqk30 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_Aqk31 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) + p_Aqk32 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0)) + tl.store(p_Aqk30, (b_Aqk30 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk31, (b_Aqk31 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk32, (b_Aqk32 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b3 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc3,), (BC,), (0,)) + b_b3 = tl.load(p_b3, boundary_check=(0,)).to(tl.float32) + b_Akk30 = b_Akk30 * b_b3[:, None] + b_Akk31 = b_Akk31 * b_b3[:, None] + b_Akk32 = b_Akk32 * b_b3[:, None] + + ################################################################################ + # 3. load diagonal Akk blocks + ################################################################################ + p_Akk00 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_Akk11 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_Akk22 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Akk33 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc3, 0), (BC, BC), (1, 0)) + # each diagonal block is stored contiguously: row i of block s is at Akk_diag[t=i_t*BT+s*BC+i, :BC] + b_Ai00 = tl.load(p_Akk00, boundary_check=(0, 1)).to(tl.float32) + b_Ai11 = tl.load(p_Akk11, boundary_check=(0, 1)).to(tl.float32) + b_Ai22 = tl.load(p_Akk22, boundary_check=(0, 1)).to(tl.float32) + b_Ai33 = tl.load(p_Akk33, boundary_check=(0, 1)).to(tl.float32) + + ################################################################################ + # 4. forward substitution on diagonals + ################################################################################ + o_i = tl.arange(0, BC) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + b_Ai00 = -tl.where(m_A, b_Ai00, 0) + b_Ai11 = -tl.where(m_A, b_Ai11, 0) + b_Ai22 = -tl.where(m_A, b_Ai22, 0) + b_Ai33 = -tl.where(m_A, b_Ai33, 0) + + # Forward substitution: load from Akk_diag (stride H*BC, columns 0:BC) + for i in range(2, min(BC, T - i_tc0)): + b_a00 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a00 = tl.where(o_i < i, b_a00, 0.0) + b_a00 += tl.sum(b_a00[:, None] * b_Ai00, 0) + b_Ai00 = tl.where((o_i == i)[:, None], b_a00, b_Ai00) + for i in range(BC + 2, min(2 * BC, T - i_tc0)): + b_a11 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a11 = tl.where(o_i < i - BC, b_a11, 0.0) + b_a11 += tl.sum(b_a11[:, None] * b_Ai11, 0) + b_Ai11 = tl.where((o_i == i - BC)[:, None], b_a11, b_Ai11) + for i in range(2 * BC + 2, min(3 * BC, T - i_tc0)): + b_a22 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a22 = tl.where(o_i < i - 2 * BC, b_a22, 0.0) + b_a22 += tl.sum(b_a22[:, None] * b_Ai22, 0) + b_Ai22 = tl.where((o_i == i - 2 * BC)[:, None], b_a22, b_Ai22) + for i in range(3 * BC + 2, min(4 * BC, T - i_tc0)): + b_a33 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a33 = tl.where(o_i < i - 3 * BC, b_a33, 0.0) + b_a33 += tl.sum(b_a33[:, None] * b_Ai33, 0) + b_Ai33 = tl.where((o_i == i - 3 * BC)[:, None], b_a33, b_Ai33) + + b_Ai00 += m_I + b_Ai11 += m_I + b_Ai22 += m_I + b_Ai33 += m_I + + # ################################################################################ + # # 5. compute merged inverse using off-diagonals + # ################################################################################ + + # we used tf32x3 to maintain matrix inverse's precision whenever possible. + b_Ai10 = -tl.dot(tl.dot(b_Ai11, b_Akk10, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + b_Ai21 = -tl.dot(tl.dot(b_Ai22, b_Akk21, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION) + b_Ai32 = -tl.dot(tl.dot(b_Ai33, b_Akk32, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai22, input_precision=SOLVE_TRIL_DOT_PRECISION) + + b_Ai20 = -tl.dot( + b_Ai22, + tl.dot(b_Akk20, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk21, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai31 = -tl.dot( + b_Ai33, + tl.dot(b_Akk31, b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai21, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai30 = -tl.dot( + b_Ai33, + tl.dot(b_Akk30, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk31, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai20, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + + ################################################################################ + # 6. store full Akk_inv to Akk + ################################################################################ + + p_Akk00 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_Akk10 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_Akk11 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc1, BC), (BC, BC), (1, 0)) + p_Akk20 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Akk21 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) + p_Akk22 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc2, 2 * BC), (BC, BC), (1, 0)) + p_Akk30 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_Akk31 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) + p_Akk32 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0)) + p_Akk33 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, 3 * BC), (BC, BC), (1, 0)) + + tl.store(p_Akk00, b_Ai00.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk10, b_Ai10.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk11, b_Ai11.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk20, b_Ai20.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk21, b_Ai21.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk22, b_Ai22.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk30, b_Ai30.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk31, b_Ai31.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk32, b_Ai32.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk33, b_Ai33.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4]], + key=["BK", "NC", "BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["B", "T"]) +def chunk_kda_bwd_kernel_intra( + q, + k, + g, + beta, + dAqk, + dAkk, + dq, + dq2, + dk, + dk2, + dg, + dg2, + db, + cu_seqlens, + chunk_indices, + B, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_kc, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_k, i_i = i_kc // NC, i_kc % NC + + all = B * T + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + i_ti = i_t * BT + i_i * BC + if i_ti >= T: + return + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + beta += bos * H + i_h + + dAqk += (bos * H + i_h) * BT + dAkk += (bos * H + i_h) * BT + dq += (bos * H + i_h) * K + dq2 += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dk2 += (bos * H + i_h) * K + dg += (bos * H + i_h) * K + dg2 += (bos * H + i_h) * K + db += (i_k * all + bos) * H + i_h + + p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + b_dq2 = tl.zeros([BC, BK], dtype=tl.float32) + b_dk2 = tl.zeros([BC, BK], dtype=tl.float32) + if i_i > 0: + p_gn = g + i_ti * H * K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (H * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (H * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp2(b_gn[None, :] - b_gk) + # [BC, BC] + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) + # [BC, BK] + b_dq2 += tl.dot(b_dAqk, b_kg) + b_dk2 += tl.dot(b_dAkk, b_kg) + b_gqn = exp2(b_g - b_gn[None, :]) + b_dq2 *= b_gqn + b_dk2 *= b_gqn + + o_i = tl.arange(0, BC) + m_dA = (i_ti + o_i) < T + o_dA = (i_ti + o_i) * H * BT + i_i * BC + p_kj = k + i_ti * H * K + o_k + p_gkj = g + i_ti * H * K + o_k + + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC] + b_dAqk = tl.load(dAqk + o_dA + j, mask=m_dA, other=0) + b_dAkk = tl.load(dAkk + o_dA + j, mask=m_dA, other=0) + # [BK] + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_kgj = b_kj[None, :] * exp2(b_g - b_gkj[None, :]) + b_dq2 += tl.where(m_i, b_dAqk[:, None] * b_kgj, 0.0) + b_dk2 += tl.where(m_i, b_dAkk[:, None] * b_kgj, 0.0) + + p_kj += H * K + p_gkj += H * K + b_db = tl.sum(b_dk2 * b_k, 1) + b_dk2 *= b_b[:, None] + + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dq2 = tl.make_block_ptr(dq2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T,), (H,), (i_ti,), (BC,), (0,)) + + b_dg2 = b_q * b_dq2 + b_dq2 = b_dq2 + tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq2, b_dq2.to(p_dq2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + tl.debug_barrier() + b_dkt = tl.zeros([BC, BK], dtype=tl.float32) + + NC = min(NC, tl.cdiv(T - i_t * BT, BC)) + if i_i < NC - 1: + p_gn = g + (min(i_ti + BC, T) - 1) * H * K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT + i_j * BC,), (BC,), (0,)) + p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, H * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, H * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + # [BC] + b_b = tl.load(p_b, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_kb = tl.load(p_k, boundary_check=(0, 1)) * b_b[:, None] + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + # [BC, BC] + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) + + o_j = i_t * BT + i_j * BC + o_i + m_j = o_j < T + # [BC, BK] + b_gkn = tl.where(m_j[:, None], exp2(b_gk - b_gn[None, :]), 0) + b_qg = b_q * b_gkn + b_kbg = b_kb * b_gkn + # [BC, BK] + b_dkt += tl.dot(b_dAqk, b_qg) + tl.dot(b_dAkk, b_kbg) + b_dkt *= exp2(b_gn[None, :] - b_g) + + o_dA = i_ti * H * BT + i_i * BC + o_i + p_qj = q + i_ti * H * K + o_k # [bs, i_ti, i_h*block_h, i_k*bk:(i_k+1)*bk] + p_kj = k + i_ti * H * K + o_k + p_gkj = g + i_ti * H * K + o_k + p_bj = beta + i_ti * H + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dAqk = tl.load(dAqk + o_dA + j * H * BT) + b_dAkk = tl.load(dAkk + o_dA + j * H * BT) + # [BK,] + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_kbj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) * tl.load(p_bj) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_gkq = exp2(b_gkj[None, :] - b_g) + b_dkt += tl.where(m_i, (b_dAkk[:, None] * b_kbj[None, :] + b_dAqk[:, None] * b_qj[None, :]) * b_gkq, 0.0) + + p_qj += H * K + p_kj += H * K + p_gkj += H * K + p_bj += H + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dk2 = tl.make_block_ptr(dk2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dg2 = tl.make_block_ptr(dg2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + + b_dg2 += (b_dk2 - b_dkt) * b_k + tl.load(p_dg, boundary_check=(0, 1)) + b_dk2 += tl.load(p_dk, boundary_check=(0, 1)) + b_dk2 += b_dkt + + tl.store(p_dk2, b_dk2.to(p_dk2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg2, b_dg2.to(p_dg2.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_kda_bwd_intra( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + dAqk: torch.Tensor, + dAkk: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + db: torch.Tensor, + dg: torch.Tensor, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, + chunk_size: int = 64, +): + B, T, H, K = k.shape + BT = chunk_size + BC = min(16, BT) + BK = min(32, triton.next_power_of_2(K)) + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + + dq2 = torch.empty_like(q) + dk2 = torch.empty_like(k) + db2 = beta.new_empty(NK, *beta.shape, dtype=torch.float) + dg2 = torch.empty_like(dg, dtype=torch.float) + grid = (NK * NC, NT, B * H) + chunk_kda_bwd_kernel_intra[grid]( + q=q, + k=k, + g=g, + beta=beta, + dAqk=dAqk, + dAkk=dAkk, + dq=dq, + dq2=dq2, + dk=dk, + dk2=dk2, + dg=dg, + dg2=dg2, + db=db2, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + B=B, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + ) + dq = dq2 + dk = dk2 + db = db2.sum(0).add_(db) + dg = chunk_local_cumsum( + dg2, + chunk_size=chunk_size, + reverse=True, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + + return dq, dk, db, dg + + +def chunk_kda_fwd_inter_solve_fused( + q, + k, + gk, + beta, + Aqk, + Akk_diag, + Akk, + scale, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +): + B, T, H, K = k.shape + assert K <= 256 + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + BC = 16 + + grid = (NT, B * H) + chunk_kda_fwd_kernel_inter_solve_fused[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk_diag=Akk_diag, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + ) diff --git a/examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py b/examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py new file mode 100644 index 000000000..1dba20282 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py @@ -0,0 +1,168 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# Token-parallel implementation of KDA intra chunk kernel + +import torch +import triton +import triton.language as tl + +from .fla_utils import exp2, autotune_cache_kwargs + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BH": BH}, num_warps=num_warps) for BH in [1, 2, 4, 8] for num_warps in [1, 2, 4, 8]], + key=["K", "H"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T", "N"]) +def chunk_kda_fwd_kernel_intra_token_parallel( + q, + k, + g, + beta, + Aqk, + Akk, + scale, + cu_seqlens, + N, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BH: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_tg, i_hg = tl.program_id(0), tl.program_id(1) + + if IS_VARLEN: + i_n = 0 + left, right = 0, N + + # Unrolled binary search (max B=2^32) + # We can limit iterations based on expected max batch size if needed + # 20 iterations covers B=1M, usually enough + for _ in range(20): + if left < right: + mid = (left + right) // 2 + if i_tg < tl.load(cu_seqlens + mid + 1).to(tl.int32): + right = mid + else: + left = mid + 1 + i_n = left + + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + i_t = i_tg - bos + else: + bos = (i_tg // T) * T + i_t = i_tg % T + + if i_t >= T: + return + + i_c = i_t // BT # chunk indices + i_s = (i_t % BT) // BC # sub_chunk indices + i_tc = i_c * BT # chunk 首坐标 + i_ts = i_tc + i_s * BC # subchunk 首坐标 + + q += bos * H * K + k += bos * H * K + g += bos * H * K + Aqk += bos * H * BT + Akk += bos * H * BC + beta += bos * H + + BK: tl.constexpr = triton.next_power_of_2(K) + o_h = tl.arange(0, BH) + o_k = tl.arange(0, BK) + m_h = (i_hg * BH + o_h) < H + m_k = o_k < K + + p_q = tl.make_block_ptr(q + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_beta = tl.make_block_ptr(beta + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,)) + # [BH, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_k = b_k * tl.load(p_beta, boundary_check=(0,)).to(tl.float32)[:, None] + + for j in range(i_ts, min(i_t + 1, min(T, i_ts + BC))): + p_kj = tl.make_block_ptr(k + j * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_gj = tl.make_block_ptr(g + j * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + # [BH, BK] + b_kj = tl.load(p_kj, boundary_check=(0, 1)).to(tl.float32) + b_gj = tl.load(p_gj, boundary_check=(0, 1)).to(tl.float32) + + b_kgj = b_kj * exp2(b_g - b_gj) + + b_kgj = tl.where(m_k[None, :], b_kgj, 0.0) + # [BH] + b_Aqk = tl.sum(b_q * b_kgj, axis=1) * scale + b_Akk = tl.sum(b_k * b_kgj, axis=1) * tl.where(j < i_t, 1.0, 0.0) + + tl.store(Aqk + i_t * H * BT + (i_hg * BH + o_h) * BT + j % BT, b_Aqk.to(Aqk.dtype.element_ty), mask=m_h) + tl.store(Akk + i_t * H * BC + (i_hg * BH + o_h) * BC + j - i_ts, b_Akk.to(Akk.dtype.element_ty), mask=m_h) + + +def chunk_kda_fwd_intra_token_parallel( + q: torch.Tensor, + k: torch.Tensor, + gk: torch.Tensor, + beta: torch.Tensor, + Aqk: torch.Tensor, + Akk: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + sub_chunk_size: int = 16, +) -> None: + """ + Token-parallel implementation: each token gets its own thread block. + Supports both fixed-length and variable-length sequences. + Reduces wasted computation on padding. + + Writes directly to Aqk and Akk tensors (in-place). + + Args: + q: [B, T, H, K] + k: [B, T, H, K] + gk: [B, T, H, K] cumsum of gates + beta: [B, T, H] + Aqk: [B, T, H, BT] output tensor to write to + Akk: [B, T, H, BC] output tensor for diagonal blocks (fp32) + scale: attention scale + chunk_size: BT (default 64) + sub_chunk_size: BC (default 16) + """ + B, T, H, K = q.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BT = chunk_size + BC = sub_chunk_size + + def grid(meta): + return (B * T, triton.cdiv(H, meta["BH"])) + + chunk_kda_fwd_kernel_intra_token_parallel[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + N=N, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + ) + return Aqk, Akk diff --git a/examples/kda/FLA_KDA/fla_chunk_o.py b/examples/kda/FLA_KDA/fla_chunk_o.py new file mode 100644 index 000000000..c29db9508 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_o.py @@ -0,0 +1,546 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + + +from .fla_utils import prepare_chunk_indices, exp, exp2, autotune_cache_kwargs, check_shared_mem + + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [64, 128] if check_shared_mem("ampere") else [16, 32] + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_fwd_kernel_o( + q, + v, + g, + h, + o, + A, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + # [BT, BK] + if USE_EXP2: + b_qg = (b_q * exp2(b_g)).to(b_q.dtype) + else: + b_qg = (b_q * exp(b_g)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h.to(b_qg.dtype)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype) + b_o += tl.dot(b_A, b_v) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_bwd_kernel_dv( + k, + g, + A, + do, + dh, + dv, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0.0) + # (SY 09/17) important to disallow tf32 here to maintain a good precision. + b_dv = tl.dot(b_A, b_do.to(b_A.dtype), allow_tf32=False) + + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + (bos + min(i_t * BT + BT, T) - 1) * H * K + i_h * K + o_k + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + b_gn = exp(tl.load(p_gn, mask=m_k, other=0)[None, :] - b_gk) + b_k = (b_k * b_gn).to(b_k.dtype) + # [BT, BV] + # (SY 09/17) it is ok to have bf16 interchunk gradient contribution here + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) + + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps) for BK in BK_LIST for BV in BV_LIST for num_warps in [2, 4, 8]], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_bwd_kernel_inter( + q, + k, + v, + g, + h, + do, + dh, + dq, + dk, + dq2, + dk2, + dg, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + g += (bos * H + i_h) * K + h += (i_tg * H + i_h) * K * V + do += (bos * H + i_h) * V + dh += (i_tg * H + i_h) * K * V + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dq2 += (bos * H + i_h) * K + dk2 += (bos * H + i_h) * K + dg += (bos * H + i_h) * K + + p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + p_gn = g + (min(T, i_t * BT + BT) - 1) * H * K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BK] + b_dgk += tl.sum(b_h * b_dh, axis=0) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + + b_dgk *= exp(b_gn) + b_dq *= scale + b_dq = b_dq * exp(b_gk) + b_dk = b_dk * exp(b_gn[None, :] - b_gk) + + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dgk += tl.sum(b_dk * b_k, axis=0) + b_dq += tl.load(p_dq, boundary_check=(0, 1)) + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + b_dg = b_q * b_dq - b_k * b_dk + # tl.debug_barrier() + b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :] + # Buggy due to strange triton compiler issue. + # m_s = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], 1., 0.) + # b_dg = tl.dot(m_s, b_dg, allow_tf32=False) + b_dgk[None, :] + p_dq = tl.make_block_ptr(dq2, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gla_fwd_o_gk( + q: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + A: torch.Tensor, + h: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, + use_exp2: bool = False, +): + B, T, H, K, V = *q.shape, v.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + o = torch.empty_like(v) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), NT, B * H) + + chunk_gla_fwd_kernel_o[grid]( + q=q, + v=v, + g=g, + h=h, + o=o, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + ) + return o + + +NUM_WARPS = [2, 4] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_G_GAMMA": lambda args: args["g_gamma"] is not None, + "USE_A": lambda args: args["A"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in NUM_WARPS for num_stages in [2, 3, 4]], + key=["H", "K", "V", "BT", "BK", "BV", "USE_G"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_bwd_kernel_dv_local( + q, + k, + g, + g_gamma, + A, + do, + dv, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + USE_A: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + do += (bos * H + i_h) * V + dv += (bos * H + i_h) * V + + if USE_A: + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_bwd_dv_local( + q: torch.Tensor, + k: torch.Tensor, + do: torch.Tensor, + g: torch.Tensor = None, + g_gamma: torch.Tensor = None, + A: torch.Tensor = None, + scale: float = None, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +) -> torch.Tensor: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + # H100 can have larger block size + if check_shared_mem("hopper", k.device.index): + CONST_TILING = 128 + elif check_shared_mem: + CONST_TILING = 64 + else: + CONST_TILING = 32 + BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) + BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dv = torch.empty_like(do) + grid = (NT, B * H) + chunk_bwd_kernel_dv_local[grid]( + q=q, + k=k, + g=g, + g_gamma=g_gamma, + A=A, + do=do, + dv=dv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dv + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4]], + key=["BV", "BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_bwd_kernel_dA( + v, + do, + dA, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H * V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + b_dA += tl.dot(b_do, b_v) + + p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_dA = tl.where(m_s, b_dA * scale, 0.0) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gla_bwd_dA( + v: torch.Tensor, + do: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +): + B, T, H, V = v.shape + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BV = min(64, triton.next_power_of_2(V)) + + dA = v.new_empty(B, T, H, BT, dtype=torch.float32) + grid = (NT, B * H) + chunk_gla_bwd_kernel_dA[grid]( + v=v, + do=do, + dA=dA, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + V=V, + BT=BT, + BV=BV, + ) + return dA diff --git a/examples/kda/FLA_KDA/fla_utils.py b/examples/kda/FLA_KDA/fla_utils.py new file mode 100644 index 000000000..b278aec90 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_utils.py @@ -0,0 +1,240 @@ +import contextlib +import functools +import inspect +import os +import warnings +from collections.abc import Callable +from typing import Any +from packaging import version +from enum import Enum + +import torch +import triton +import triton.language.extra.libdevice as tldevice + + +device = "cuda" +device_torch_lib = getattr(torch, device) + +exp = tldevice.fast_expf +exp2 = tldevice.exp2 +log = tldevice.fast_logf +log2 = tldevice.fast_log2f + +IS_NVIDIA_HOPPER = True and ("NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9) +USE_CUDA_GRAPH = True and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" + + +FLA_CACHE_RESULTS = os.getenv("FLA_CACHE_RESULTS", "1") == "1" +SUPPORTS_AUTOTUNE_CACHE = "cache_results" in inspect.signature(triton.autotune).parameters +autotune_cache_kwargs = {"cache_results": FLA_CACHE_RESULTS} if SUPPORTS_AUTOTUNE_CACHE else {} + + +# error check,copy from +def get_abs_err(x, y): + return (x.detach() - y.detach()).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item() + base = (x.detach()).flatten().square().mean().sqrt().item() + return err / (base + 1e-8) + + +def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6): + abs_atol = get_abs_err(ref, tri) + msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}" + print(msg) + error_rate = get_err_ratio(ref, tri) + if abs_atol <= err_atol: + return + if warning or (error_rate < 0.01 or abs_atol <= 0.3): + if error_rate > ratio: + warnings.warn(msg, stacklevel=2) + else: + assert error_rate < ratio, msg + + +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: tuple | None = None + last_kwargs: dict | None = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if ( + last_args is not None + and last_kwargs is not None + and len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_chunk_indices( + cu_seqlens: torch.LongTensor, + chunk_size: int, +) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +# @functools.cache +# def get_multiprocessor_count(tensor_idx: int = 0) -> int: +# try: +# return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['multiprocessor_count'] +# except BaseException: +# # Maybe we use a NPU device. +# if triton.runtime.driver.active.get_current_target().backend == 'npu': +# return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['num_vectorcore'] +# else: +# return 1 +@functools.cache +def get_multiprocessor_count(tensor_idx: int = 0) -> int: + """ + Compatible across Triton versions: + - 2.0.x + - 2.1.0 + - 2.2.x and above + Supports CUDA and NPU. + """ + + # ---- Try the newer Triton 2.2+ API ---- + try: + drv = triton.runtime.driver.active + props = drv.utils.get_device_properties(tensor_idx) + return props.get("multiprocessor_count") or props.get("num_vectorcore") or 1 + except Exception: + pass + + # ---- Fallback: Triton 2.0 / 2.1 API ---- + try: + cuda = triton.runtime.driver.CudaDriver + dev = cuda.get_current_device() + props = cuda.get_device_properties(dev) + return props.get("multiprocessor_count", 1) + except Exception: + pass + + return 1 + + +def input_guard( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = custom_device_ctx(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def check_pytorch_version(version_s: str = "2.4") -> bool: + return version.parse(torch.__version__) >= version.parse(version_s) + + +if check_pytorch_version("2.4"): + device = "cuda" + autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device) + autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device) + + def custom_device_ctx(index: int): + return device_torch_lib.device(index) +else: + assert device == "cuda", "Only cuda device is supported for PyTorch version < 2.4.0." + autocast_custom_fwd = device_torch_lib.amp.custom_fwd + autocast_custom_bwd = device_torch_lib.amp.custom_bwd + + def custom_device_ctx(index: int): + return torch.cuda.device(index) + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/examples/kda/FLA_KDA/fla_wy_fast.py b/examples/kda/FLA_KDA/fla_wy_fast.py new file mode 100644 index 000000000..a042c2a5f --- /dev/null +++ b/examples/kda/FLA_KDA/fla_wy_fast.py @@ -0,0 +1,312 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from .fla_utils import prepare_chunk_indices, exp2, autotune_cache_kwargs + + +@triton.heuristics( + { + "STORE_QG": lambda args: args["qg"] is not None, + "STORE_KG": lambda args: args["kg"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"DOT_PRECISION": DOT_PRECISION}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + for DOT_PRECISION in (["tf32x3", "ieee"]) + ], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + q, + k, + qg, + kg, + v, + beta, + w, + u, + A, + gk, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + STORE_QG: tl.constexpr, + STORE_KG: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_b[:, None] # 乘beta + + p_gk = tl.make_block_ptr(gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kb *= exp2(b_gk) + if STORE_QG: + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_qg = b_q * exp2(b_gk) + tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1)) + if STORE_KG: + last_idx = min(i_t * BT + BT, T) - 1 + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + b_gn = tl.load(gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0) # chunk的最后一个g + b_kg = b_k * tl.where((i_t * BT + tl.arange(0, BT) < T)[:, None], exp2(b_gn[None, :] - b_gk), 0) + p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1)) + + b_w = tl.dot(b_A, b_kb.to(b_k.dtype)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [2, 4] for num_stages in [2, 3, 4]], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def prepare_wy_repr_bwd_kernel( + k, + v, + beta, + gk, + A, + dA, + dw, + du, + dk, + dk2, + dv, + db, + dg, + dg2, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_b = tl.make_block_ptr(beta + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_db = tl.make_block_ptr(db + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_b = tl.load(p_b, boundary_check=(0,)) + b_db = tl.zeros([BT], dtype=tl.float32) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk2 = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg2 = tl.make_block_ptr(dg2 + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_gk = tl.make_block_ptr(gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_gk_exp = exp2(tl.load(p_gk, boundary_check=(0, 1))) + b_kbg = b_k * b_b[:, None] * b_gk_exp + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + + b_dA += tl.dot(b_dw, tl.trans(b_kbg).to(b_dw.dtype)) + b_dkbg = tl.dot(b_A, b_dw) + b_dk = b_dkbg * b_gk_exp * b_b[:, None] + tl.load(p_dk, boundary_check=(0, 1)) + b_db += tl.sum(b_dkbg * b_k * b_gk_exp, 1) + b_dg = b_kbg * b_dkbg + tl.load(p_dg, boundary_check=(0, 1)) + + tl.store(p_dk2, b_dk.to(p_dk2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg2, b_dg.to(p_dg2.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_vb)) + b_dvb = tl.dot(b_A, b_du) + b_dv = b_dvb * b_b[:, None] + b_db += tl.sum(b_dvb * b_v, 1) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_dA = tl.where(m_A, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + + b_dA = tl.where(m_A, -b_dA, 0) + + # if using gk, save dA first and handle dk in another kernel + p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + q: torch.Tensor = None, + gk: torch.Tensor = None, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = A.shape[-1] + BK = 64 + BV = 64 + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + w = torch.empty_like(k) + u = torch.empty_like(v) + qg = torch.empty_like(q) if q is not None else None + kg = torch.empty_like(k) if gk is not None else None + recompute_w_u_fwd_kernel[(NT, B * H)]( + q=q, + k=k, + qg=qg, + kg=kg, + v=v, + beta=beta, + w=w, + u=u, + A=A, + gk=gk, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u, qg, kg + + +def prepare_wy_repr_bwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + gk: torch.Tensor, + A: torch.Tensor, + dk: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + dg: torch.Tensor, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = 64 + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + CONST_TILING = 64 + BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) + BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) + + dk2 = torch.empty_like(dk, dtype=torch.float) + dv = torch.empty_like(v) + dg2 = torch.empty_like(gk, dtype=torch.float) + dA = torch.empty_like(A, dtype=torch.float) + db = torch.empty_like(beta, dtype=torch.float) + prepare_wy_repr_bwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + gk=gk, + A=A, + dA=dA, + dw=dw, + du=du, + dk=dk, + dk2=dk2, + dv=dv, + db=db, + dg=dg, + dg2=dg2, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + dk = dk2 + dg = dg2 + + return dk, dv, db, dg, dA diff --git a/examples/kda/README.md b/examples/kda/README.md new file mode 100644 index 000000000..f445a9f09 --- /dev/null +++ b/examples/kda/README.md @@ -0,0 +1,7 @@ +# KDA kernel implementation with TileLang +## Requirement +- TileLang: 0.1.6.post2+cuda.git729e66ca +- triton: 3.2.0 +- FLA: commit 9714c5(used for comparison) + +We copy the needed files and function from flash-linear-attention to the FLA_KDA/ for easily comparison. diff --git a/examples/kda/chunk_bwd_dqkwg.py b/examples/kda/chunk_bwd_dqkwg.py new file mode 100644 index 000000000..d3d4df4b4 --- /dev/null +++ b/examples/kda/chunk_bwd_dqkwg.py @@ -0,0 +1,274 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_inter import chunk_kda_bwd_dqkwg +from test_utils_kda import do_bench, compare_tensors + +import torch + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + gate_dtype, +): + BS = S // chunk_size + q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + v_new = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + w = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + g = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + h = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() + dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + do = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + dh = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() + + return q, k, v_new, w, g, h, dv, do, dh + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + gate_dtype, +): + dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() + dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() + dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + return dq, dk, dw, dg + + +def get_configs(): + import itertools + + block_DK = [32, 64, 128] + block_DV = [32, 64, 128] + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3] + _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) + + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=3, rep=5) +@tilelang.jit(out_idx=[-4, -3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def chunk_bwd_dqkwg( + B, + S, + H, + DK, + DV, + scale, + chunk_size, + input_dtype, + gate_dtype, + block_DK=32, + block_DV=32, + threads=32, + num_stages=0, +): + block_S = chunk_size + BS = S // block_S + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + H_shape = (B, BS, H, DK, DV) + + @T.prim_func + def kernel( + Q: T.Tensor(K_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + G: T.Tensor(K_shape, dtype=gate_dtype), + h: T.Tensor(H_shape, dtype=input_dtype), + dv: T.Tensor(V_shape, dtype=input_dtype), + DO: T.Tensor(V_shape, dtype=input_dtype), + Dh: T.Tensor(H_shape, dtype=input_dtype), + dq: T.Tensor(K_shape, dtype=T.float32), + dk: T.Tensor(K_shape, dtype=T.float32), + dw: T.Tensor(K_shape, dtype=gate_dtype), + dg: T.Tensor(K_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh): + bb, bh = bbh // H, bbh % H + chunk_last_idx = T.min(S, (bs + 1) * block_S) - 1 + + dgkn_fragment = T.alloc_fragment((block_DK), dtype=T.float32) + dgkn_fragment_tmp = T.alloc_fragment((block_DK,), dtype=T.float32) + dq_fragment = T.alloc_fragment((block_S, block_DK), dtype=T.float32) + dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=T.float32) + dw_fragment = T.alloc_fragment((block_S, block_DK), dtype=T.float32) + dgk_shared = T.alloc_shared((block_S, block_DK), dtype=T.float32) + + h_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + dh_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + dgkn_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) # d of last token in a chunk + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + DO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + DV_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + G_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) # chunk G + Gn_shared = T.alloc_shared((block_DK), dtype=input_dtype) # chunk last token G + Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + + dkkn_shared = T.alloc_shared((block_S, block_DK), dtype=T.float32) + pp_shared = T.alloc_shared((block_DK), dtype=T.float32) + + T.clear(dgkn_fragment) + T.clear(dq_fragment) + T.clear(dk_fragment) + T.clear(dw_fragment) + + T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], G_shared) + T.copy(G[bb, chunk_last_idx, bh, bk * block_DK : (bk + 1) * block_DK], Gn_shared) + + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared) + T.copy(Dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.copy(DO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], DO_shared) + T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], DV_shared) + # += reduce_sum + for i_k1, i_v1 in T.Parallel(block_DK, block_DV): + dgkn_shared[i_k1, i_v1] = h_shared[i_k1, i_v1] * dh_shared[i_k1, i_v1] + T.reduce_sum(dgkn_shared, dgkn_fragment_tmp, dim=1, clear=True) # [block_DK] + for i_ks in T.Parallel(block_DK): + dgkn_fragment[i_ks] += dgkn_fragment_tmp[i_ks] + T.gemm(DO_shared, h_shared, dq_fragment, transpose_B=True, clear_accum=False) # [block_S, block_DK] + T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True, clear_accum=False) # [block_S, block_DK] + T.gemm(DV_shared, h_shared, dw_fragment, transpose_B=True, clear_accum=False) # [block_S, block_DK] + # chunk last token + for i_k0 in T.Parallel(block_DK): + dgkn_fragment[i_k0] = dgkn_fragment[i_k0] * T.exp2(Gn_shared[i_k0]) + + for i_s, i_k in T.Parallel(block_S, block_DK): + dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k] + dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale * T.exp2(G_shared[i_s, i_k]) + dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp2(Gn_shared[i_k] - G_shared[i_s, i_k]) + + T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], Q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], K_shared) + + for i_s2, i_k2 in T.Parallel(block_S, block_DK): + dkkn_shared[i_s2, i_k2] = dk_fragment[i_s2, i_k2] * K_shared[i_s2, i_k2] + T.reduce_sum(dkkn_shared, pp_shared, dim=0, clear=True) + for i_k3 in T.Parallel(block_DK): + pp_shared[i_k3] += dgkn_fragment[i_k3] + + for i_s4, i_k4 in T.Parallel(block_S, block_DK): + dgk_shared[i_s4, i_k4] = ( + Q_shared[i_s4, i_k4] * dq_fragment[i_s4, i_k4] + - K_shared[i_s4, i_k4] * dk_fragment[i_s4, i_k4] + + T.if_then_else(chunk_last_idx == bs * block_S + i_s4, pp_shared[i_k4], 0.0) + ) + + T.copy(dgk_shared, dg[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + scale, + input_dtype, + gate_dtype, + qk_dtype, + chunk_size, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=64, + block_DV=32, + threads=128, + num_stages=0, +): + q, k, v_new, w, g, h, dv, do, dh = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, gate_dtype)) + + dq_ref, dk_ref, dw_ref, dg_ref = chunk_kda_bwd_dqkwg( + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + scale=scale, + ) + + dq, dk, dw, dg = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, gate_dtype)) + kernel = chunk_bwd_dqkwg( + B=B, S=S, H=H, DK=DK, DV=DV, scale=scale, chunk_size=chunk_size, input_dtype=input_dtype, gate_dtype=gate_dtype + ) + dq, dk, dw, dg = kernel(q, k, v_new, g, h, dv, do, dh) + + compare_tensors("dq", dq_ref, dq) + compare_tensors("dk", dk_ref, dk) + compare_tensors("dw", dw_ref, dw) + compare_tensors("dg", dg_ref, dg) + + fla_time = do_bench( + chunk_kda_bwd_dqkwg, + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + scale=scale, + ) + tilelang_time = do_bench(kernel, q, k, v_new, g, h, dv, do, dh) + print("fla_time:", fla_time) + print("tilelang_time:", tilelang_time) + + +def main(): + run_test( + B=1, + S=8192, + H=64, + DK=128, + DV=128, + scale=1.0, + input_dtype="float32", + gate_dtype="float32", # gate must be float32 + qk_dtype="float32", + chunk_size=64, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=32, + block_DV=32, + threads=128, + num_stages=2, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_bwd_dv.py b/examples/kda/chunk_bwd_dv.py new file mode 100644 index 000000000..cdbe0a899 --- /dev/null +++ b/examples/kda/chunk_bwd_dv.py @@ -0,0 +1,150 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune +import sys # noqa: F401 + +from FLA_KDA.fla_chunk_o import chunk_bwd_dv_local +from test_utils_kda import compare_tensors, do_bench + +import torch + +torch.random.manual_seed(1) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + do_dtype, +): + q = torch.randn(B, S, H, DK, dtype=do_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=do_dtype).cuda() + DO = torch.randn(B, S, H, DV, dtype=do_dtype).cuda() + A = torch.randn(B, S, H, chunk_size, dtype=input_dtype).cuda() + return q, k, DO, A + + +def prepare_output( + B, + S, + H, + DV, + chunk_size, + output_dtype, +): + dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return dv + + +def get_configs(): + import itertools + + block_DV = [32, 64, 128] + threads = [32, 64, 128] + num_stages = [0, 1, 2, 3, 4] + _configs = list(itertools.product(block_DV, threads, num_stages)) + configs = [{"block_DV": c[0], "threads": c[1], "num_stages": c[2]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=10, rep=5) +@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_bwd_kernel_dv_local( + B, + S, + H, + DV, + input_dtype, + output_dtype, + do_dtype, + chunk_size, + block_DV=128, + threads=128, + num_stages=1, +): + block_S = BS = chunk_size + DO_shape = (B, S, H, DV) + A_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + DO: T.Tensor(DO_shape, dtype=do_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dv: T.Tensor(DO_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + + A_shared = T.alloc_shared((BS, BS), dtype=do_dtype) + DO_shared = T.alloc_shared((BS, block_DV), dtype=do_dtype) + dv_fragment = T.alloc_fragment((BS, block_DV), dtype=T.float32) + dv_shared = T.alloc_shared((BS, block_DV), dtype=output_dtype) + + T.copy(A[bb, bs * BS : (bs + 1) * BS, bh, :], A_shared) + for i_s1, i_s2 in T.Parallel(BS, BS): + A_shared[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, A_shared[i_s1, i_s2], 0.0) + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(DO[bb, bs * BS : (bs + 1) * BS, bh, i_v * block_DV : (i_v + 1) * block_DV], DO_shared) + T.gemm(A_shared, DO_shared, dv_fragment, transpose_A=True, clear_accum=True) # transpose_A: A^T + T.copy(dv_fragment, dv_shared) + T.copy(dv_shared, dv[bb, bs * BS : (bs + 1) * BS, bh, i_v * block_DV : (i_v + 1) * block_DV]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + scale, + input_dtype, + do_dtype, + output_dtype, + chunk_size, +): + q, k, DO, A = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, do_dtype)) + dv_ref = chunk_bwd_dv_local(q, k, do=DO, A=A) + + dv_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, output_dtype)) + kernel = tilelang_chunk_bwd_kernel_dv_local( + B=B, + S=S, + H=H, + DV=DV, + input_dtype=input_dtype, + output_dtype=output_dtype, + do_dtype=do_dtype, + chunk_size=chunk_size, + ) + dv_tilelang = kernel(DO, A) + compare_tensors("dv", dv_ref, dv_tilelang) + + fla_time = do_bench(chunk_bwd_dv_local, q, k, do=DO, A=A) + tilelang_time = do_bench(kernel, DO, A) + print("fla_time: ", fla_time) + print("tilelang_time: ", tilelang_time) + + +def main(): + run_test( + B=1, + S=1024 * 8, # 32768 + H=64, + DK=128, + DV=128, + scale=1.0, + input_dtype="bfloat16", + do_dtype="float32", + output_dtype="bfloat16", + chunk_size=64, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_bwd_gla_dA.py b/examples/kda/chunk_bwd_gla_dA.py new file mode 100644 index 000000000..913fa9171 --- /dev/null +++ b/examples/kda/chunk_bwd_gla_dA.py @@ -0,0 +1,147 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_o import chunk_gla_bwd_dA +from test_utils_kda import compare_tensors, do_bench + +import torch + +torch.random.manual_seed(1) + + +def prepare_input( + B, + S, + H, + DV, + chunk_size, + input_dtype, + do_dtype, +): + DO = torch.randn(B, S, H, DV, dtype=do_dtype).cuda() + V_new = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + return DO, V_new + + +def prepare_output( + B, + S, + H, + DV, + chunk_size, + d_type, +): + dA = torch.empty(B, S, H, chunk_size, dtype=d_type).cuda() + return dA + + +def get_configs(): + import itertools + + block_DV = [32, 64, 128] + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3, 4] + _configs = list(itertools.product(block_DV, threads, num_stages)) + configs = [{"block_DV": c[0], "threads": c[1], "num_stages": c[2]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=10, rep=5) +@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_bwd_kernel_dv_local( + B, + S, + H, + DV, + scale, + input_dtype, + da_dtype, + do_dtype, + chunk_size, + block_DV=128, + threads=128, + num_stages=1, +): + block_S = BS = chunk_size + DO_shape = (B, S, H, DV) + V_shape = (B, S, H, DV) + dA_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + DO: T.Tensor(DO_shape, dtype=do_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + dA: T.Tensor(dA_shape, dtype=da_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) + dA_fragment = T.alloc_fragment((block_S, block_S), dtype=T.float32) + + T.clear(dA_fragment) + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(DO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], do_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.gemm(do_shared, V_shared, dA_fragment, transpose_B=True) + for i_s1, i_s2 in T.Parallel(block_S, block_S): + dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0) # 下三角矩阵 + T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, 0:block_S]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + scale, + input_dtype, + do_dtype, + da_dtype, + chunk_size, +): + DO, V_new = prepare_input(B, S, H, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, do_dtype)) + print(DO.dtype, V_new.dtype) + dA_ref = chunk_gla_bwd_dA(v=V_new, do=DO, scale=scale) + + dA_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, da_dtype)) + kernel = tilelang_chunk_bwd_kernel_dv_local( + B=B, + S=S, + H=H, + DV=DV, + scale=scale, + input_dtype=input_dtype, + da_dtype=da_dtype, + do_dtype=do_dtype, + chunk_size=chunk_size, + ) + dA_tilelang = kernel(DO, V_new) + compare_tensors("dA", dA_ref, dA_tilelang) + fla_time = do_bench(chunk_gla_bwd_dA, v=V_new, do=DO, scale=scale) + tilelang_time = do_bench(kernel, DO, V_new) + print("fla_time:", fla_time) + print("tilelang_time:", tilelang_time) + + +def main(): + run_test( + B=1, + S=1024 * 8, # 32768 + H=64, + DK=128, + DV=128, + scale=1.0, + input_dtype="bfloat16", + do_dtype="bfloat16", + da_dtype="float32", + chunk_size=64, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_bwd_intra.py b/examples/kda/chunk_bwd_intra.py new file mode 100644 index 000000000..6c66732b4 --- /dev/null +++ b/examples/kda/chunk_bwd_intra.py @@ -0,0 +1,493 @@ +# Reference: FLA_KDA/fla_chunk_intra.py +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_intra import chunk_kda_bwd_intra +from FLA_KDA.cumsum import chunk_local_cumsum +from test_utils_kda import compare_tensors, do_bench + +import torch + +torch.random.manual_seed(0) +torch.set_printoptions(profile="full") + + +def prepare_input( + B, + S, + H, + DK, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + BT = chunk_size + q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + g = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + + # dAqk and dAkk are gradients w.r.t. Aqk and Akk + # Shape: (B, S, H, BT) + dAqk = torch.randn(B, S, H, BT, dtype=input_dtype).cuda() + dAkk = torch.randn(B, S, H, BT, dtype=input_dtype).cuda() + + # Initial gradients (will be updated by the kernel) + dq = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + dk = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + db = torch.randn(B, S, H, dtype=input_dtype).cuda() + dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + + return q, k, g, beta, dAqk, dAkk, dq, dk, db, dg + + +def prepare_output( + B, + S, + H, + DK, + chunk_size, + NK, + output_dtype, + gate_dtype, + state_dtype, +): + dq = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + db = torch.empty(NK, B, S, H, dtype=output_dtype).cuda() + dg = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda() + return dq, dk, db, dg + + +def get_configs(): + import itertools + + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3] + _configs = list(itertools.product(threads, num_stages)) + + configs = [{"threads": c[0], "num_stages": c[1]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=5, rep=5) +@tilelang.jit( + out_idx=[-4, -3, -2, -1], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, +) +def tilelang_chunk_bwd_intra( + # task config + B, + S, + H, + DK, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + # kernel config + block_DK, + block_BC=16, + threads=128, + num_stages=0, +): + BT = chunk_size + BC = block_BC # sub-chunk size, typically 16 + + NC = BT // BC # number of sub-chunks + NT = T.ceildiv(S, BT) + NK = T.ceildiv(DK, block_DK) # number of K blocks + + K_shape = (B, S, H, DK) + Beta_shape = (B, S, H) + G_shape = (B, S, H, DK) + BT_shape = (B, S, H, BT) # for dAqk and dAkk + + dq_shape = (B, S, H, DK) + dk_shape = (B, S, H, DK) + db_shape = (B, S, H) + db2_shape = (NK, B, S, H) + dg_shape = (B, S, H, DK) + + @T.prim_func + def kernel( + # input + q: T.Tensor(K_shape, dtype=input_dtype), + k: T.Tensor(K_shape, dtype=input_dtype), + g: T.Tensor(G_shape, dtype=gate_dtype), + beta: T.Tensor(Beta_shape, dtype=input_dtype), + dAqk: T.Tensor(BT_shape, dtype=input_dtype), + dAkk: T.Tensor(BT_shape, dtype=input_dtype), + dq: T.Tensor(dq_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=input_dtype), + db: T.Tensor(db_shape, dtype=input_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), + # output + dq2: T.Tensor(dq_shape, dtype=output_dtype), + dk2: T.Tensor(dk_shape, dtype=output_dtype), + db2: T.Tensor(db2_shape, dtype=output_dtype), + dg2: T.Tensor(dg_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(DK, block_DK) * NC, NT, B * H, threads=threads) as (i_kc, i_t, i_bh): + i_k, i_i = i_kc // NC, i_kc % NC + bb, bh = i_bh // H, i_bh % H + + # actual sub-chunk index + i_ti = i_t * BT + i_i * BC + + # current sub-chunk data + q_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + k_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + beta_shared = T.alloc_shared((BC,), dtype=input_dtype) + g_current_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + gn_shared = T.alloc_shared((block_DK,), dtype=gate_dtype) # last token's g in current sub-chunk + + dq_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + dk_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + dg_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + + # Allocate fragments + dq2_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + dk2_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + dg2_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + db_fragment = T.alloc_fragment((BC,), dtype=accum_dtype) + + # Initialize fragments + T.clear(dq2_fragment) + T.clear(dk2_fragment) + T.clear(dg2_fragment) + T.clear(db_fragment) + + # Temporary shared memory for previous sub-chunks + k_prev_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + g_prev_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + dAqk_prev_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + dAkk_prev_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + + # Temporary fragment for b_kg computation + kg_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + + kj_shared = T.alloc_shared((block_DK,), dtype=T.float32) + gkj_shared = T.alloc_shared((block_DK,), dtype=T.float32) + kgj_fragment = T.alloc_fragment((BC, block_DK), dtype=T.float32) + dAqk_col = T.alloc_shared((BC,), dtype=input_dtype) + dAkk_col = T.alloc_shared((BC,), dtype=input_dtype) + + # Load g, q, k for current sub-chunk + T.copy(q[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], q_shared) + T.copy(k[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], k_shared) + T.copy(g[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], g_current_shared) + T.copy(beta[bb, i_ti : i_ti + BC, bh], beta_shared) + + if i_i > 0: + chunk_first_idx = i_ti # chunk first token idx + + T.copy(g[bb, chunk_first_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gn_shared) # Get the first token's g value (b_gn) + + # Loop over previous sub-chunks (i_j from 0 to i_i-1) + # Since i_i is computed from i_kc % NC and NC is small, we can use conditional blocks + # Process each possible previous sub-chunk with conditional execution + for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index ofprevious sub_chunks + prev_ti = i_t * BT + i_j * BC + T.copy(k[bb, prev_ti : prev_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], k_prev_shared) + T.copy(g[bb, prev_ti : prev_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], g_prev_shared) + + T.copy(dAqk[bb, i_ti : i_ti + BC, bh, i_j * BC : (i_j + 1) * BC], dAqk_prev_shared) + T.copy(dAkk[bb, i_ti : i_ti + BC, bh, i_j * BC : (i_j + 1) * BC], dAkk_prev_shared) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + kg_fragment[i_bc, i_k2] = k_prev_shared[i_bc, i_k2] * T.exp2(gn_shared[i_k2] - g_prev_shared[i_bc, i_k2]) + + T.gemm(dAqk_prev_shared, kg_fragment, dq2_fragment, clear_accum=False) + T.gemm(dAkk_prev_shared, kg_fragment, dk2_fragment, clear_accum=False) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + gqn = T.exp2(g_current_shared[i_bc, i_k2] - gn_shared[i_k2]) + dq2_fragment[i_bc, i_k2] = dq2_fragment[i_bc, i_k2] * gqn + dk2_fragment[i_bc, i_k2] = dk2_fragment[i_bc, i_k2] * gqn + + # Process current sub-chunk diagonal + loop_length = T.min(BC, S - i_t * BT - i_i * BC) + for j in T.Pipelined(loop_length, num_stages=num_stages): + token_j_idx = i_ti + j + + T.copy(k[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], kj_shared) + T.copy(g[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gkj_shared) + T.copy(dAqk[bb, i_ti : i_ti + BC, bh, i_i * BC + j], dAqk_col) + T.copy(dAkk[bb, i_ti : i_ti + BC, bh, i_i * BC + j], dAkk_col) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + kgj_fragment[i_bc, i_k2] = kj_shared[i_k2] * T.exp2(g_current_shared[i_bc, i_k2] - gkj_shared[i_k2]) + dq2_fragment[i_bc, i_k2] += T.if_then_else(i_bc >= j, dAqk_col[i_bc] * kgj_fragment[i_bc, i_k2], 0.0) + dk2_fragment[i_bc, i_k2] += T.if_then_else(i_bc >= j, dAkk_col[i_bc] * kgj_fragment[i_bc, i_k2], 0.0) + + # Compute b_db = sum(b_dk2 * b_k, dim=1) + dk2_k_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dk2_k_fragment[i_bc, i_k2] = dk2_fragment[i_bc, i_k2] * k_shared[i_bc, i_k2] + T.reduce_sum(dk2_k_fragment, db_fragment, dim=1, clear=True) + + # b_dk2 *= b_b[:, None] + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dk2_fragment[i_bc, i_k2] = dk2_fragment[i_bc, i_k2] * beta_shared[i_bc] + + # Compute b_dg2 = b_q * b_dq2 (before adding dq to dq2) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dg2_fragment[i_bc, i_k2] = q_shared[i_bc, i_k2] * dq2_fragment[i_bc, i_k2] + + # Load dq and compute b_dq2 = b_dq2 + b_dq + T.copy(dq[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], dq_shared) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dq2_fragment[i_bc, i_k2] = dq2_fragment[i_bc, i_k2] + dq_shared[i_bc, i_k2] + + # # Store results + T.copy(dq2_fragment, dq2[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK]) + T.copy(db_fragment, db2[i_k, bb, i_ti : i_ti + BC, bh]) + + # Initialize dkt_fragment for processing subsequent sub-chunks and lower triangular part + dkt_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + T.clear(dkt_fragment) + + # Temporary shared memory for subsequent sub-chunks + q_next_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + k_next_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + g_next_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + beta_next_shared = T.alloc_shared((BC,), dtype=input_dtype) + dAqk_next_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + dAkk_next_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + + # Temporary fragments for computation + gkn_shared = T.alloc_shared((BC, block_DK), dtype=accum_dtype) + qg_shared = T.alloc_shared((BC, block_DK), dtype=accum_dtype) + kbg_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + kbg_shared = T.alloc_shared((BC, block_DK), dtype=accum_dtype) + dkt_temp_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + # T.use_swizzle(10) + + NC_actual = T.min(NC, T.ceildiv(S - i_t * BT, BC)) # Process subsequent sub-chunks (i_j from i_i+1 to NC-1) + if i_i < NC_actual - 1: + # Get the last token's g value in current sub-chunk + chunk_last_idx = T.min(S, i_ti + BC) - 1 + gn_last_shared = T.alloc_shared((block_DK,), dtype=gate_dtype) + T.copy(g[bb, chunk_last_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gn_last_shared) + + # Loop over subsequent sub-chunks + for i_j in T.Pipelined(i_i + 1, NC_actual, num_stages=num_stages): + i_tj = i_t * BT + i_j * BC + + T.copy(q[bb, i_tj : i_tj + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], q_next_shared) + T.copy(k[bb, i_tj : i_tj + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], k_next_shared) + T.copy(g[bb, i_tj : i_tj + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], g_next_shared) + T.copy(beta[bb, i_tj : i_tj + BC, bh], beta_next_shared) + + T.copy(dAqk[bb, i_tj : i_tj + BC, bh, i_i * BC : (i_i + 1) * BC], dAqk_next_shared) # [BC, BC] need transpose + T.copy(dAkk[bb, i_tj : i_tj + BC, bh, i_i * BC : (i_i + 1) * BC], dAkk_next_shared) # [BC, BC] need transpose + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + # kbg = k * beta + kbg_fragment[i_bc, i_k2] = k_next_shared[i_bc, i_k2] * beta_next_shared[i_bc] + gkn_shared[i_bc, i_k2] = T.if_then_else( + i_tj + i_bc < S, T.exp2(g_next_shared[i_bc, i_k2] - gn_last_shared[i_k2]), 0.0 + ) + + # Compute qg and kbg + for i_bc, i_k2 in T.Parallel(BC, block_DK): + qg_shared[i_bc, i_k2] = q_next_shared[i_bc, i_k2] * gkn_shared[i_bc, i_k2] + kbg_shared[i_bc, i_k2] = kbg_fragment[i_bc, i_k2] * gkn_shared[i_bc, i_k2] + + # Accumulate: dkt += dAqk^T @ qg + dAkk^T @ kbg + # Use transpose_A=True because dAqk/dAkk are loaded in (T, BT) layout but we need (BT, T) for gemm + T.gemm(dAqk_next_shared, qg_shared, dkt_temp_fragment, transpose_A=True, clear_accum=True) + T.gemm(dAkk_next_shared, kbg_shared, dkt_temp_fragment, transpose_A=True, clear_accum=False) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dkt_fragment[i_bc, i_k2] = dkt_fragment[i_bc, i_k2] + dkt_temp_fragment[i_bc, i_k2] + + # Scale dkt by exp2(gn_last - g_current) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + g_scale = T.exp2(gn_last_shared[i_k2] - g_current_shared[i_bc, i_k2]) + dkt_fragment[i_bc, i_k2] = dkt_fragment[i_bc, i_k2] * g_scale + + # Process lower triangular part of current sub-chunk diagonal + # This corresponds to j <= i_bc in the diagonal block + qj_shared = T.alloc_shared((block_DK,), dtype=T.float32) + kj_shared_lower = T.alloc_shared((block_DK,), dtype=T.float32) + gj_shared_lower = T.alloc_shared((block_DK,), dtype=T.float32) + bj_local = T.alloc_local((1), dtype=input_dtype) + dAqk_col_lower = T.alloc_shared((BC,), dtype=input_dtype) + dAkk_col_lower = T.alloc_shared((BC,), dtype=input_dtype) + + gkq_fragment = T.alloc_fragment((BC, block_DK), dtype=T.float32) + # dkt_lower_temp = T.alloc_fragment((BC, block_DK), dtype=T.float32) + kbj_fragment = T.alloc_fragment((block_DK,), dtype=T.float32) + + max_token_j_idx = T.min(S, i_ti + BC) + for j in T.Pipelined(BC, num_stages=num_stages): + token_j_idx = i_ti + j + + if token_j_idx < max_token_j_idx: + T.copy(q[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], qj_shared) # [BK] + T.copy(k[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], kj_shared_lower) + T.copy(g[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gj_shared_lower) + + bj_local[0] = beta[bb, token_j_idx, bh] + T.copy(dAqk[bb, token_j_idx, bh, i_i * BC : (i_i + 1) * BC], dAqk_col_lower) # [BC] + T.copy(dAkk[bb, token_j_idx, bh, i_i * BC : (i_i + 1) * BC], dAkk_col_lower) + + # Compute kbj = kj * bj + for i_k2 in T.Parallel(block_DK): + kbj_fragment[i_k2] = kj_shared_lower[i_k2] * bj_local[0] + # Compute gkq = exp2(gj - g_current) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + gkq_fragment[i_bc, i_k2] = T.exp2(gj_shared_lower[i_k2] - g_current_shared[i_bc, i_k2]) + + # Accumulate: dkt += (dAkk * kbj + dAqk * qj) * gkq for i_bc <= j + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dkt_fragment[i_bc, i_k2] += T.if_then_else( + i_bc <= j, + (dAkk_col_lower[i_bc] * kbj_fragment[i_k2] + dAqk_col_lower[i_bc] * qj_shared[i_k2]) * gkq_fragment[i_bc, i_k2], + 0.0, + ) + + # Load dk and dg + T.copy(dk[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared) + T.copy(dg[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], dg_shared) + + # Update dg2: dg2 += (dk2 - dkt) * k + dg + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dg2_fragment[i_bc, i_k2] = ( + dg2_fragment[i_bc, i_k2] + + (dk2_fragment[i_bc, i_k2] - dkt_fragment[i_bc, i_k2]) * k_shared[i_bc, i_k2] + + dg_shared[i_bc, i_k2] + ) + + # Update dk2: dk2 += dk + dkt + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dk2_fragment[i_bc, i_k2] += dk_shared[i_bc, i_k2] + dkt_fragment[i_bc, i_k2] + + # Store dk2 and dg2 + T.copy(dk2_fragment, dk2[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK]) + T.copy(dg2_fragment, dg2[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + threads=128, + num_stages=0, + cu_seqlens=None, + chunk_indices=None, +): + q, k, g, beta, dAqk, dAkk, dq, dk, db, dg = prepare_input( + B, + S, + H, + DK, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + + # Reference implementation + dq_ref, dk_ref, db_ref, dg_ref = chunk_kda_bwd_intra( + q=q, + k=k, + g=g, + beta=beta, + dAqk=dAqk, + dAkk=dAkk, + dq=dq, + dk=dk, + db=db, + dg=dg, + ) + block_DK = min(64, tilelang.math.next_power_of_2(DK)) + NK = (DK + block_DK - 1) // block_DK + # TileLang implementation + kernel = tilelang_chunk_bwd_intra( + B=B, + S=S, + H=H, + DK=DK, + input_dtype=input_dtype, + output_dtype=output_dtype, + accum_dtype=accum_dtype, + gate_dtype=gate_dtype, + state_dtype=state_dtype, + chunk_size=chunk_size, + block_DK=block_DK, + ) + + dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = prepare_output( + B, S, H, DK, chunk_size, NK, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = kernel(q, k, g, beta, dAqk, dAkk, dq, dk, db, dg) + db_tilelang = db_tilelang.sum(0).add_(db) + dg_tilelang = chunk_local_cumsum( + dg_tilelang, + chunk_size=chunk_size, + reverse=True, + ) + + compare_tensors("dq", dq_tilelang, dq_ref) + compare_tensors("dk", dk_tilelang, dk_ref) + compare_tensors("db", db_tilelang, db_ref) + compare_tensors("dg", dg_tilelang, dg_ref) + + fla_time = do_bench( + chunk_kda_bwd_intra, + q=q, + k=k, + g=g, + beta=beta, + dAqk=dAqk, + dAkk=dAkk, + dq=dq, + dk=dk, + db=db, + dg=dg, + ) + tilelang_time = do_bench(kernel, q, k, g, beta, dAqk, dAkk, dq, dk, db, dg) + print(f"Fla time: {fla_time}") + print(f"Tilelang time: {tilelang_time}") + + +def main(): + DK = 128 + run_test( + B=1, + S=8192, + H=8, + DK=DK, + input_dtype=T.float32, + output_dtype=T.float32, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, + chunk_size=64, + threads=128, + num_stages=0, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_delta_bwd.py b/examples/kda/chunk_delta_bwd.py new file mode 100644 index 000000000..8c22488ca --- /dev/null +++ b/examples/kda/chunk_delta_bwd.py @@ -0,0 +1,309 @@ +# Reference: fla/ops/common/chunk_delta_h.py +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_delta import chunk_gated_delta_rule_bwd_dhu +from FLA_KDA.cumsum import chunk_local_cumsum +from test_utils_kda import do_bench, compare_tensors + +import torch +import torch.nn.functional as F + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() * 0.01 + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + # Note: G should be in logspace and do chunkwise cumsum + G = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + G = F.logsigmoid(G) + G = chunk_local_cumsum(G, chunk_size) + + h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() * 0.01 + + dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + return Q, K, W, G, h0, dht, dO, dv + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + gate_dtype, + state_dtype, +): + BS = S // chunk_size + dh = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda() + dh0 = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda() + dv2 = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return dh, dh0, dv2 + + +def get_configs(): + import itertools + + block_DV = [32, 64, 128] + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3, 4] + _configs = list(itertools.product(block_DV, threads, num_stages)) + + configs = [{"block_DV": c[0], "threads": c[1], "num_stages": c[2]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[-3, -2, -1]) +def tilelang_chunk_gated_delta_rule_bwd_dhu( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_gk=True, + use_initial_state=True, + use_final_state_gradient=True, + # kernel config + block_DV=64, + threads=256, + num_stages=0, +): + block_S = chunk_size + # Should support cu_seqlen + BS = S // block_S + + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + W_shape = (B, S, H, DK) + G_shape = (B, S, H, DK) + h0_shape = (B, H, DK, DV) + dht_shape = (B, H, DK, DV) + dO_shape = (B, S, H, DV) + dv_shape = (B, S, H, DV) + + dh_shape = (B, BS, H, DK, DV) + dh0_shape = (B, H, DK, DV) + dv2_shape = (B, S, H, DV) + + @T.prim_func + def kernel( + # Input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + GK: T.Tensor(G_shape, dtype=gate_dtype), + h0: T.Tensor(h0_shape, dtype=input_dtype), + dht: T.Tensor(dht_shape, dtype=input_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + # Output + dh: T.Tensor(dh_shape, dtype=output_dtype), + dh0: T.Tensor(dh0_shape, dtype=state_dtype), + dv2: T.Tensor(dv2_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): + bb, bh = bbh // H, bbh % H + + b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype) + b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + + Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + + GK_last_shared = T.alloc_shared((DK,), dtype=gate_dtype) + + if use_final_state_gradient: + T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared) + T.copy(b_dh_shared, b_dh_fragment) + else: + T.clear(b_dh_fragment) + + for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): + # The gradient should be stored in the reverse order + i_s_inv = T.ceildiv(S, block_S) - i_s - 1 # reverse indices + # Store the updated dh + T.copy(b_dh_fragment, b_dh_shared) + T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + # Update dv + T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared) + T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True) + T.copy( + dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared + ) # copy old dv + T.copy(dv_shared, dv_fragment_2) + for i_s2, i_v in T.Parallel(block_S, block_DV): + dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v] + # Store the updated dv + T.copy(dv_fragment, dv_shared) + T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + # Update dh + T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) # [block_S, DK] + T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared) # [block_S, DK] + T.copy( + dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared + ) # [block_S, block_DV] + + if use_gk: + last_idx = T.min((i_s_inv + 1) * block_S, S) - 1 # chunk last token gk + T.copy(GK[bb, last_idx, bh, :], GK_last_shared) + for i_k, i_v in T.Parallel(DK, block_DV): + b_dh_fragment[i_k, i_v] *= T.exp2(GK_last_shared[i_k]) + + T.gemm(Q_shared, dO_shared, b_dh_fragment_1, transpose_A=True, clear_accum=True) # [DK, block_DV] + + # dv_shared: [block_S, block_DV] + T.gemm(W_shared, dv_shared, b_dh_fragment_2, transpose_A=True, clear_accum=True) # [DK, block_DV] + for i_k, i_v in T.Parallel(DK, block_DV): + b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] * scale - b_dh_fragment_2[i_k, i_v] + + if use_initial_state: + T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_gk=True, + use_initial_state=True, + use_final_state_gradient=True, + block_DV=64, + threads=256, + num_stages=0, + use_torch=False, +): + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + + dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + + # fla ref + print("fla running...", flush=True) + if use_gk: + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu( + q=Q, k=K, w=W, do=dO, dv=dv, gk=G, h0=h0, dht=dht, scale=scale, use_exp2=True + ) + + # tilelang + print("tilelang running...", flush=True) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_gk, + use_initial_state, + use_final_state_gradient, + ) + dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) + + fla_time = do_bench( + chunk_gated_delta_rule_bwd_dhu, q=Q, k=K, w=W, do=dO, dv=dv, gk=G, h0=h0, dht=dht, scale=scale, chunk_size=chunk_size + ) + tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv) + + print(f"fla time: {fla_time} ms") + print(f"tilelang time: {tilelang_time} ms") + + compare_tensors("dh", dh_ref, dh_tilelang) + compare_tensors("dh0", dh0_ref, dh0_tilelang) + compare_tensors("dv2", dv2_ref, dv2_tilelang) + + +def main(): + DK = 128 + run_test( + B=1, + S=1024 * 8, + H=64, + DK=DK, + DV=128, + input_dtype="bfloat16", + output_dtype="bfloat16", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32", + chunk_size=64, + scale=DK**-0.5, + use_gk=True, + use_initial_state=True, + use_final_state_gradient=True, + block_DV=32, + threads=128, + num_stages=1, + use_torch=False, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_delta_h_fwd.py b/examples/kda/chunk_delta_h_fwd.py new file mode 100644 index 000000000..fbb8bd988 --- /dev/null +++ b/examples/kda/chunk_delta_h_fwd.py @@ -0,0 +1,306 @@ +# Reference: fla/ops/common/chunk_delta_h.py + +import sys # noqa: F401 +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/your/path/to/flash-linear-attention") + +from FLA_KDA.fla_chunk_delta import chunk_gated_delta_rule_fwd_h +from FLA_KDA.cumsum import chunk_local_cumsum + +import torch +import torch.nn.functional as F + +from test_utils_kda import compare_tensors, do_bench + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, +): + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + W = F.normalize(W, dim=-1, p=2) + U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + U = F.normalize(U, dim=-1, p=2) + G = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + G = F.logsigmoid(G) + G = chunk_local_cumsum(G, chunk_size) + initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + return K, W, U, G, initial_state + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + state_dtype, +): + BS = (S + chunk_size - 1) // chunk_size # ceildiv to match kernel iteration + h = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda() + final_state = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda() + V_new = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return h, final_state, V_new + + +def get_configs(): + import itertools + + block_DK = [32, 64, 128] + block_DV = [32, 64, 128] + threads = [128, 256] + num_stages = [1, 2, 3] + _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) + + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=3, rep=5) +@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_gated_delta_rule_fwd_h( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_gk, + use_initial_state, + store_final_state, + save_new_value, + # kernel config + block_DK=64, + block_DV=32, + threads=128, + num_stages=1, +): + block_S = chunk_size + BS = (S + chunk_size - 1) // chunk_size # ceildiv to match kernel iteration + + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + W_shape = (B, S, H, DK) + U_shape = (B, S, H, DV) + GK_shape = (B, S, H, DK) + h_shape = (B, BS, H, DK, DV) + initial_state_shape = (B, H, DK, DV) + final_state_shape = (B, H, DK, DV) + + @T.prim_func + def kernel( + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + U: T.Tensor(U_shape, dtype=input_dtype), + GK: T.Tensor(GK_shape, dtype=gate_dtype), + initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=output_dtype), + final_state: T.Tensor(final_state_shape, dtype=state_dtype), + V_new: T.Tensor(V_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): + bb, bh = bbh // H, bbh % H + + b_h_shared = T.alloc_shared((DK, block_DV), dtype=input_dtype) + b_h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + + U_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) + K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + GK_last_shared = T.alloc_shared((DK), dtype=gate_dtype) + + if use_initial_state: + T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared) + T.copy(b_h_shared, b_h_fragment) + else: + T.clear(b_h_fragment) + + for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): + # Store previous result to the hidden tensor, like the epilogue + T.copy(b_h_shared, h[bb, i_s, bh, :, bv * block_DV : (bv + 1) * block_DV]) + + # Recurrence + T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, :], W_shared) + T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True) + + # U - W * S + T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared) + T.copy(U_shared, U_fragment) + for i_s2, i_v in T.Parallel(block_S, block_DV): + V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v] + + # Save V_new + if save_new_value: + T.copy(V_new_fragment, dst=V_new_shared) + T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared) + # use_gk + if use_gk: + T.copy(GK[bb, (i_s + 1) * block_S - 1, bh, :], GK_last_shared) # block last token + for i_k, i_v in T.Parallel(DK, block_DV): + b_h_fragment[i_k, i_v] *= T.exp2(GK_last_shared[i_k]) + + # Update intermediate results + T.copy(V_new_fragment, V_new_shared) + T.gemm(K_shared, V_new_shared, b_h_fragment, transpose_A=True) + + T.copy(b_h_fragment, b_h_shared) + + # Save final state + if store_final_state: + T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=64, + block_DV=32, + threads=128, + num_stages=0, +): + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + h_ref, final_state_ref, V_new_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + + # fla ref + h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h( + k=K, + w=W, + u=U, + gk=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + use_exp2=True, + ) + + # tilelang + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_gk, + use_initial_state, + store_final_state, + save_new_value, + ) + h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) + + fla_time = do_bench( + chunk_gated_delta_rule_fwd_h, + k=K, + w=W, + u=U, + gk=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + use_exp2=True, + ) + tilelang_time = do_bench(kernel, K, W, U, G, initial_state) + + # check correctness + compare_tensors("h", h_ref, h_tilelang) + compare_tensors("final_state", final_state_ref, final_state_tilelang) + compare_tensors("V_new", V_new_ref, V_new_tilelang) + + print(f"tilelang time: {tilelang_time} ms") + print(f"fla time: {fla_time} ms") + + +def main(): + run_test( + B=1, + S=8192, + H=64, + DK=128, + DV=128, + input_dtype="float16", + output_dtype="float16", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32", + chunk_size=64, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=32, + block_DV=32, + threads=128, + num_stages=2, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_inter_solve_fused.py b/examples/kda/chunk_inter_solve_fused.py new file mode 100644 index 000000000..940dc20c8 --- /dev/null +++ b/examples/kda/chunk_inter_solve_fused.py @@ -0,0 +1,566 @@ +import tilelang +import tilelang.language as T + +from FLA_KDA.fla_chunk_intra import chunk_kda_fwd_inter_solve_fused +from FLA_KDA.cumsum import chunk_local_cumsum +from test_utils_kda import compare_tensors, do_bench + +import torch +import torch.nn.functional as F + + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + chunk_size, + sub_chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, +): + q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + gk = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() # 需要是cumsum + gk = F.logsigmoid(gk) + gk = chunk_local_cumsum(gk, chunk_size) + + Aqk = torch.empty(B, S, H, chunk_size, dtype=input_dtype).cuda() + Akk_diag = torch.ones(B, S, H, sub_chunk_size, dtype=torch.float32).cuda() + + return q, k, gk, beta, Aqk, Akk_diag + + +def prepare_output( + B, + S, + H, + chunk_size, + sub_chunk_size, + output_dtype, +): + Akk = torch.empty(B, S, H, chunk_size, dtype=output_dtype).cuda() + return Akk + + +@tilelang.jit(out_idx=[-2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_kda_fwd_inter_fused( + B, + S, + H, + DK, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + sub_chunk_size, + scale, + block_DK=32, + threads=32, + num_stages=1, +): + block_S = BS = chunk_size + BC = sub_chunk_size + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + GK_shape = (B, S, H, DK) + Beta_shape = (B, S, H) + Aqk_shape = (B, S, H, BS) + Akk_diag_shape = (B, S, H, BC) + """ + Fused kernel: compute inter-subchunk Akk + solve_tril in one pass. + Prerequisite: token_parallel has already computed diagonal Akk blocks in Akk_diag. + + This kernel: + 1. Computes off-diagonal Aqk blocks -> writes to global + 2. Computes off-diagonal Akk blocks -> keeps in registers + 3. Loads diagonal Akk blocks from Akk_diag (fp32) + 4. Does forward substitution on diagonals + 5. Computes merged Akk_inv + 6. Writes Akk_inv to Akk + """ + + @T.prim_func + def kernel( + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + GK: T.Tensor(GK_shape, dtype=gate_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + Akk_diag: T.Tensor(Akk_diag_shape, dtype=T.float32), + Aqk: T.Tensor(Aqk_shape, dtype=output_dtype), + Akk: T.Tensor(Aqk_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + + Aqk10_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk10_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk20_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk20_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk21_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk21_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk30_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk30_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk31_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk31_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk32_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk32_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk10_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk20_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk21_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk30_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk31_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk32_shared = T.alloc_shared((BC, BC), dtype=T.float32) + + K0_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK0_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + Q1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + Q2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + Q3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + + Q_GK_scaled_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K_GK_scaled_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + b_kt_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + + b_gn1_shared = T.alloc_shared((block_DK,), dtype=T.float32) + b_gn2_shared = T.alloc_shared((block_DK,), dtype=T.float32) + b_gn3_shared = T.alloc_shared((block_DK,), dtype=T.float32) + + b_gqn1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + b_gqn2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + b_gqn3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + + beta_1_shared = T.alloc_shared((BC,), dtype=T.float32) + beta_2_shared = T.alloc_shared((BC,), dtype=T.float32) + beta_3_shared = T.alloc_shared((BC,), dtype=T.float32) + # Akk_inv + Ai_00_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_10_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_11_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_20_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_21_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_22_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_30_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_31_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_32_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_33_shared = T.alloc_shared((BC, BC), dtype=T.float32) + + T.clear(Aqk10_fragment) + T.clear(Akk10_fragment) + T.clear(Aqk20_fragment) + T.clear(Akk20_fragment) + T.clear(Aqk21_fragment) + T.clear(Akk21_fragment) + T.clear(Aqk30_fragment) + T.clear(Akk30_fragment) + T.clear(Aqk31_fragment) + T.clear(Akk31_fragment) + T.clear(Aqk32_fragment) + T.clear(Akk32_fragment) + + i_tc0 = bs * BS + i_tc1 = bs * BS + BC + i_tc2 = bs * BS + 2 * BC + i_tc3 = bs * BS + 3 * BC + + ################################################################################ + # 1. off-diagonal blocks + ################################################################################ + + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(K[bb, bs * BS : bs * BS + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K0_shared) + T.copy(GK[bb, bs * BS : bs * BS + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK0_shared) + if i_tc1 < S: + T.copy(Q[bb, i_tc1 : i_tc1 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], Q1_shared) + T.copy(K[bb, i_tc1 : i_tc1 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K1_shared) + T.copy(GK[bb, i_tc1 : i_tc1 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK1_shared) + T.copy(GK[bb, i_tc1, bh, i_k * block_DK : (i_k + 1) * block_DK], b_gn1_shared) # subblock第一个token的GK + for i_c1, i_k1 in T.Parallel(BC, block_DK): + b_gqn1_shared[i_c1, i_k1] = T.if_then_else( + i_tc1 + i_c1 < S, T.exp2(GK1_shared[i_c1, i_k1] - b_gn1_shared[i_k1]), 0.0 + ) + Q_GK_scaled_shared[i_c1, i_k1] = Q1_shared[i_c1, i_k1] * b_gqn1_shared[i_c1, i_k1] + K_GK_scaled_shared[i_c1, i_k1] = K1_shared[i_c1, i_k1] * b_gqn1_shared[i_c1, i_k1] + b_kt_shared[i_c1, i_k1] = K0_shared[i_c1, i_k1] * T.exp2(b_gn1_shared[i_k1] - GK0_shared[i_c1, i_k1]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk10_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk10_fragment, transpose_B=True) + if i_tc2 < S: + T.copy(Q[bb, i_tc2 : i_tc2 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], Q2_shared) + T.copy(K[bb, i_tc2 : i_tc2 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K2_shared) + T.copy(GK[bb, i_tc2 : i_tc2 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK2_shared) + T.copy(GK[bb, i_tc2, bh, i_k * block_DK : (i_k + 1) * block_DK], b_gn2_shared) + for i_c2, i_k2 in T.Parallel(BC, block_DK): + b_gqn2_shared[i_c2, i_k2] = T.if_then_else( + i_tc2 + i_c2 < S, T.exp2(GK2_shared[i_c2, i_k2] - b_gn2_shared[i_k2]), 0.0 + ) + Q_GK_scaled_shared[i_c2, i_k2] = Q2_shared[i_c2, i_k2] * b_gqn2_shared[i_c2, i_k2] + K_GK_scaled_shared[i_c2, i_k2] = K2_shared[i_c2, i_k2] * b_gqn2_shared[i_c2, i_k2] + b_kt_shared[i_c2, i_k2] = K0_shared[i_c2, i_k2] * T.exp2(b_gn2_shared[i_k2] - GK0_shared[i_c2, i_k2]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk20_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk20_fragment, transpose_B=True) + for i_c3, i_k3 in T.Parallel(BC, block_DK): + b_kt_shared[i_c3, i_k3] = K1_shared[i_c3, i_k3] * T.exp2(b_gn2_shared[i_k3] - GK1_shared[i_c3, i_k3]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk21_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk21_fragment, transpose_B=True) + if i_tc3 < S: + T.copy(Q[bb, i_tc3 : i_tc3 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], Q3_shared) + T.copy(K[bb, i_tc3 : i_tc3 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K3_shared) + T.copy(GK[bb, i_tc3 : i_tc3 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK3_shared) + T.copy(GK[bb, i_tc3, bh, i_k * block_DK : (i_k + 1) * block_DK], b_gn3_shared) + for i_c4, i_k4 in T.Parallel(BC, block_DK): + b_gqn3_shared[i_c4, i_k4] = T.if_then_else( + i_tc3 + i_c4 < S, T.exp2(GK3_shared[i_c4, i_k4] - b_gn3_shared[i_k4]), 0.0 + ) + Q_GK_scaled_shared[i_c4, i_k4] = Q3_shared[i_c4, i_k4] * b_gqn3_shared[i_c4, i_k4] + K_GK_scaled_shared[i_c4, i_k4] = K3_shared[i_c4, i_k4] * b_gqn3_shared[i_c4, i_k4] + b_kt_shared[i_c4, i_k4] = K0_shared[i_c4, i_k4] * T.exp2(b_gn3_shared[i_k4] - GK0_shared[i_c4, i_k4]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk30_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk30_fragment, transpose_B=True) + for i_c5, i_k5 in T.Parallel(BC, block_DK): + b_kt_shared[i_c5, i_k5] = K1_shared[i_c5, i_k5] * T.exp2(b_gn3_shared[i_k5] - GK1_shared[i_c5, i_k5]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk31_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk31_fragment, transpose_B=True) + for i_c6, i_k6 in T.Parallel(BC, block_DK): + b_kt_shared[i_c6, i_k6] = K2_shared[i_c6, i_k6] * T.exp2(b_gn3_shared[i_k6] - GK2_shared[i_c6, i_k6]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk32_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk32_fragment, transpose_B=True) + + ################################################################################ + # 2. save off-diagonal Aqk blocks and prepare Akk + ################################################################################ + + if i_tc1 < S: + T.copy(Beta[bb, i_tc1 : i_tc1 + BC, bh], beta_1_shared) + for i_c21, i_c22 in T.Parallel(BC, BC): + Aqk10_fragment[i_c21, i_c22] = Aqk10_fragment[i_c21, i_c22] * scale + Akk10_fragment[i_c21, i_c22] = Akk10_fragment[i_c21, i_c22] * beta_1_shared[i_c21] + T.copy(Aqk10_fragment, Aqk[bb, i_tc1 : i_tc1 + BC, bh, 0:BC]) + T.copy(Akk10_fragment, Akk10_shared) + if i_tc2 < S: + T.copy(Beta[bb, i_tc2 : i_tc2 + BC, bh], beta_2_shared) + for i_c23, i_c24 in T.Parallel(BC, BC): + Aqk20_fragment[i_c23, i_c24] = Aqk20_fragment[i_c23, i_c24] * scale + Aqk21_fragment[i_c23, i_c24] = Aqk21_fragment[i_c23, i_c24] * scale + Akk20_fragment[i_c23, i_c24] = Akk20_fragment[i_c23, i_c24] * beta_2_shared[i_c23] + Akk21_fragment[i_c23, i_c24] = Akk21_fragment[i_c23, i_c24] * beta_2_shared[i_c23] + T.copy(Aqk20_fragment, Aqk[bb, i_tc2 : i_tc2 + BC, bh, 0:BC]) + T.copy(Aqk21_fragment, Aqk[bb, i_tc2 : i_tc2 + BC, bh, BC : 2 * BC]) + T.copy(Akk20_fragment, Akk20_shared) + T.copy(Akk21_fragment, Akk21_shared) + if i_tc3 < S: + T.copy(Beta[bb, i_tc3 : i_tc3 + BC, bh], beta_3_shared) + for i_c25, i_c26 in T.Parallel(BC, BC): + Aqk30_fragment[i_c25, i_c26] = Aqk30_fragment[i_c25, i_c26] * scale + Aqk31_fragment[i_c25, i_c26] = Aqk31_fragment[i_c25, i_c26] * scale + Aqk32_fragment[i_c25, i_c26] = Aqk32_fragment[i_c25, i_c26] * scale + Akk30_fragment[i_c25, i_c26] = Akk30_fragment[i_c25, i_c26] * beta_3_shared[i_c25] + Akk31_fragment[i_c25, i_c26] = Akk31_fragment[i_c25, i_c26] * beta_3_shared[i_c25] + Akk32_fragment[i_c25, i_c26] = Akk32_fragment[i_c25, i_c26] * beta_3_shared[i_c25] + T.copy(Aqk30_fragment, Aqk[bb, i_tc3 : i_tc3 + BC, bh, 0:BC]) + T.copy(Aqk31_fragment, Aqk[bb, i_tc3 : i_tc3 + BC, bh, BC : 2 * BC]) + T.copy(Aqk32_fragment, Aqk[bb, i_tc3 : i_tc3 + BC, bh, 2 * BC : 3 * BC]) + T.copy(Akk30_fragment, Akk30_shared) + T.copy(Akk31_fragment, Akk31_shared) + T.copy(Akk32_fragment, Akk32_shared) + + ################################################################################ + # 3. load diagonal Akk blocks + ################################################################################ + + T.copy(Akk_diag[bb, i_tc0 : i_tc0 + BC, bh, :], Ai_00_shared) + T.copy(Akk_diag[bb, i_tc1 : i_tc1 + BC, bh, :], Ai_11_shared) + T.copy(Akk_diag[bb, i_tc2 : i_tc2 + BC, bh, :], Ai_22_shared) + T.copy(Akk_diag[bb, i_tc3 : i_tc3 + BC, bh, :], Ai_33_shared) + for i_c1, i_c2 in T.Parallel(BC, BC): + Ai_00_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_00_shared[i_c1, i_c2], 0) + Ai_11_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_11_shared[i_c1, i_c2], 0) + Ai_22_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_22_shared[i_c1, i_c2], 0) + Ai_33_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_33_shared[i_c1, i_c2], 0) + + ################################################################################ + # 4. forward substitution on diagonals + ################################################################################ + a_00_shared = T.alloc_shared((BC,), dtype=T.float32) + Aa_mul_shared = T.alloc_shared((BC, BC), dtype=T.float32) + reduce_shared = T.alloc_shared((BC,), dtype=T.float32) + for i_i in T.Pipelined(2, T.min(BC, S - i_tc0), num_stages=num_stages): + T.copy(Akk_diag[bb, i_tc0 + i_i, bh, :], a_00_shared) # load row + for i_c in T.Parallel(BC): + a_00_shared[i_c] = T.if_then_else(i_c < i_i, -a_00_shared[i_c], 0.0) # mask:i_c