diff --git a/benchmarks/bench_gdn_prefill.py b/benchmarks/bench_gdn_prefill.py new file mode 100644 index 0000000000..aa94143d85 --- /dev/null +++ b/benchmarks/bench_gdn_prefill.py @@ -0,0 +1,282 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import argparse +import numpy as np +import torch + +from flashinfer.gdn_prefill import chunk_gated_delta_rule +from flashinfer.testing.utils import bench_gpu_time + + +def gdn_flops( + total_seq_len: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + num_seqs: int, +) -> int: + """ + Calculate FLOPs for Gated Delta Rule (GDN) attention. + + Delta Rule formula: + state_t = alpha_t * state_{t-1} + beta_t * (k_t @ v_t^T) + output_t = q_t @ state_t + + Matrix multiplications per token per head: + 1. k @ v^T (outer product): 2 * d^2 FLOPs + 2. q @ state: 2 * d^2 FLOPs + + Note: alpha/beta gating are element-wise scalar multiplications, + not counted in TFLOPS. + """ + num_o_heads = max(num_q_heads, num_v_heads) + + # k @ v^T (outer product): 2 * d^2 per token per head + outer_product_flops = 2 * total_seq_len * num_o_heads * head_size * head_size + + # q @ state: 2 * d^2 per token per head + output_flops = 2 * total_seq_len * num_o_heads * head_size * head_size + + total_flops = outer_product_flops + output_flops + return total_flops + + +def gdn_bytes( + total_seq_len: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + num_seqs: int, + dtype: torch.dtype, +) -> int: + """ + Calculate memory bytes for GDN attention. + + Includes: + - Q, K, V tensors (input) + - Output tensor + - State tensor (float32) + - Alpha, Beta tensors (optional, float32) + """ + num_o_heads = max(num_q_heads, num_v_heads) + num_sab_heads = num_o_heads + elem_size = dtype.itemsize + + # Input tensors + q_bytes = total_seq_len * num_q_heads * head_size * elem_size + k_bytes = total_seq_len * num_k_heads * head_size * elem_size + v_bytes = total_seq_len * num_v_heads * head_size * elem_size + + # Output tensor + o_bytes = total_seq_len * num_o_heads * head_size * elem_size + + # State tensor (float32) + state_bytes = num_seqs * num_sab_heads * head_size * head_size * 4 + + # Alpha and Beta (float32) + alpha_bytes = total_seq_len * num_sab_heads * 4 + beta_bytes = total_seq_len * num_sab_heads * 4 + + total_bytes = ( + q_bytes + k_bytes + v_bytes + o_bytes + state_bytes + alpha_bytes + beta_bytes + ) + return total_bytes + + +def bench_gdn_prefill( + batch_size: int, + seq_len: int, + num_q_heads: int, + num_k_heads: int, + num_v_heads: int, + head_size: int, + dtype: torch.dtype, + use_alpha: bool = True, + use_beta: bool = True, +): + """Benchmark GDN prefill kernel.""" + total_seq_len = batch_size * seq_len + num_o_heads = max(num_q_heads, num_v_heads) + num_sab_heads = num_o_heads + + # Create inputs + q = torch.randn(total_seq_len, num_q_heads, head_size, dtype=dtype, device="cuda") + k = torch.randn(total_seq_len, num_k_heads, head_size, dtype=dtype, device="cuda") + # L2 normalize k for numerical stability + k = torch.nn.functional.normalize(k, p=2.0, dim=-1) + v = torch.randn(total_seq_len, num_v_heads, head_size, dtype=dtype, device="cuda") + + cu_seqlens = torch.arange( + 0, batch_size * seq_len + 1, seq_len, dtype=torch.int64, device="cuda" + ) + + alpha = ( + torch.rand(total_seq_len, num_sab_heads, dtype=torch.float32, device="cuda") + if use_alpha + else None + ) + beta = ( + torch.rand(total_seq_len, num_sab_heads, dtype=torch.float32, device="cuda") + if use_beta + else None + ) + + # Pre-allocate outputs + output = torch.empty( + total_seq_len, num_o_heads, head_size, dtype=dtype, device="cuda" + ) + output_state = torch.empty( + batch_size, + num_sab_heads, + head_size, + head_size, + dtype=torch.float32, + device="cuda", + ) + + # Warmup + chunk_gated_delta_rule( + q, k, v, alpha, beta, None, None, True, cu_seqlens, False, output, output_state + ) + torch.cuda.synchronize() + + # Benchmark + times = bench_gpu_time( + lambda: chunk_gated_delta_rule( + q, + k, + v, + alpha, + beta, + None, + None, + True, + cu_seqlens, + False, + output, + output_state, + ), + dry_run_time_ms=100, + repeat_time_ms=1000, + enable_cupti=True, + ) + + median_ms = np.median(times) + + # Calculate metrics + flops = gdn_flops( + total_seq_len, num_q_heads, num_k_heads, num_v_heads, head_size, batch_size + ) + bytes_accessed = gdn_bytes( + total_seq_len, + num_q_heads, + num_k_heads, + num_v_heads, + head_size, + batch_size, + dtype, + ) + + tflops = flops / median_ms / 1e9 + tb_per_sec = bytes_accessed / median_ms / 1e9 + + # Get device info for bandwidth calculation + props = torch.cuda.get_device_properties(0) + props.total_memory * 2 / 1e12 # Approximate peak bandwidth + + return { + "batch_size": batch_size, + "seq_len": seq_len, + "num_q_heads": num_q_heads, + "num_k_heads": num_k_heads, + "num_v_heads": num_v_heads, + "head_size": head_size, + "dtype": str(dtype).replace("torch.", ""), + "median_ms": median_ms, + "tflops": tflops, + "tb_per_sec": tb_per_sec, + } + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark GDN Prefill Kernel") + parser.add_argument("--batch-size", type=int, nargs="+", default=[1, 4, 16, 64]) + parser.add_argument("--seq-len", type=int, nargs="+", default=[128, 256, 512, 1024]) + parser.add_argument("--num-q-heads", type=int, default=16) + parser.add_argument("--num-k-heads", type=int, default=16) + parser.add_argument("--num-v-heads", type=int, default=32) + parser.add_argument("--head-size", type=int, default=128) + parser.add_argument( + "--dtype", type=str, choices=["float16", "bfloat16"], default="bfloat16" + ) + parser.add_argument( + "--preset", + type=str, + choices=["qwen3-next", "custom"], + default="custom", + help="Use preset config. qwen3-next: q=k=16, v=32, d=128", + ) + args = parser.parse_args() + + # Apply preset configurations + if args.preset == "qwen3-next": + # Qwen3-Next-80B-A3B linear attention config (GVA) + args.num_q_heads = 16 + args.num_k_heads = 16 + args.num_v_heads = 32 + args.head_size = 128 + + # Check SM90 support + device_capability = torch.cuda.get_device_capability() + if device_capability[0] < 9: + print(f"Current device capability: {device_capability}") + print("GDN requires SM90 (Hopper) or later. Exiting...") + return + + dtype = getattr(torch, args.dtype) + + print( + f"GDN Prefill Benchmark (heads: q={args.num_q_heads}, k={args.num_k_heads}, v={args.num_v_heads}, d={args.head_size}, dtype={args.dtype})" + ) + print("-" * 100) + print(f"{'batch':>6} {'seq_len':>8} {'time(ms)':>10} {'TFLOPS':>10} {'TB/s':>10}") + print("-" * 100) + + for batch_size in args.batch_size: + for seq_len in args.seq_len: + result = bench_gdn_prefill( + batch_size=batch_size, + seq_len=seq_len, + num_q_heads=args.num_q_heads, + num_k_heads=args.num_k_heads, + num_v_heads=args.num_v_heads, + head_size=args.head_size, + dtype=dtype, + ) + print( + f"{result['batch_size']:>6} {result['seq_len']:>8} " + f"{result['median_ms']:>10.3f} {result['tflops']:>10.2f} " + f"{result['tb_per_sec']:>10.2f}" + ) + + print("-" * 100) + + +if __name__ == "__main__": + main() diff --git a/csrc/flat/ampere/collective/flat_collective_inverse.hpp b/csrc/flat/ampere/collective/flat_collective_inverse.hpp new file mode 100644 index 0000000000..c9535c6396 --- /dev/null +++ b/csrc/flat/ampere/collective/flat_collective_inverse.hpp @@ -0,0 +1,482 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/cutlass.h" +#include "flat/cute_ext.hpp" + +namespace flat::collective { + +using namespace cute; + +template +constexpr bool is_contiguous(Layout&& layout) { + auto dim_layout = get(layout); + if constexpr (rank(dim_layout) == 0) { + return stride(dim_layout) == 1; + } else { + return stride<0>(dim_layout) == 1; + } +} + +namespace detail::SM80 { + +// SM80 version of make_acc_into_op in "flat/hopper/collective/flat_common.hpp" +template +CUTE_DEVICE constexpr auto convert_c_layout_to_a_layout(CLayout const& c, + TiledMMA const& tiled_mma) { + constexpr auto c_frag_atom_size = size<0>(CLayout{}); + constexpr auto a_frag_atom_size = size<1>(typename TiledMMA::AtomLayoutA_TV{}); + static_assert(a_frag_atom_size % c_frag_atom_size == 0); + constexpr auto ratio = a_frag_atom_size / c_frag_atom_size; + if constexpr (ratio == 1) { + return CLayout{}; + } else { + // e.g. the mma instruction shape is 16x8x16, we need to convert from ((2,2), MMA_M, MMA_N) to + // ((2,2,2), MMA_M, MMA_N/2) + + constexpr auto tiler = + make_shape(_, _, Int{}); // keep the first mode (FragAtom) and second mode (MMA_M) + constexpr auto divided = + logical_divide(CLayout{}, tiler); // (FragAtom, MMA_M, (ratio, MMA_N/ratio)) + + return make_layout(flatten(make_layout(get<0>(divided), get<2, 0>(divided))), get<1>(divided), + get<2, 1>(divided)); + } +} + +template +CUTE_DEVICE auto make_acc_into_op(Accumulator const& acc, TiledMMA const& tiled_mma) { + Tensor operand = + make_fragment_like(convert_c_layout_to_a_layout(acc.layout(), tiled_mma)); + Tensor operand_as_acc = make_tensor(operand.data(), acc.layout()); + cute::copy(acc, operand_as_acc); + return operand; +} + +} // namespace detail::SM80 + +template +struct CollectiveInverse { + // FIXME: precision is not good due to half + static_assert(std::is_same_v || std::is_same_v, + "only half is implemented"); + + CUTE_DEVICE + CollectiveInverse(int wg_sync_named_barrier_id) + : wg_sync_named_barrier_id_(wg_sync_named_barrier_id) {} + + template + CUTE_DEVICE void compute(TensorT&& sT) { + constexpr auto L = + typename std::remove_const_t>::layout_type{}; + static_assert(rank(L) == 2); + static_assert(size<0>(L) == 64); + static_assert(size<1>(L) == 64); + + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + if (thread_idx < 64) { // compute 8x8 inverse on diagnal directly + auto t8X8sT = flat_divide(sT, Shape<_8, _8>{}); + compute_diagonal_inverse_NxN<8>(t8X8sT(_, _, thread_idx / 8, thread_idx / 8), thread_idx % 8); + } + + cutlass::arch::NamedBarrier::arrive_and_wait(cutlass::NumThreadsPerWarpGroup, + wg_sync_named_barrier_id_); + + auto t16X16sT = flat_divide(sT, Shape<_16, _16>{}); + blockwise_diagonal_inversed_8x8_to_16x16(t16X16sT(_, _, thread_idx / 32, thread_idx / 32)); + + cutlass::arch::NamedBarrier::arrive_and_wait(cutlass::NumThreadsPerWarpGroup, + wg_sync_named_barrier_id_); + + if (thread_idx < 64) { + auto t32X32sT = flat_divide(sT, Shape<_32, _32>{}); + blockwise_diagonal_inversed_16x16_to_32x32(t32X32sT(_, _, thread_idx / 32, thread_idx / 32)); + } + cutlass::arch::NamedBarrier::arrive_and_wait(cutlass::NumThreadsPerWarpGroup, + wg_sync_named_barrier_id_); + blockwise_diagonal_inversed_32x32_to_64x64(sT); + } + + private: + template + CUTE_DEVICE void compute_diagonal_inverse_NxN(TensorT&& mat, + int tid_in_group) { // group_size = N + constexpr auto L = + typename std::remove_const_t>::layout_type{}; + static_assert(rank(L) == 2); + static_assert(size<0>(L) == N); + static_assert(size<1>(L) == N); + + using ElementCompute = float; + + using CopyOp = + Copy_Atom, Element>; + + auto load_row = [&](int y) { + auto row = make_tensor(Shape>{}); + copy(CopyOp{}, std::forward(mat)(y, _), row); + + auto row_cvt = make_tensor_like(row); + copy(row, row_cvt); + + if constexpr (GarbageFilledDiagonal || GarbageFilledUpperTriangular) { + CUTE_UNROLL + for (int i = 0; i < N; ++i) { + row_cvt(i) = i == y ? 1.0f : (i > y ? 0.0f : row_cvt(i)); + } + } + return row_cvt; + }; + + auto store_row = [&](int y, auto row) { + auto row_cvt = make_tensor_like(row); + copy(row, row_cvt); + copy(CopyOp{}, row_cvt, std::forward(mat)(y, _)); + }; + + auto row = load_row(tid_in_group); +#define LOAD(y, x) __shfl_sync(0xffffffff, row(x), y, N) + + CUTE_UNROLL + for (int src_row = 0; src_row < N - 1; ++src_row) { // idx of src row to eliminate + auto row_scale = -row(src_row); // scale the src row + CUTE_UNROLL + for (int i = 0; i < src_row; ++i) { + auto src_row_value = LOAD(src_row, i); + row(i) = tid_in_group > src_row ? row_scale * src_row_value + row(i) : row(i); + } + row(src_row) = tid_in_group > src_row ? row_scale : row(src_row); + } + +#undef LOAD + + store_row(tid_in_group, row); + } + + /* + blockwise inverse has relation as follows + inv(| A 0 |) = | inv(A) 0 | + | C D | | -inv(D)C inv(A) inv(D) | + */ + + template + CUTE_DEVICE void blockwise_diagonal_inversed_4x4_to_8x8(TensorT&& mat) { + constexpr auto L = + typename std::remove_const_t>::layout_type{}; + static_assert(rank(L) == 2); + static_assert(size<0>(L) == 8); + static_assert(size<1>(L) == 8); + auto mat_NxN_2x2 = flat_divide(std::forward(mat), Shape<_4, _4>{}); + + // FIXME: implement + } + + template + CUTE_DEVICE void blockwise_diagonal_inversed_8x8_to_16x16(TensorT&& mat) { + constexpr auto L = + typename std::remove_const_t>::layout_type{}; + static_assert(rank(L) == 2); + static_assert(size<0>(L) == 16); + static_assert(size<1>(L) == 16); + + static_assert(is_contiguous<0>(L) == 1 || is_contiguous<1>(L) == 1); + constexpr bool is_col_major = is_contiguous<0>(L); + + auto mat_8x8_2x2 = flat_divide(std::forward(mat), Shape<_8, _8>{}); + using MMA = SM80_16x8x8_F32F16F16F32_TN; + using TiledMMA = decltype(make_tiled_mma(MMA{}, Layout>{}, Shape<_16, _8, _8>{})); + + using CopyOpD_S2R = std::conditional_t; + using CopyOpC_S2R = std::conditional_t; + using CopyOpA_S2R = std::conditional_t; +#ifdef CUTE_ARCH_STSM_SM90_ENABLED + using CopyOpO_R2S = std::conditional_t; +#else + using CopyOpO_R2S = UniversalCopy; +#endif + + int lane_id = cutlass::canonical_lane_idx(); + auto tiled_mma = TiledMMA{}; + auto thr_mma = tiled_mma.get_thread_slice(lane_id); + + auto D_tiled_copy = make_tiled_copy_A(Copy_Atom{}, tiled_mma); + auto C_tiled_copy = make_tiled_copy_B(Copy_Atom{}, tiled_mma); + auto A_tiled_copy = make_tiled_copy_B(Copy_Atom{}, tiled_mma); + auto O_tiled_copy = make_tiled_copy_C(Copy_Atom{}, tiled_mma); + + auto D_thr_copy = D_tiled_copy.get_thread_slice(lane_id); + auto C_thr_copy = C_tiled_copy.get_thread_slice(lane_id); + auto A_thr_copy = A_tiled_copy.get_thread_slice(lane_id); + auto O_thr_copy = O_tiled_copy.get_thread_slice(lane_id); + + Tensor sDinv = mat_8x8_2x2(_, _, _1{}, _1{}); + Tensor sC = select_tensor<1, 0>(mat_8x8_2x2(_, _, _1{}, _0{})); + Tensor sAinv = select_tensor<1, 0>(mat_8x8_2x2(_, _, _0{}, _0{})); + Tensor sO = mat_8x8_2x2(_, _, _1{}, _0{}); + + Tensor sDinv_m_bcast = + make_tensor(sDinv.data(), logical_product(sDinv.layout(), Tile>{})); + Tensor sO_m_bcast = + make_tensor(sO.data(), logical_product(sO.layout(), Tile>{})); + + Tensor tOrDinv = make_fragment_like(partition_shape_A(tiled_mma, Shape<_16, _8>{})); + Tensor tOrC = thr_mma.partition_fragment_B(sC); + Tensor tOrAinv = thr_mma.partition_fragment_B(sAinv); + + Tensor tDCrDC = partition_fragment_C(tiled_mma, Shape<_16, _8>{}); // output of -inv(D)C + Tensor tOrO = partition_fragment_C(tiled_mma, Shape<_16, _8>{}); // output of -inv(D)C inv(A) + + Tensor tOsDinv = D_thr_copy.partition_S(sDinv_m_bcast); + Tensor tOrDinv_cv = D_thr_copy.retile_D(tOrDinv); + Tensor tOsC = C_thr_copy.partition_S(sC); + Tensor tOrC_cv = C_thr_copy.retile_D(tOrC); + Tensor tOsAinv = A_thr_copy.partition_S(sAinv); + Tensor tOrAinv_cv = A_thr_copy.retile_D(tOrAinv); + Tensor tOsO = O_thr_copy.partition_D(sO_m_bcast); + Tensor tOrO_cv = O_thr_copy.retile_S(tOrO); + + ///////////////////////////////////////////////////////////////////////////// + // -inv(D)C + copy(D_tiled_copy, tOsDinv(make_coord(_, _0{}), _, _), tOrDinv_cv(make_coord(_, _0{}), _, _)); + copy(C_tiled_copy, tOsC, tOrC_cv); + + clear(tDCrDC); + gemm(tiled_mma, tOrDinv, tOrC, tDCrDC); + transform(tDCrDC(make_coord(_, _0{}), _, _), [](auto v) { return -v; }); + + ///////////////////////////////////////////////////////////////////////////// + // -inv(D)C inv(A) + Tensor tOrDC = detail::SM80::make_acc_into_op(tDCrDC, tiled_mma); + + copy(A_tiled_copy, tOsAinv, tOrAinv_cv); + clear(tOrO); + gemm(tiled_mma, tOrDC, tOrAinv, tOrO); + + auto tOrO_cv_cvt = make_tensor_like(tOrO_cv(make_coord(_, _0{}), _, _)); + transform(tOrO_cv(make_coord(_, _0{}), _, _), tOrO_cv_cvt, [](auto v) { return Element(v); }); + copy(O_tiled_copy, tOrO_cv_cvt, tOsO(make_coord(_, _0{}), _, _)); + } + + template + CUTE_DEVICE void blockwise_diagonal_inversed_16x16_to_32x32(TensorT&& mat) { + constexpr auto L = + typename std::remove_const_t>::layout_type{}; + static_assert(rank(L) == 2); + static_assert(size<0>(L) == 32); + static_assert(size<1>(L) == 32); + + static_assert(is_contiguous<0>(L) == 1 || is_contiguous<1>(L) == 1); + constexpr bool is_col_major = is_contiguous<0>(L); + + using TileShape = Shape<_16, _16, _16>; + auto mat_16x16_2x2 = flat_divide(std::forward(mat), select<0, 1>(TileShape{})); + + using MMA = SM80_16x8x16_F32F16F16F32_TN; + using TiledMMA = decltype(make_tiled_mma(MMA{}, Layout>{}, TileShape{})); + + using CopyOpD_S2R = std::conditional_t; + using CopyOpC_S2R = std::conditional_t; + using CopyOpA_S2R = std::conditional_t; +#ifdef CUTE_ARCH_STSM_SM90_ENABLED + using CopyOpO_R2S = std::conditional_t; +#else + using CopyOpO_R2S = UniversalCopy; +#endif + + int lane_id = cutlass::canonical_lane_idx(); + auto tiled_mma = TiledMMA{}; + auto thr_mma = tiled_mma.get_thread_slice(lane_id); + + auto D_tiled_copy = make_tiled_copy_A(Copy_Atom{}, tiled_mma); + auto C_tiled_copy = make_tiled_copy_B(Copy_Atom{}, tiled_mma); + auto A_tiled_copy = make_tiled_copy_B(Copy_Atom{}, tiled_mma); + auto O_tiled_copy = make_tiled_copy_C(Copy_Atom{}, tiled_mma); + + auto D_thr_copy = D_tiled_copy.get_thread_slice(lane_id); + auto C_thr_copy = C_tiled_copy.get_thread_slice(lane_id); + auto A_thr_copy = A_tiled_copy.get_thread_slice(lane_id); + auto O_thr_copy = O_tiled_copy.get_thread_slice(lane_id); + + Tensor sDinv = mat_16x16_2x2(_, _, _1{}, _1{}); + Tensor sC = select_tensor<1, 0>(mat_16x16_2x2(_, _, _1{}, _0{})); + Tensor sAinv = select_tensor<1, 0>(mat_16x16_2x2(_, _, _0{}, _0{})); + Tensor sO = mat_16x16_2x2(_, _, _1{}, _0{}); + + Tensor tOrDinv = thr_mma.partition_fragment_A(sDinv); + Tensor tOrC = thr_mma.partition_fragment_B(sC); + Tensor tOrAinv = thr_mma.partition_fragment_B(sAinv); + + Tensor tDCrDC = + partition_fragment_C(tiled_mma, select<0, 1>(TileShape{})); // output of -inv(D)C + Tensor tOrO = + partition_fragment_C(tiled_mma, select<0, 1>(TileShape{})); // output of -inv(D)C inv(A) + + Tensor tOsDinv = D_thr_copy.partition_S(sDinv); + Tensor tOrDinv_cv = D_thr_copy.retile_D(tOrDinv); + Tensor tOsC = C_thr_copy.partition_S(sC); + Tensor tOrC_cv = C_thr_copy.retile_D(tOrC); + Tensor tOsAinv = A_thr_copy.partition_S(sAinv); + Tensor tOrAinv_cv = A_thr_copy.retile_D(tOrAinv); + Tensor tOsO = O_thr_copy.partition_D(sO); + Tensor tOrO_cv = O_thr_copy.retile_S(tOrO); + + ///////////////////////////////////////////////////////////////////////////// + // -inv(D)C + copy(D_tiled_copy, tOsDinv, tOrDinv_cv); + copy(C_tiled_copy, tOsC, tOrC_cv); + + clear(tDCrDC); + gemm(tiled_mma, tOrDinv, tOrC, tDCrDC); + transform(tDCrDC, [](auto v) { return -v; }); + + ///////////////////////////////////////////////////////////////////////////// + // -inv(D)C inv(A) + Tensor tOrDC = detail::SM80::make_acc_into_op(tDCrDC, tiled_mma); + + copy(A_tiled_copy, tOsAinv, tOrAinv_cv); + clear(tOrO); + gemm(tiled_mma, tOrDC, tOrAinv, tOrO); + + auto tOrO_cv_cvt = make_tensor_like(tOrO_cv); + transform(tOrO_cv, tOrO_cv_cvt, [](auto v) { return Element(v); }); + copy(O_tiled_copy, tOrO_cv_cvt, tOsO); + } + + template + CUTE_DEVICE void blockwise_diagonal_inversed_32x32_to_64x64(TensorT&& mat) { + constexpr auto L = + typename std::remove_const_t>::layout_type{}; + static_assert(rank(L) == 2); + static_assert(size<0>(L) == 64); + static_assert(size<1>(L) == 64); + + static_assert(is_contiguous<0>(L) == 1 || is_contiguous<1>(L) == 1); + constexpr bool is_col_major = is_contiguous<0>(L); + + auto mat_32x32_2x2 = flat_divide(std::forward(mat), select<0, 1>(Shape<_32, _32>{})); + auto mat_16x2X16x2_2x2 = logical_divide(mat_32x32_2x2, Shape<_16, _16>{}); + + using MMA = SM80_16x8x16_F32F16F16F32_TN; + using TiledMMA1 = + decltype(make_tiled_mma(MMA{}, Layout>{}, Shape<_16, _16, _32>{})); + using TiledMMA2 = + decltype(make_tiled_mma(MMA{}, Layout>{}, Shape<_16, _32, _16>{})); + + using CopyOpD_S2R = std::conditional_t; + using CopyOpC_S2R = std::conditional_t; + using CopyOpA_S2R = std::conditional_t; + using CopyOpO_S2R = std::conditional_t; + using CopyOpO_S2R = std::conditional_t; +#ifdef CUTE_ARCH_STSM_SM90_ENABLED + using CopyOpO_R2S = std::conditional_t; +#else + using CopyOpO_R2S = UniversalCopy; +#endif + + int warp_id_in_wg = cutlass::canonical_warp_idx() - + cutlass::NumWarpsPerWarpGroup * cutlass::canonical_warp_group_idx(); + int x = warp_id_in_wg / 2; + int y = warp_id_in_wg % 2; + + int lane_id = cutlass::canonical_lane_idx(); + auto tiled_mma1 = TiledMMA1{}; + auto thr_mma1 = tiled_mma1.get_thread_slice(lane_id); + + auto tiled_mma2 = TiledMMA2{}; + auto thr_mma2 = tiled_mma2.get_thread_slice(lane_id); + + auto D_tiled_copy = make_tiled_copy_A(Copy_Atom{}, tiled_mma1); + auto C_tiled_copy = make_tiled_copy_B(Copy_Atom{}, tiled_mma1); + auto A_tiled_copy = make_tiled_copy_B(Copy_Atom{}, tiled_mma2); + auto O_tiled_s2r = make_tiled_copy_C(Copy_Atom{}, tiled_mma2); + auto O_tiled_r2s = make_tiled_copy_C(Copy_Atom{}, tiled_mma2); + + auto D_thr_copy = D_tiled_copy.get_thread_slice(lane_id); + auto C_thr_copy = C_tiled_copy.get_thread_slice(lane_id); + auto A_thr_copy = A_tiled_copy.get_thread_slice(lane_id); + auto O_thr_s2r = O_tiled_s2r.get_thread_slice(lane_id); + auto O_thr_r2s = O_tiled_r2s.get_thread_slice(lane_id); + + Tensor sDinv = mat_16x2X16x2_2x2(make_coord(_, y), _, _1{}, _1{}); + Tensor sC = select_tensor<1, 0>(mat_16x2X16x2_2x2(_, make_coord(_, x), _1{}, _0{})); + Tensor sAinv = + select_tensor<1, 0>(mat_16x2X16x2_2x2(make_coord(_, x), _, _0{}, _0{})); // NOTE: not y! + Tensor sO = mat_16x2X16x2_2x2(make_coord(_, y), _, _1{}, _0{}); // needs cross-warp reduction + + Tensor tOrDinv = thr_mma1.partition_fragment_A(sDinv); + Tensor tOrC = thr_mma1.partition_fragment_B(sC); + Tensor tOrAinv = thr_mma2.partition_fragment_B(sAinv); + + Tensor tDCrDC = partition_fragment_C(tiled_mma1, Shape<_16, _16>{}); // output of -inv(D)C + Tensor tOrO = partition_fragment_C(tiled_mma2, Shape<_16, _32>{}); // output of -inv(D)C inv(A) + + Tensor tOsDinv = D_thr_copy.partition_S(sDinv); + Tensor tOrDinv_cv = D_thr_copy.retile_D(tOrDinv); + Tensor tOsC = C_thr_copy.partition_S(sC); + Tensor tOrC_cv = C_thr_copy.retile_D(tOrC); + Tensor tOsAinv = A_thr_copy.partition_S(sAinv); + Tensor tOrAinv_cv = A_thr_copy.retile_D(tOrAinv); + + ///////////////////////////////////////////////////////////////////////////// + // -inv(D)C + copy(D_tiled_copy, tOsDinv, tOrDinv_cv); + copy(C_tiled_copy, tOsC, tOrC_cv); + + clear(tDCrDC); + gemm(tiled_mma1, tOrDinv, tOrC, tDCrDC); + transform(tDCrDC, [](auto v) { return -v; }); + + ///////////////////////////////////////////////////////////////////////////// + // -inv(D)C inv(A) + Tensor tOrDC = detail::SM80::make_acc_into_op(tDCrDC, tiled_mma2); + + copy(A_tiled_copy, tOsAinv, tOrAinv_cv); + clear(tOrO); + gemm(tiled_mma2, tOrDC, tOrAinv, tOrO); + + auto tOrO_cvt = make_tensor_like(tOrO); + transform(tOrO, tOrO_cvt, [](auto v) { return Element(v); }); + + // ensure tOsC consumed, tOsC and tOsO are the same buffer + cutlass::arch::NamedBarrier::arrive_and_wait(cutlass::NumThreadsPerWarpGroup, + wg_sync_named_barrier_id_); + + Tensor tOsO = O_thr_r2s.partition_D(sO); + Tensor tOrO_cvt_cv = O_thr_r2s.retile_S(tOrO_cvt); + if (x == 0) { + copy(O_tiled_r2s, tOrO_cvt_cv, tOsO); + } + cutlass::arch::NamedBarrier::arrive_and_wait(cutlass::NumThreadsPerWarpGroup, + wg_sync_named_barrier_id_); + if (x == 1) { + Tensor tOrO_red = make_tensor_like(tOrO_cvt); + Tensor tOsO_s = O_thr_s2r.partition_S(sO); + Tensor tOrO_red_cv = O_thr_s2r.retile_D(tOrO_red); + copy(O_tiled_s2r, tOsO_s, tOrO_red_cv); + transform(tOrO_cvt, tOrO_red, tOrO_cvt, [](auto a, auto b) { return a + b; }); + copy(O_tiled_r2s, tOrO_cvt_cv, tOsO); + } + } + + private: + int wg_sync_named_barrier_id_; +}; + +} // namespace flat::collective diff --git a/csrc/flat/ampere/collective/flat_collective_load.hpp b/csrc/flat/ampere/collective/flat_collective_load.hpp new file mode 100644 index 0000000000..3a7f517eff --- /dev/null +++ b/csrc/flat/ampere/collective/flat_collective_load.hpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/pipeline/sm90_pipeline.hpp" +#include "flat/unused.hpp" + +namespace flat::collective { + +using namespace cute; + +enum class LoadKindVector { + kAlpha, + kBeta, +}; + +CUTE_HOST_DEVICE constexpr char const* to_string(LoadKindVector kind) { + if (kind == LoadKindVector::kAlpha) { + return "alpha"; + } else if (kind == LoadKindVector::kBeta) { + return "beta"; + } else { + return "unknown loadkind"; + } +} + +template +struct CollectiveLoadVector { + using SharedStorage = cute::array_aligned>; + using PipelineState = typename cutlass::PipelineState; + + using VectorProcessor = VectorProcessor_; + + static_assert(rank_v == 2 || rank_v == 3); + + static constexpr LoadKindVector kind = kKind; + static constexpr int VectorSize = size<0>(SmemLayout{}); + + CUTE_DEVICE + CollectiveLoadVector(ElementSrc const* src, GmemLayout layout, ElementSrc oob_value, + Pipeline& pipeline, SharedStorage& storage) + : src_(src), + src_layout_(layout), + src_oob_value_(oob_value), + pipeline_(pipeline), + storage_(storage) {} + + template + CUTE_DEVICE auto partition_SD(ProblemSize const& problem_size, TileShape const& tile_shape, + WorkDesc const& work_desc) { + constexpr auto BlkSeqQ = decltype(get<0>(tile_shape))::value; + + Tensor g = [&] { + auto head_idx = work_desc.o_head_idx(); // num_o_heads == num_sab_heads + DPRINTF0_W("slice view GMEM %s: seq_idx:%d head_idx:%d tok_offset:%lld\n", to_string(kind), + work_desc.seq_idx, head_idx, work_desc.tok_offset); + Tensor m_varlen_head = make_tensor(make_gmem_ptr(src_), src_layout_); + + Tensor m_varlen = m_varlen_head(_, head_idx); // slice into current head_idx + Tensor m_offset = domain_offset(make_coord(work_desc.tok_offset), + m_varlen); // offset to start of the current sequence + Tensor g_full = flat_divide(m_offset, BlkSeqQ); // (blk, iter_blk) + return g_full; + }(); + // (blk, pipe) or (blk, pipe, N), N for feature rich preprocess, data will be stored at 0 + Tensor s = make_tensor(make_smem_ptr(storage_.data()), SmemLayout{}); + + auto thr_layout = Layout<_32>{}; + auto val_layout = Layout<_1>{}; + auto tiled_copy = + make_tiled_copy(Copy_Atom, ElementDst>{}, thr_layout, val_layout); + auto thr_copy = tiled_copy.get_thread_slice(cutlass::canonical_lane_idx()); + + auto coord = thr_copy.partition_S(make_identity_tensor(Shape, _1>{})); + auto len_of_last_blk = work_desc.seq_len - (ceil_div(work_desc.seq_len, BlkSeqQ) - 1) * BlkSeqQ; + + // auto mask = FunctionPredTensor([coord, len_of_last_blk](auto frag_coord) { + // auto coord_in_blk = get<0>(coord(frag_coord)); + // return coord_in_blk < len_of_last_blk; + // }); + // NOTE: old FunctionPredTensor is easier to understand, cute::lazy::transform means + // coord(runtime_input) and then transfrom with the given lambda + auto mask = cute::lazy::transform(coord, [len_of_last_blk](auto const& c) { + auto coord_in_blk = get<0>(c); + return coord_in_blk < len_of_last_blk; + }); + + auto src = thr_copy.partition_S(g); // (cpy, iter_cpy, iter_blk) + auto dst = thr_copy.partition_D(s); // (cpy, iter_cpy, pipe) + + return make_tuple(src, dst, mask); + } + + template + CUTE_DEVICE void step(SrcDst const& src_dst, int src_iter, PipelineState& dst_pipe, int num_iters, + VectorProcessor processor = {}) { + auto src = get<0>(src_dst); + auto dst = get<1>(src_dst); + + auto regs = make_fragment_like(take<0, 2>(shape(dst))); + if constexpr (!IsTail) { + copy(src(_, _, src_iter), regs); + } else { + auto mask = get<2>(src_dst); + fill(regs, src_oob_value_); + copy_if(mask, src(_, _, src_iter), regs); + } + + int dst_pipe_idx = dst_pipe.index(); + + DPRINTF0_WG("%s pipeline.producer_acquire smem_pipe_write:%d\n", to_string(kind), dst_pipe_idx); + pipeline_.producer_acquire(dst_pipe); + cutlass::arch::fence_view_async_shared(); + + if constexpr (rank_v == 3) { + copy(regs, dst(_, _, _0{}, dst_pipe_idx)); + } else { + copy(regs, dst(_, _, dst_pipe_idx)); + } + + Tensor s = make_tensor(make_smem_ptr(storage_.data()), SmemLayout{}); + if constexpr (!std::is_same_v) { + if constexpr (rank_v == 3) { + processor(s(_, _, dst_pipe_idx)); + } else { + processor(s(_, dst_pipe_idx)); + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_.producer_commit(dst_pipe); + ++dst_pipe; + } + + private: + ElementSrc const* src_; + GmemLayout src_layout_; // in (packed_seq, H) coordinate + ElementSrc src_oob_value_; + Pipeline& pipeline_; + SharedStorage& storage_; +}; + +} // namespace flat::collective diff --git a/csrc/flat/common.hpp b/csrc/flat/common.hpp new file mode 100644 index 0000000000..91939e6085 --- /dev/null +++ b/csrc/flat/common.hpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#include "debug.hpp" + +#define FLAT_UNUSED_PARAMETER(x) (void)x + +#define CHECK(expr, msg) \ + do { \ + if (!(expr)) { \ + std::string buffer(1024, '\0'); \ + sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", ##expr, msg __FILE__, __LINE__); \ + throw std::runtime_error(buffer.c_str()); \ + } \ + } while (0) + +#define CUDA_CHECK(expr) \ + do { \ + cudaError_t err = (expr); \ + if (err != cudaSuccess) { \ + std::string buffer(1024, '\0'); \ + sprintf(buffer.data(), "CUDA Error: %s, Code: %d at %s:%d\n", cudaGetErrorName(err), err, \ + __FILE__, __LINE__); \ + throw std::runtime_error(buffer.c_str()); \ + } \ + } while (0) diff --git a/csrc/flat/cute_ext.hpp b/csrc/flat/cute_ext.hpp new file mode 100644 index 0000000000..91dd0dc2ab --- /dev/null +++ b/csrc/flat/cute_ext.hpp @@ -0,0 +1,52 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/detail/layout.hpp" + +namespace flat { + +using namespace cute; + +template +__forceinline__ __host__ __device__ constexpr auto select_layout(Layout&& l) { + if constexpr (is_composed_layout::value) { + return make_composed_layout(l.layout_a(), l.offset(), select(l.layout_b())); + } else { + return select(l); + } +} + +template +__forceinline__ __host__ __device__ constexpr auto select_tensor(Tensor&& t) { + if constexpr (is_composed_layout::value) { + return make_tensor( + std::forward(t).data(), + make_composed_layout(std::forward(t).layout().layout_a(), + std::forward(t).layout().offset(), + select(std::forward(t).layout().layout_b()))); + } else { + return make_tensor(std::forward(t).data(), select(t.layout())); + } +} + +template +CUTE_DEVICE constexpr size_t alignment_for_swizzle(Layout&& layout) { + return cutlass::detail::alignment_for_swizzle(std::forward(layout)); +} + +} // namespace flat diff --git a/csrc/flat/debug.hpp b/csrc/flat/debug.hpp new file mode 100644 index 0000000000..b3f27b2c5c --- /dev/null +++ b/csrc/flat/debug.hpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cute/config.hpp" + +#if DEBUG_PIPE +#define PIPE_DEBUG_PRINTF(fmt, ...) \ + if (threadIdx.x == 0) printf("%s:%d " fmt, __FILE__, __LINE__, ##__VA_ARGS__) +#else +#define PIPE_DEBUG_PRINTF(...) +#endif + +#ifndef FLAT_DEBUG_PRINT +#define FLAT_DEBUG_PRINT 0 +#endif + +#if FLAT_DEBUG_PRINT +#define IS_PRINT_BLOCK cute::block(1) +#define DPRINTF(fmt, ...) \ + if (IS_PRINT_BLOCK) printf("%s:%d " fmt, __FILE__, __LINE__, ##__VA_ARGS__) +#define DPRINTF0(fmt, ...) \ + if (IS_PRINT_BLOCK && threadIdx.x == 0) printf("%s:%d " fmt, __FILE__, __LINE__, ##__VA_ARGS__) +#define DPRINTF_W(fmt, ...) \ + if (IS_PRINT_BLOCK) \ + printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \ + threadIdx.x, ##__VA_ARGS__) +#define DPRINTF0_W(fmt, ...) \ + if (IS_PRINT_BLOCK && threadIdx.x % 32 == 0) \ + printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \ + threadIdx.x, ##__VA_ARGS__) +#define DPRINTF_WG(fmt, ...) \ + if (IS_PRINT_BLOCK) \ + printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \ + threadIdx.x, ##__VA_ARGS__) +#define DPRINTF0_WG(fmt, ...) \ + if (IS_PRINT_BLOCK && threadIdx.x % 128 == 0) \ + printf("%s:%d [WG%d][W%d][T%-3d] " fmt, __FILE__, __LINE__, threadIdx.x / 128, threadIdx.x / 32, \ + threadIdx.x, ##__VA_ARGS__) +#else +#define DPRINTF(...) +#define DPRINTF0(...) +#define DPRINTF_W(...) +#define DPRINTF0_W(...) +#define DPRINTF_WG(...) +#define DPRINTF0_WG(...) +#endif + +#if FLAT_DEBUG_PRINT +#define DPRINT_TMA_DESC(tma_dess_addr) \ + do { \ + auto p = reinterpret_cast(tma_dess_addr); \ + DPRINTF( \ + "\n" \ + "%08X%08X %08X%08X %08X%08X %08X%08X\n" \ + "%08X%08X %08X%08X %08X%08X %08X%08X\n" \ + "%08X%08X %08X%08X %08X%08X %08X%08X\n" \ + "%08X%08X %08X%08X %08X%08X %08X%08X\n", \ + p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], p[9], p[10], p[11], p[12], p[13], \ + p[14], p[15], p[16], p[17], p[18], p[19], p[20], p[21], p[22], p[23], p[24], p[25], p[26], \ + p[27], p[28], p[29], p[30], p[31]); \ + } while (0) +#else +#define DPRINT_TMA_DESC(tma_dess_addr) +#endif diff --git a/csrc/flat/hopper/collective/flat_collective_load.hpp b/csrc/flat/hopper/collective/flat_collective_load.hpp new file mode 100644 index 0000000000..b587a648be --- /dev/null +++ b/csrc/flat/hopper/collective/flat_collective_load.hpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/pipeline/sm90_pipeline.hpp" + +namespace flat::collective { + +using namespace cute; + +enum class LoadKind { + kQ, + kK, + kV, +}; + +CUTE_HOST_DEVICE constexpr char const* to_string(LoadKind kind) { + if (kind == LoadKind::kQ) { + return "Q"; + } else if (kind == LoadKind::kK) { + return "K"; + } else if (kind == LoadKind::kV) { + return "V"; + } else { + return "unknown loadkind"; + } +} + +template +struct CollectiveLoadTma { + using SharedStorage = cute::array_aligned>; + using PipelineState = typename cutlass::PipelineState; + + static constexpr LoadKind kind = kKind; + TMA const& tma_load; + Pipeline& pipeline; + SharedStorage& storage; + + CUTE_DEVICE + CollectiveLoadTma(TMA const& tma_load, Pipeline& pipeline, SharedStorage& storage) + : tma_load(tma_load), pipeline(pipeline), storage(storage) {} + + template + CUTE_DEVICE auto partition_SD(ProblemSize const& problem_size, TileShape const& tile_shape, + WorkDesc const& work_desc) { + constexpr auto BlkSeqQ = decltype(get<0>(tile_shape))::value; + constexpr auto BlkSeqKV = decltype(get<1>(tile_shape))::value; + constexpr auto HeadSize = decltype(get<2>(tile_shape))::value; + + Tensor g = [&] { + if constexpr (kind == LoadKind::kQ) { + DPRINTF0_W("slice view GMEM %s: seq_idx:%d head_idx:%d tok_offset:%lld\n", to_string(kind), + work_desc.seq_idx, work_desc.q_head_idx(), work_desc.tok_offset); + Tensor m_varlen_head = tma_load.get_tma_tensor( + make_shape(problem_size.total_seqlen, problem_size.head_size, + problem_size.num_q_heads)); // global view to the packed varlen sequence + Tensor m_varlen = + m_varlen_head(_, _, work_desc.q_head_idx()); // slice into current head_idx + Tensor m_offset = domain_offset(make_coord(work_desc.tok_offset, _0{}), + m_varlen); // offset to start of the current sequence + Tensor g_full = local_tile(m_offset, make_tile(BlkSeqQ, HeadSize), + make_coord(_, _0{})); // (blk, d, iter_blk) + return g_full; + } else { + auto num_heads = + (kind == LoadKind::kK ? problem_size.num_k_heads : problem_size.num_v_heads); + auto head_idx = (kind == LoadKind::kK ? work_desc.k_head_idx() : work_desc.v_head_idx()); + DPRINTF0_W("slice view GMEM %s: seq_idx:%d head_idx:%d tok_offset:%lld\n", to_string(kind), + work_desc.seq_idx, head_idx, work_desc.tok_offset); + Tensor m_varlen_head = tma_load.get_tma_tensor( + make_shape(problem_size.head_size, problem_size.total_seqlen, + num_heads)); // global view to the packed varlen sequence + Tensor m_varlen = m_varlen_head(_, _, head_idx); // slice into current head_idx + Tensor m_offset = domain_offset(make_coord(_0{}, work_desc.tok_offset), + m_varlen); // offset to start of the current sequence + Tensor g_full = local_tile(m_offset, make_tile(HeadSize, BlkSeqKV), + make_coord(_0{}, _)); // (d, blk, iter_blk) + return g_full; + } + }(); + Tensor s = make_tensor(make_smem_ptr(storage.data()), SmemLayout{}); + + auto block_tma = tma_load.get_slice(_0{}); // do not support cluster + return make_tuple(block_tma.partition_S(g), block_tma.partition_D(s)); + } + + template + CUTE_DEVICE void step(SrcDst const& src_dst, int src_iter, PipelineState& dst_pipe, + uint32_t lane_predicate) { + if (lane_predicate == 1) { + DPRINTF_WG("%s pipeline.producer_acquire smem_pipe_write:%d\n", to_string(kind), + dst_pipe.index()); + if constexpr (kAcquireBarrier) { + pipeline.producer_acquire(dst_pipe); + } + using BarrierType = typename Pipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(dst_pipe); + + auto src = get<0>(src_dst); + auto dst = get<1>(src_dst); + + copy(tma_load.with(*tma_barrier), src(_, _, _, src_iter), dst(_, _, _, dst_pipe.index())); + ++dst_pipe; + } + } +}; + +} // namespace flat::collective diff --git a/csrc/flat/hopper/collective/flat_collective_store.hpp b/csrc/flat/hopper/collective/flat_collective_store.hpp new file mode 100644 index 0000000000..8cca5b4fba --- /dev/null +++ b/csrc/flat/hopper/collective/flat_collective_store.hpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "flat/cute_ext.hpp" + +namespace flat::collective { + +using namespace cute; + +/* +NOTE: what we need is as follows + + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; using CollectiveStoreO = typename +cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, +cutlass::arch::OpClassTensorOp, TileShapeO1, ClusterShape, +cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulatorO, ElementAccumulatorO, void, +LayoutO, Alignment, // C, not exists ElementO, +decltype(select<1,0,2>(LayoutO{})), Alignment, // D + cutlass::epilogue::TmaWarpSpecializedCooperative, DefaultOperation>::CollectiveOp; + +but unfortunately the required type alias is only useful for our purpose is private so we roll out +our own. +*/ + +CUTE_DEVICE uint32_t smid() { +#ifdef __CUDA_ARCH__ + uint32_t virtual_smid; + asm("mov.u32 %0, %%smid;" : "=r"(virtual_smid)); + return virtual_smid; +#else + return 0; +#endif +} + +template +struct CollectiveStoreTma { + static_assert(size_v == 1); + using TileShape_MNK = TileShape_MNK_; + using TileShape_MN = decltype(select<0, 1>( + TileShape_MNK{})); // Collective work on TileShape_MN, it is also the OutputTile + using SizeM = decltype(get<0>(TileShape_MNK{})); // head_size + using SizeN = decltype(get<1>(TileShape_MNK{})); // seqlen + + constexpr static bool is_m_major_O = cutlass::epilogue::collective::detail::is_m_major(); + +#if 0 + // NOTE: the following derived layout is a bit slower than the manual one, will evaluate it later + using SmemLayoutAtom = decltype(cutlass::epilogue::collective::detail::sm90_get_epilogue_smem_swizzle_layout_atom< + StrideO, ElementO, TileShape_MN>()); +#else + static_assert(sizeof(SmemElementO) == 2); + using SmemLayoutAtom = GMMA::Layout_MN_SW32_Atom; +#endif + + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtom{}, make_shape(SizeM{}, SizeN{}, Int{}), + cute::conditional_t, Step<_1, _2, _3>>{})); + + constexpr static uint32_t TmaTransactionBytes = + (size(take<0, 2>(SmemLayoutO{})) * static_cast(sizeof_bits::value)) / + 8; + + using CopyOpR2S = + decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator< + StrideO, ElementO, TileShape_MN>()); + using CopyAtomR2S = Copy_Atom; + + using CopyOpS2G = SM90_TMA_STORE; + + using SharedStorage = cute::array_aligned, + alignment_for_swizzle(SmemLayoutO{})>; + using Pipeline = cutlass::PipelineAsync; // NOT PipelineTmaStore! + using PipelineState = cutlass::PipelineState; + + struct Arguments { + ElementO* ptr_O; + StrideO dO; + }; + + struct Params { + using TMA_O = decltype(make_tma_copy(CopyOpS2G{}, + make_tensor(make_gmem_ptr(nullptr), + repeat_like(StrideO{}, int32_t(0)), StrideO{}), + take<0, 2>(SmemLayoutO{}), TileShape_MN{}, _1{})); + + TMA_O tma_store_o; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + void* tensormaps; + }; + using TMA = typename Params::TMA_O; + + CUTE_DEVICE + CollectiveStoreTma(TMA const& tma_store, Pipeline& pipeline, SharedStorage& storage, + void* tensormaps) + : tma_store_(tma_store), pipeline_(pipeline), storage_(storage), tensormaps_(tensormaps) {} + + template + static Params to_underlying_arguments(ProblemSize const& problem_size, Arguments const& args, + void* workspace) { + auto problem_size_MNKL = append<4>(problem_size, 1); + auto [M, N, K, L] = problem_size_MNKL; + + Tensor tensor_o = + make_tensor(make_gmem_ptr(args.ptr_O), make_layout(make_shape(M, N, L), args.dO)); + TMA tma_store_o = + make_tma_copy_C_sm90(CopyOpS2G{}, tensor_o, take<0, 2>(SmemLayoutO{}), TileShape_MN{}); + + return { + .tma_store_o = tma_store_o, + .tma_transaction_bytes = TmaTransactionBytes, + .tensormaps = workspace, + }; + } + + static size_t get_workspace_size(/*Arguments const& args,*/ int sm_count) { + // only use additional TMA desc for output tail tiles + size_t num_bytes = sizeof(cute::TmaDescriptor) * sm_count; + DPRINTF("workspace num_bytes:%zu\n", num_bytes); + return num_bytes; + } + + template + static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, + /*Arguments const& args,*/ void* workspace, + cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + + CUTE_DEVICE static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + template + CUTE_DEVICE auto partition_SD(ProblemSize const& problem_size, TileShape const& tile_shape, + WorkDesc const& work_desc) { + constexpr auto BlkSeqQ = decltype(get<0>(tile_shape))::value; + constexpr auto HeadSize = decltype(get<2>(tile_shape))::value; + + Tensor g = [&] { + DPRINTF0_W("slice view GMEM O: seq_idx:%d head_idx:%d tok_offset:%lld\n", work_desc.seq_idx, + work_desc.o_head_idx(), work_desc.tok_offset); + Tensor m_varlen_head = tma_store_.get_tma_tensor( + make_shape(problem_size.head_size, problem_size.total_seqlen, + problem_size.num_o_heads)); // global view to the packed varlen sequence + Tensor m_varlen = m_varlen_head(_, _, work_desc.o_head_idx()); // slice into current head_idx + Tensor m_offset = domain_offset(make_coord(_0{}, work_desc.tok_offset), + m_varlen); // offset to start of the current sequence + Tensor g_full = local_tile(m_offset, make_tile(HeadSize, BlkSeqQ), + make_coord(_0{}, _)); // (d, blk, iter_blk) + return g_full; + }(); + Tensor s = make_tensor(make_smem_ptr(storage_.data()), SmemLayoutO{}); + + auto block_tma = tma_store_.get_slice(_0{}); // do not support cluster + return make_tuple(block_tma.partition_S(s), block_tma.partition_D(g)); + } + + template + CUTE_DEVICE static bool can_process(ProblemSize const& problem_size, WorkDesc const& work_desc, + int blk, int num_blocks) { + if (blk < num_blocks - 1) { + // intermediate full tiles, always use TMA + return true; + } else if (work_desc.seq_len % SizeN{} == 0 || work_desc.seq_idx == problem_size.num_seqs - 1) { + // 1. last tile but full, also use TMA + // 2. last tile but last seq, oob can be handled by TMA + return true; + } else { + return false; + } + } + + template + CUTE_DEVICE void step(ProblemSize const& problem_size, WorkDesc const& work_desc, + SrcDst const& src_dst, PipelineState& src_pipe, int dst_iter, int num_iters, + uint32_t lane_predicate) { + auto src = get<0>(src_dst); + auto dst = get<1>(src_dst); + + if (dst_iter == 0) { + bool can_process_tail = can_process(problem_size, work_desc, num_iters - 1, num_iters); + if (!can_process_tail) { + create_tensormap_for_tail(work_desc, lane_predicate); + } + } + + DPRINTF0_WG("pipeline.producer_acquire smem_pipe_read:%d\n", src_pipe.index()); + if constexpr (kAcquireBarrier) { + pipeline_.consumer_wait(src_pipe); + } + + if (can_process(problem_size, work_desc, dst_iter, num_iters)) { + DPRINTF0_W("store src_pipe:%d -> blk:%d\n", src_pipe.index(), dst_iter); + if (lane_predicate == 1) { + copy(tma_store_, src(_, _, _, src_pipe.index()), dst(_, _, _, dst_iter)); + } + } else { + cute::TmaDescriptor* tensormap = acquire_tensormap_for_tail(); + DPRINTF0_W("store tail with tensormap:%p src_pipe:%d -> blk:%d\n", tensormap, + src_pipe.index(), dst_iter); + if (lane_predicate == 1) { + copy(tma_store_.with(tensormap), src(_, _, _, src_pipe.index()), dst(_, _, _, dst_iter)); + } + } + + if constexpr (kAcquireBarrier) { + pipeline_.consumer_release(src_pipe); + } + ++src_pipe; + } + + template + CUTE_DEVICE void create_tensormap_for_tail(WorkDesc const& work_desc, uint32_t lane_predicate) { + namespace ptx = cuda::ptx; + constexpr int num_of_16B = sizeof(cute::TmaDescriptor) / sizeof(uint128_t); + + cute::TmaDescriptor* tensormap = static_cast(tensormaps_) + smid(); + + auto lane_idx = cutlass::canonical_lane_idx(); + if (lane_idx < num_of_16B) { + auto src = reinterpret_cast(tma_store_.get_tma_descriptor()); + auto dst = reinterpret_cast(tensormap); + + dst[lane_idx] = src[lane_idx]; + } + __syncwarp(); + + if (lane_predicate == 1) { + uint32_t new_total_seqlen = work_desc.tok_offset + work_desc.seq_len; + ptx::tensormap_replace_global_dim(ptx::space_global, tensormap, /*ord=*/ptx::n32_t<1>{}, + new_total_seqlen); + } + __syncwarp(); + + ptx::fence_proxy_tensormap_generic(ptx::sem_release, ptx::scope_cta); + } + + CUTE_DEVICE cute::TmaDescriptor* acquire_tensormap_for_tail() { + namespace ptx = cuda::ptx; + cute::TmaDescriptor* tensormap = static_cast(tensormaps_) + smid(); + ptx::fence_proxy_tensormap_generic(ptx::sem_acquire, ptx::scope_cta, tensormap, + /*size=*/ptx::n32_t<128>{}); + return tensormap; + } + + private: + TMA const& tma_store_; + Pipeline& pipeline_; + SharedStorage& storage_; + void* tensormaps_; +}; + +} // namespace flat::collective diff --git a/csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp b/csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp new file mode 100644 index 0000000000..49f499511a --- /dev/null +++ b/csrc/flat/hopper/collective/flat_collective_tma_warpspecialized_delta_rule.hpp @@ -0,0 +1,1239 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "../../cute_ext.hpp" +#include "../../math_order_barrier.hpp" +#include "../../unused.hpp" +#include "../collective/flat_collective_load.hpp" +#include "../collective/flat_collective_store.hpp" +#include "../collective/flat_common.hpp" +#include "../collective/flat_named_barriers.hpp" +#include "../kernel/flat_options.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "flat/ampere/collective/flat_collective_inverse.hpp" +#include "flat/ampere/collective/flat_collective_load.hpp" + +// #define INLINE_LAMBDA [[gnu::always_inline]] +#define INLINE_LAMBDA __attribute__((always_inline)) +// #define INLINE_LAMBDA [[msvc::forceinline]] + +#define WORKAROUND_WGMMA_PERFORMANCE_LOSS() \ + if (thread_idx > 8192) { \ + __syncwarp(); \ + } + +namespace flat::collective { + +struct DeltaRuleNamedBarriers : FlatSharedNamedBarriers { + static constexpr int KKLaunched = FlatSharedNamedBarriers::NumBarriersUsed + 0; + static constexpr int AuxMath = FlatSharedNamedBarriers::NumBarriersUsed + 1; +}; + +using namespace cute; +using flat::kernel::find_option_t; +using flat::kernel::Tag; + +template +struct FlatMainloopTmaWarpSpecializedDeltaRule { + using Element = Element_; + using ElementAccumulatorQK = ElementAccumulatorQK_; + using ElementAccumulatorO = ElementAccumulatorQK; + using ElementAccumulatorKV = ElementAccumulatorKV_; + using ElementO = Element; + + using TileShape = TileShape_; + + using LayoutQ = LayoutQ_; // (seqlen_q, d, h) + using LayoutK = LayoutK_; // (seqlen_k, d, h) + using LayoutV = LayoutV_; // (seqlen_k, d, h) + using LayoutO = LayoutO_; // (seqlen_k, d, h) + + // Options + static constexpr bool kIsPersistent = + find_option_t::value; + + static constexpr bool kInitStateFromInput = + find_option_t::value; + + static constexpr int NumLoadWarpGroups = 1; + static constexpr int NumStateMmaWarpGroups = 2; + static constexpr int NumAuxMmaWarpGroups = 1; + + static constexpr int StageCountQ = find_option_t, Options>::value; + static constexpr int StageCountK = find_option_t, Options>::value; + static constexpr int StageCountV = find_option_t, Options>::value; + + static constexpr int NeedsAlpha = + find_option_t::value; + static constexpr int NeedsBeta = find_option_t::value; + + static constexpr int NeedsDecay = + find_option_t::value; + static_assert(!NeedsDecay, "DeltaRule does not supports decay"); + + static constexpr int NumLoadThreads = NumLoadWarpGroups * 128; + static constexpr int NumStateMmaThreads = NumStateMmaWarpGroups * 128; + static constexpr int NumAuxMmaThreads = NumAuxMmaWarpGroups * 128; + + static constexpr uint32_t OrderedBarrierId0 = + uint32_t(cutlass::arch::ReservedNamedBarriers::StreamkBarrier0); + static constexpr uint32_t OrderedBarrierId1 = + uint32_t(cutlass::arch::ReservedNamedBarriers::StreamkBarrier1); + + using OrderedMathBarriers = std::conditional_t< + NumStateMmaWarpGroups == 2, + OrderedNamedBarriers, + OrderedNamedBarriers>; + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesK = cutlass::gemm::collective::StageCount; + using StagesV = cutlass::gemm::collective::StageCount; + using StagesO = cutlass::gemm::collective::StageCount<2>; + using ClusterShape = Shape<_1, _1, _1>; + + using StagesQK = cutlass::gemm::collective::StageCount<2>; + using StagesKK = cutlass::gemm::collective::StageCount<2>; + + using StagesAlphaBeta = cutlass::gemm::collective::StageCount<5>; + + static constexpr int Alignment = 16 / sizeof(Element); + + static constexpr auto BlkSeqQ = get<0>(TileShape{}); // Blk_Q + static constexpr auto BlkSeqKV = get<1>(TileShape{}); // Blk_K/V + static constexpr auto HeadSize = get<2>(TileShape{}); // D (Dq, Dk, Dv all equal) + static constexpr auto HeadSizeQK = HeadSize; + static constexpr auto HeadSizeV = HeadSize; + + using TileShapeQK = decltype(make_shape(BlkSeqQ, BlkSeqKV, HeadSizeQK)); + using TileShapeKK = decltype(make_shape(BlkSeqKV, BlkSeqKV, HeadSizeQK)); + using TileShapeKV = decltype(make_shape(HeadSizeV, HeadSizeQK, BlkSeqKV)); + static_assert(std::is_same_v); + + using TileShapeO2 = decltype(make_shape(HeadSizeV, BlkSeqQ, BlkSeqKV)); + using TileShapeO1 = decltype(make_shape(HeadSizeV, BlkSeqQ, HeadSizeQK)); + + static_assert(BlkSeqQ % 64 == 0); + static_assert(BlkSeqQ == 64 || BlkSeqQ == 128); + static constexpr bool IsQKCooperative = BlkSeqQ == 128; + static constexpr bool IsKKCooperative = IsQKCooperative; + + using DummyStages = cutlass::gemm::collective::StageCount<2>; + ; + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, Element, LayoutQ, Alignment, Element, + LayoutK, Alignment, ElementAccumulatorQK, TileShapeQK, ClusterShape, DummyStages, + std::conditional_t>::CollectiveOp; + + using CollectiveMmaKV_G2S = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, Element, + decltype(select<1, 0, 2>(LayoutV{})), Alignment, // direct TMA copy for GMEM -> SMEM + Element, decltype(select<1, 0, 2>(LayoutK{})), Alignment, ElementAccumulatorKV, TileShapeKV, + ClusterShape, DummyStages, cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + // raw layout for copy + using SmemLayoutQ_SD = + decltype(unstage_smem_layout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK_DS = decltype(unstage_smem_layout(typename CollectiveMmaKV_G2S::SmemLayoutB{}, + Int{})); + using SmemLayoutV_DS = decltype(unstage_smem_layout(typename CollectiveMmaKV_G2S::SmemLayoutA{}, + Int{})); + + using RefLayoutV = decltype(make_layout(select<0, 2>(TileShapeKV{}), LayoutRight{})); + using CollectiveMmaKV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, Element, RefLayoutV, + Alignment, // needs a S2R transposition for MMA + Element, decltype(select<1, 0, 2>(LayoutK{})), Alignment, ElementAccumulatorKV, TileShapeKV, + ClusterShape, DummyStages, cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + using RefLayoutKV = + decltype(make_layout(select<0, 1>(TileShapeKV{}), LayoutRight{})); // (dv, dk) + using CollectiveMmaO1 = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, Element, RefLayoutKV, Alignment, Element, + LayoutQ, Alignment, ElementAccumulatorO, TileShapeO1, ClusterShape, DummyStages, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + // (blk_q,blk_k) to align with O2 mma, LayoutRight to align with QK mma output + using DesiredLayoutQK = decltype(make_layout(select<0, 1>(TileShapeQK{}), LayoutRight{})); + using CollectiveMmaO2 = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, Element, RefLayoutV, Alignment, Element, + DesiredLayoutQK, Alignment, ElementAccumulatorO, TileShapeO2, ClusterShape, DummyStages, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; // Q@K^t + using TiledMmaKV = decltype(convert_to_gmma_rs(typename CollectiveMmaKV::TiledMma{})); + using TiledMmaO1 = decltype(convert_to_gmma_rs(typename CollectiveMmaO1::TiledMma{})); + using TiledMmaO2 = decltype(convert_to_gmma_rs(typename CollectiveMmaO2::TiledMma{})); + + static constexpr int TiledMmaQKNumThreads = size(TiledMmaQK{}); + static_assert(size(TiledMmaQK{}) == NumAuxMmaThreads); + + static_assert(size(TiledMmaKV{}) == NumStateMmaThreads); + static_assert(size(TiledMmaO1{}) == NumStateMmaThreads); + static_assert(size(TiledMmaO2{}) == NumStateMmaThreads); + + using CollectiveStoreO = + CollectiveStoreTma(LayoutO{})), StagesO::value>; + + // layout for compute + using QKSmemLayoutQ = SmemLayoutQ_SD; + using QKSmemLayoutK = decltype(select_layout<1, 0, 2>(SmemLayoutK_DS{})); + + using KVSmemLayoutK = SmemLayoutK_DS; + using KVSmemLayoutV = SmemLayoutV_DS; + + // layout for compute output + using SmemLayoutQK = decltype(tile_to_shape( + GMMA::Layout_K_INTER_Atom{}, + flatten(make_shape(select<0, 1>(TileShapeQK{}), Int{})), + Step<_1, _2, _3>{})); + using SmemLayoutO = typename CollectiveStoreO::SmemLayoutO; + + using SmemLayoutKK = decltype(tile_to_shape( + GMMA::Layout_K_INTER_Atom{}, + flatten(make_shape(select<0, 1>(TileShapeQK{}), Int{})), + Step<_1, _2, _3>{})); + + using InverseType = cutlass::half_t; + using CollectiveInverse = flat::collective::CollectiveInverse; + + using ElementAccumulatorSK = float; + using TileShapeSK = decltype(make_shape(HeadSizeV, BlkSeqKV, HeadSizeQK)); + using CollectiveMmaSK = + typename cutlass::gemm::collective::CollectiveBuilder< // basically the same as O1 + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, Element, RefLayoutKV, Alignment, + Element, LayoutK, Alignment, ElementAccumulatorSK, TileShapeSK, ClusterShape, DummyStages, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + using ElementAccumulatorNewV = float; + using TileShapeNewV = decltype(make_shape(HeadSizeV, BlkSeqKV, BlkSeqKV)); + using RefLayoutSK = + decltype(make_layout(select<0, 2>(TileShapeNewV{}), LayoutRight{})); // (dv, Blk) + using DesiredLayoutKK = decltype(make_layout(select<1, 2>(TileShapeNewV{}), LayoutRight{})); // + using CollectiveMmaNewV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, Element, RefLayoutSK, Alignment, Element, + DesiredLayoutKK, Alignment, ElementAccumulatorKV, TileShapeNewV, ClusterShape, DummyStages, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + // FIXME: K@K^t are not exactly the same as Q@K^t, but similar enough + using TiledMmaKK = + typename CollectiveMmaQK::TiledMma; // T = inv(I + strict_lower_triangular(K@K^t)) + using TiledMmaSK = + decltype(convert_to_gmma_rs(typename CollectiveMmaSK::TiledMma{})); // ?? = -S@K^t + V^t + using TiledMmaNewV = + decltype(convert_to_gmma_rs(typename CollectiveMmaNewV::TiledMma{})); // NewV = ??@T^t + + static constexpr int TiledMmaKKNumThreads = size(TiledMmaKK{}); + static_assert(size(TiledMmaKK{}) == NumAuxMmaThreads); + + using GmemStrideAlphaBeta = Stride; + using GmemLayoutAlphaBeta = Layout, GmemStrideAlphaBeta>; // (seq, head) + + // (blk, pipe, cumsum_log/cumprod), + // 0 for cumsum(log(alpha)) aka log(cumprod(alpha)) + // 1 for cumprod(alpha) + // 2 for cumprod(alpha) * scale + using AlphaCumSumLogIdx = _0; + using AlphaCumProdIdx = _1; + using AlphaCumProdScaleIdx = _2; + + using SmemLayoutAlpha = + decltype(make_layout(make_shape(BlkSeqQ, Int<3>{}, Int{}))); + using SmemLayoutBeta = decltype(make_layout(make_shape(BlkSeqQ, Int{}))); + + using MainloopQPipeline = cutlass::PipelineTmaAsync; + using MainloopKPipeline = cutlass::PipelineTmaAsync; + using MainloopVPipeline = cutlass::PipelineTmaAsync; + using MainloopOPipeline = typename CollectiveStoreO::Pipeline; + + using MainloopQKPipeline = cutlass::PipelineAsync; + using MainloopKKPipeline = cutlass::PipelineAsync; + + using MainloopAlphaPipeline = + std::conditional_t, Unused>; + using MainloopBetaPipeline = + std::conditional_t, Unused>; + + using QPipelineState = typename cutlass::PipelineState; + using KPipelineState = typename cutlass::PipelineState; + using VPipelineState = typename cutlass::PipelineState; + using OPipelineState = typename CollectiveStoreO::PipelineState; + + using QKPipelineState = cutlass::PipelineState; + using KKPipelineState = cutlass::PipelineState; + + using AlphaPipelineState = + std::conditional_t, Unused>; + using BetaPipelineState = + std::conditional_t, Unused>; + + struct AlphaProcessor { + CUTE_DEVICE + AlphaProcessor(float scale) : scale_(scale) {} + + template + CUTE_DEVICE void operator()(T&& vecs) { + constexpr int WarpSize = cutlass::NumThreadsPerWarp; + int lane_id = cutlass::canonical_lane_idx(); + + Tensor vecs_32 = flat_divide( + std::forward(vecs), + make_tile(Int{})); // ((32), iter, cumsum_log/cumprod/cumprod_scale) + Tensor vec_cumsum_log = vecs_32(make_coord(_), _, AlphaCumSumLogIdx{}); + Tensor vec_cumprod = vecs_32(make_coord(_), _, AlphaCumProdIdx{}); + Tensor vec_cumprod_s = vecs_32(make_coord(_), _, AlphaCumProdScaleIdx{}); // cumprod * scale + Tensor frag = make_tensor(size<1>(vec_cumprod)); + + CUTE_UNROLL + for (int iter = 0; iter < size(frag); ++iter) { + frag(iter) = log2f(vec_cumsum_log(lane_id, iter) + 1e-10f); + } + + CUTE_UNROLL + for (int offset = 1; offset < WarpSize; offset *= 2) { + CUTE_UNROLL + for (int iter = 0; iter < size(frag); ++iter) { + auto v = __shfl_up_sync(0xFFFFFFFF, frag(iter), offset); + if (lane_id >= offset) { + frag(iter) += v; + } + } + } + + float sum = 0.0f; + CUTE_UNROLL + for (int iter = 1; iter < size(frag); ++iter) { + sum += __shfl_sync(0xFFFFFFFF, frag(iter - 1), 31); + frag(iter) += sum; + } + + CUTE_UNROLL + for (int iter = 0; iter < size(frag); ++iter) { + vec_cumsum_log(lane_id, iter) = frag(iter); + float cumprod = exp2f(frag(iter)); + vec_cumprod(lane_id, iter) = cumprod; + vec_cumprod_s(lane_id, iter) = cumprod * scale_; + } + } + + float scale_ = 1.0f; + }; + + using BetaProcessor = Unused; + // struct BetaProcessor { + // template + // CUTE_DEVICE + // void operator()(T&& vec) { + // int lane_id = cutlass::canonical_lane_idx(); + // int warp_size = cutlass::NumThreadsPerWarp; + // for (int i = lane_id; i < size(vec); i += warp_size) { + // auto val = vec(i); + // val = max(val, 1e-10f); // clamp due to fusion with IKK before matrix inverse + // vec(i) = 1.0f / val; + // } + // } + // }; + + static constexpr int LoadQBytes = size(QKSmemLayoutQ{}(_, _, _0{})) * sizeof(Element); + static constexpr int LoadKBytes = size(KVSmemLayoutK{}(_, _, _0{})) * sizeof(Element); + static constexpr int LoadVBytes = size(KVSmemLayoutV{}(_, _, _0{})) * sizeof(Element); + static constexpr int StoreOBytes = CollectiveStoreO::TmaTransactionBytes; + + using SharedStorageO = typename CollectiveStoreO::SharedStorage; + + struct SharedStorage { + alignas(alignment_for_swizzle( + QKSmemLayoutQ{})) cute::array_aligned> smem_q; + alignas(alignment_for_swizzle( + KVSmemLayoutK{})) cute::array_aligned> smem_k; + alignas(alignment_for_swizzle( + KVSmemLayoutV{})) cute::array_aligned> smem_v; + alignas(alignment_for_swizzle( + SmemLayoutQK{})) cute::array_aligned> smem_qk; + alignas(alignment_for_swizzle( + SmemLayoutKK{})) cute::array_aligned> smem_kk; + + SharedStorageO smem_o; + // TODO: make optional + cute::array_aligned> smem_beta; + cute::array_aligned> smem_alpha; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaKV_G2S::Params::TMA_B; + using TMA_V = typename CollectiveMmaKV_G2S::Params::TMA_A; + using TMA_O = typename CollectiveStoreO::Params::TMA_O; + + using LoadQ = CollectiveLoadTma; + using LoadK = CollectiveLoadTma; + using LoadV = CollectiveLoadTma; + + using LoadAlpha = + CollectiveLoadVector; + using LoadBeta = CollectiveLoadVector; + + struct Arguments { // clang-format off + Element const* ptr_Q; LayoutQ dQ; + Element const* ptr_K; LayoutK dK; + Element const* ptr_V; LayoutV dV; + Element* ptr_O; LayoutO dO; + float* ptr_output_state; // layout fixed (kdim, vdim, num_heads, num_seqs):LayoutLeft{} + float const* ptr_input_state; + float scale; + float const* alpha_ptr; GmemStrideAlphaBeta alpha_stride; + float const* beta_ptr; GmemStrideAlphaBeta beta_stride; + }; // clang-format on + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_O tma_store_o; + void* tensormaps; + float scale; + + float* ptr_output_state; + float const* ptr_input_state; + + float const* alpha_ptr; + GmemLayoutAlphaBeta alpha_layout; + float const* beta_ptr; + GmemLayoutAlphaBeta beta_layout; + }; + + template + static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { + auto ratio = problem_size.num_q_heads > problem_size.num_v_heads + ? problem_size.num_q_heads / problem_size.num_v_heads + : problem_size.num_v_heads / problem_size.num_q_heads; + + constexpr bool IsGVAEnabled = find_option_t::value; + + bool is_gqa_like = (problem_size.num_k_heads == problem_size.num_v_heads) && + (problem_size.num_q_heads == ratio * problem_size.num_k_heads) && + (problem_size.num_q_heads == ratio * problem_size.num_v_heads); + + bool is_gva_like = (problem_size.num_q_heads == problem_size.num_k_heads) && + (problem_size.num_v_heads == ratio * problem_size.num_q_heads) && + (problem_size.num_v_heads == ratio * problem_size.num_k_heads); + return true && ((!IsGVAEnabled && is_gqa_like) || (IsGVAEnabled && is_gva_like)) && + (problem_size.head_size <= get<2>(TileShape{})) && + ((problem_size.head_size % Alignment) == 0); + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, + void* workspace) { + int64_t s = problem_size.total_seqlen; + int64_t t = problem_size.total_seqlen; + int32_t d = problem_size.head_size; + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + make_shape(s, t, d, problem_size.num_q_heads), + typename CollectiveMmaQK::Arguments{ + args.ptr_Q, args.dQ, args.ptr_K, args.dK, // never used, dummy + }, + /*workspace=*/nullptr); + + auto params_kv_k = CollectiveMmaKV_G2S::to_underlying_arguments( + make_shape(d, d, s, problem_size.num_k_heads), + typename CollectiveMmaKV_G2S::Arguments{ + args.ptr_V, select<1, 0, 2>(args.dV), // not used + args.ptr_K, select<1, 0, 2>(args.dK), // used as G2S for K + }, + /*workspace=*/nullptr); + + auto params_kv_v = CollectiveMmaKV_G2S::to_underlying_arguments( + make_shape(d, d, s, problem_size.num_v_heads), + typename CollectiveMmaKV_G2S::Arguments{ + args.ptr_V, select<1, 0, 2>(args.dV), // used as G2S for V + args.ptr_K, select<1, 0, 2>(args.dK), // not used + }, + /*workspace=*/nullptr); + + auto params_o = CollectiveStoreO::to_underlying_arguments( + make_shape(d, s, d, problem_size.num_o_heads), // in O1 + // make_shape(d, s, s, problem_size.num_o_heads), // in O2 + typename CollectiveStoreO::Arguments{args.ptr_O, select<1, 0, 2>(args.dO)}, workspace); + + return Params{ + .tma_load_q = params_qk.tma_load_a, + .tma_load_k = params_kv_k.tma_load_b, + .tma_load_v = params_kv_v.tma_load_a, + .tma_store_o = params_o.tma_store_o, + .tensormaps = params_o.tensormaps, + .scale = args.scale, + + .ptr_output_state = args.ptr_output_state, + .ptr_input_state = args.ptr_input_state, + + // TODO: refactor all name to varname_vartype + .alpha_ptr = args.alpha_ptr, + .alpha_layout = make_layout(make_shape(s, problem_size.num_sab_heads), args.alpha_stride), + .beta_ptr = args.beta_ptr, + .beta_layout = make_layout(make_shape(s, problem_size.num_sab_heads), args.beta_stride), + }; + } + + static size_t get_workspace_size(Arguments const& args, int sm_count) { + return CollectiveStoreO::get_workspace_size(sm_count); + } + + template + static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, + Arguments const& args, void* workspace, + cudaStream_t stream) { + return CollectiveStoreO::initialize_workspace(problem_shape, workspace, stream); + } + + CUTE_DEVICE static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + template + CUTE_DEVICE void load_qkv(Params const& params, ProblemShape const& problem_size, + LoadTileShape const& load_tile_shape, WorkDesc const& work_desc, + MainloopQPipeline& q_pipeline, QPipelineState& q_smem_pipe_write, + MainloopKPipeline& k_pipeline, KPipelineState& k_smem_pipe_write, + MainloopVPipeline& v_pipeline, VPipelineState& v_smem_pipe_write, + SharedStorage& storage) { + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + uint32_t lane_predicate = cute::elect_one_sync(); + + auto q_collective_load = LoadQ(params.tma_load_q, q_pipeline, storage.smem_q); + auto k_collective_load = LoadK(params.tma_load_k, k_pipeline, storage.smem_k); + auto v_collective_load = LoadV(params.tma_load_v, v_pipeline, storage.smem_v); + + auto q_src_dst = q_collective_load.partition_SD(problem_size, load_tile_shape, work_desc); + auto k_src_dst = k_collective_load.partition_SD(problem_size, load_tile_shape, work_desc); + auto v_src_dst = v_collective_load.partition_SD(problem_size, load_tile_shape, work_desc); + + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks; ++blk) { + k_collective_load.step(k_src_dst, blk, k_smem_pipe_write, lane_predicate); + q_collective_load.step(q_src_dst, blk, q_smem_pipe_write, lane_predicate); + v_collective_load.step(v_src_dst, blk, v_smem_pipe_write, lane_predicate); + } + } + + template + CUTE_DEVICE void load_beta(Params const& params, ProblemShape const& problem_size, + TileShape const& tile_shape, WorkDesc const& work_desc, + MainloopBetaPipeline& pipeline, BetaPipelineState& smem_pipe_write, + SharedStorage& storage) { + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + + // fuse post inverse diag(beta) into diagonal of IKK + // auto collective_load = LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/1.0f, + // pipeline, storage.smem_beta}; + auto collective_load = LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/0.0f, + pipeline, storage.smem_beta}; + auto src_dst = collective_load.partition_SD(problem_size, tile_shape, work_desc); + + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks - 1; ++blk) { + collective_load.step(src_dst, blk, smem_pipe_write, num_blocks); + } + collective_load.step(src_dst, num_blocks - 1, smem_pipe_write, num_blocks); + } + + template + CUTE_DEVICE void load_alpha(Params const& params, ProblemShape const& problem_size, + TileShape const& tile_shape, WorkDesc const& work_desc, + MainloopAlphaPipeline& pipeline, AlphaPipelineState& smem_pipe_write, + SharedStorage& storage) { + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + + auto collective_load = LoadAlpha{params.alpha_ptr, params.alpha_layout, /*oob_value=*/1.0f, + pipeline, storage.smem_alpha}; + auto src_dst = collective_load.partition_SD(problem_size, tile_shape, work_desc); + + typename LoadAlpha::VectorProcessor processor{params.scale}; + + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks - 1; ++blk) { + collective_load.step(src_dst, blk, smem_pipe_write, num_blocks, processor); + } + collective_load.step(src_dst, num_blocks - 1, smem_pipe_write, num_blocks, + processor); + } + + template + CUTE_DEVICE void store(TMA_O const& tma_store, void* tensormaps, ProblemSize const& problem_size, + StoreTileShape const& store_tile_shape, WorkDesc const& work_desc, + MainloopOPipeline& pipeline, PipelineState& smem_pipe_read, + SharedStorageO& storage) { + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + uint32_t lane_predicate = cute::elect_one_sync(); + + auto collective_store = CollectiveStoreO{tma_store, pipeline, storage, tensormaps}; + auto src_dst = collective_store.partition_SD(problem_size, store_tile_shape, work_desc); + + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks; ++blk) { + DPRINTF0_W("O collective_store.step smem_pipe_read:%d -> blk_idx:%d, num_blocks:%d\n", + smem_pipe_read.index(), blk, num_blocks); + collective_store.step(problem_size, work_desc, src_dst, smem_pipe_read, blk, num_blocks, + lane_predicate); + } + } + + template + CUTE_DEVICE void compute( + Params const& params, ProblemShape const& problem_size, WorkDesc const& work_desc, + MainloopQPipeline& q_pipeline, QPipelineState& q_smem_pipe_read, + MainloopKPipeline& k_pipeline, KPipelineState& k_smem_pipe_read, + MainloopVPipeline& v_pipeline, VPipelineState& v_smem_pipe_read, + MainloopOPipeline& o_pipeline, OPipelineState& o_smem_pipe_write, + MainloopQKPipeline& qk_pipeline, QKPipelineState& qk_smem_pipe_read, + MainloopKKPipeline& kk_pipeline, KKPipelineState& kk_smem_pipe_read, + MainloopAlphaPipeline& alpha_pipeline, AlphaPipelineState& alpha_smem_pipe_read, + // MainloopBetaPipeline& beta_pipeline, BetaPipelineState& beta_smem_pipe_read, + OrderedMathBarriers& math_barriers, SharedStorage& storage) { + // MAKE NVCC HAPPY! + constexpr auto zero = Element{}; + + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + DPRINTF0_WG("num_blocks: %d\n", num_blocks); + + int thread_idx = int(threadIdx.x) - NumLoadThreads; + int warpgroup_idx = thread_idx / cutlass::NumThreadsPerWarpGroup; + + float scale = params.scale; + + // Tensor Beta = make_tensor(make_smem_ptr(storage.smem_beta.data()), SmemLayoutBeta{}); + Tensor Alpha = make_tensor(make_smem_ptr(storage.smem_alpha.data()), SmemLayoutAlpha{}); + + Tensor sQqk = make_tensor(make_smem_ptr(storage.smem_q.data()), QKSmemLayoutQ{}); + Tensor sKqk = make_tensor(make_smem_ptr(storage.smem_k.data()), QKSmemLayoutK{}); + Tensor sKkv = make_tensor(make_smem_ptr(storage.smem_k.data()), KVSmemLayoutK{}); + Tensor sVkv = make_tensor(make_smem_ptr(storage.smem_v.data()), KVSmemLayoutV{}); + Tensor sQK = make_tensor(make_smem_ptr(storage.smem_qk.data()), SmemLayoutQK{}); + Tensor sO = make_tensor(make_smem_ptr(storage.smem_o.data()), SmemLayoutO{}); + + static_assert(sizeof(InverseType) == sizeof(Element)); + Tensor sKK_inv = make_tensor(make_smem_ptr(storage.smem_kk.data()), SmemLayoutKK{}); + Tensor sKK_opd = make_tensor(make_smem_ptr(reinterpret_cast(storage.smem_kk.data())), + SmemLayoutKK{}); + + /////////////////////////////////////////////////////////////////////////// + // S@K (-S K^T + V^T) + auto sk_tiled_mma = TiledMmaSK{}; + auto sk_thr_mma = sk_tiled_mma.get_thread_slice(thread_idx); + + auto layout_SKAlpha = flatten(make_layout( // broadcast Alpha vector to SK size + make_layout(select<0, 1>(TileShapeSK{}), Stride<_0, _1>{}), // (D, Blk_KV) + select<1, 2>(SmemLayoutAlpha{}) // (Idx, pipe) + )); // (D, Blk_KV, Idx, pipe) + + auto tSKrAlpha = sk_thr_mma.partition_C(Alpha.compose(layout_SKAlpha))( + _, _, _, AlphaCumProdIdx{}, _); // (frag, iter_D, iter_Blk_Q, pipe) + + // tSKrV adds to tSKrSK (acc) + using SK_V_S2R = Copy_Atom; + auto tSKrV_tiled_copy = make_tiled_copy_C(SK_V_S2R{}, sk_tiled_mma); + auto tSKrV_thr_copy = tSKrV_tiled_copy.get_thread_slice(thread_idx); + + Tensor tSKsK = sk_thr_mma.partition_B(sKqk); + Tensor tSKrK = sk_thr_mma.make_fragment_B(tSKsK); + + /////////////////////////////////////////////////////////////////////////// + // NewV = (S@K result) @ T^t + auto newv_tiled_mma = TiledMmaNewV{}; + auto newv_thr_mma = newv_tiled_mma.get_thread_slice(thread_idx); + + Tensor tNewVsB = newv_thr_mma.partition_B(sKK_opd); + Tensor tNewVrB = newv_thr_mma.make_fragment_B(tNewVsB); + + /////////////////////////////////////////////////////////////////////////// + // K@V + auto kv_tiled_mma = TiledMmaKV{}; + auto kv_thr_mma = kv_tiled_mma.get_thread_slice(thread_idx); + + Tensor tKVrKV = partition_fragment_C(kv_thr_mma, select<0, 1>(TileShapeKV{})); + + // Tensor tKVrV = kv_thr_mma.partition_fragment_A(sVkv(_, _, _0{})); // mma src + // Tensor tKVrV_cv = tKVrV_thr_copy.retile_D(tKVrV); // copy view dst + // Tensor tKVsV = tKVrV_thr_copy.partition_S(sVkv); // copy view src + + Tensor tKVsK = kv_thr_mma.partition_B(sKkv); + Tensor tKVrK = kv_thr_mma.make_fragment_B(tKVsK); + + auto const cV = make_identity_tensor(Shape, Int>{}); + Tensor tKVcV = kv_thr_mma.partition_A(cV); + + /////////////////////////////////////////////////////////////////////////// + // Q@K@V + auto o1_tiled_mma = TiledMmaO1{}; + auto o1_thr_mma = o1_tiled_mma.get_thread_slice(thread_idx); + auto o2_tiled_mma = TiledMmaO2{}; + auto o2_thr_mma = o2_tiled_mma.get_thread_slice(thread_idx); + + // A1 for Q@(KV) + // Tensor tOrKV = make_acc_into_op(tKVrKV, typename TiledMmaO1::LayoutA_TV{}); + // B1 for Q@(KV) + Tensor tOsQ = o1_thr_mma.partition_B(sQqk); + Tensor tOrQ = o1_thr_mma.make_fragment_B(tOsQ); + + // A2 for QK@V + // Tensor tOsV = o2_thr_mma.partition_A(sVkv); + // Tensor tOrV = o2_thr_mma.make_fragment_A(tOsV); + // B2 for QK@V + Tensor tOsQK = o2_thr_mma.partition_B(sQK); + Tensor tOrQK = o2_thr_mma.make_fragment_B(tOsQK); + + using O_R2S = typename CollectiveStoreO::CopyAtomR2S; + auto tiled_copy_o = make_tiled_copy_C(O_R2S{}, o1_tiled_mma); + auto thr_copy_o = tiled_copy_o.get_thread_slice(thread_idx); + auto tOsO = thr_copy_o.partition_D(sO); + + auto const cO = make_identity_tensor(Shape, Int>{}); + Tensor tOcO = o1_thr_mma.partition_C(cO); + + auto layout_OAlpha = flatten(make_layout( // broadcast Alpha vector to O size + make_layout(select<0, 1>(TileShapeO1{}), Stride<_0, _1>{}), // (D, Blk_Q) + select<1, 2>(SmemLayoutAlpha{}) // (Idx, pipe) + )); // (D, Blk_Q, Idx, pipe) + + auto tOrAlphaScale = o1_thr_mma.partition_C(Alpha.compose(layout_OAlpha))( + _, _, _, AlphaCumProdScaleIdx{}, _); // (frag, iter_D, iter_Blk_Q, pipe) + + auto const seq_idx = work_desc.seq_idx; + auto const q_head_idx = work_desc.q_head_idx(); + auto const k_head_idx = work_desc.k_head_idx(); + auto const v_head_idx = work_desc.v_head_idx(); + + auto sk_epi = [&](auto& tSKrSK, auto const& alpha_smem_pipe_read) INLINE_LAMBDA { + if constexpr (NeedsAlpha) { + transform(tSKrSK, tSKrAlpha(_, _, _, alpha_smem_pipe_read.index()), tSKrSK, + [&](auto sk, auto coeff) { return sk * coeff; }); + } + }; + + auto sk_load_v = [&](int pipe_idx) INLINE_LAMBDA { + Tensor tSKrV = make_fragment_like( + partition_fragment_C(sk_thr_mma, sVkv(_, _, _0{}))); // mma acc + Tensor tSKrV_cv = tSKrV_thr_copy.retile_D(tSKrV); // copy view dst + Tensor tSKsV = tSKrV_thr_copy.partition_S(sVkv); // copy view src + copy(tSKrV_tiled_copy, tSKsV(_, _, _, pipe_idx), tSKrV_cv); + return tSKrV; + }; + + auto kv_decay_v = [&](auto& tKVrV, auto const& alpha_smem_pipe_read, auto is_final_block_, + auto B) INLINE_LAMBDA { + constexpr bool is_final_block = decltype(is_final_block_)::value; + if constexpr (NeedsAlpha) { + Tensor Alpha_cumsum_log = Alpha(_, AlphaCumSumLogIdx{}, alpha_smem_pipe_read.index()); + float block_coeff_log = Alpha_cumsum_log(B - 1); + cute::transform(tKVrV, tKVcV, tKVrV, [&](auto val, auto coord) { + auto tok = get<1>(coord); + float coeff = [&] { + if constexpr (!is_final_block) { + return exp2f(block_coeff_log - Alpha_cumsum_log(tok)); + } else { + return tok < B ? exp2f(block_coeff_log - Alpha_cumsum_log(tok)) : 0.0f; + } + }(); + return decltype(val)(val * coeff); + }); + } + if constexpr (is_final_block) { + if constexpr (!NeedsAlpha) { + cute::transform(tKVrV, tKVcV, tKVrV, [&](auto val, auto coord) { + auto tok = get<1>(coord); + return tok < B ? val : zero; // mask v of tail oob values + }); + } + } + }; + + auto kv_load = [&](auto& tKVrKV) INLINE_LAMBDA { + DPRINTF0_WG("[%d,%d,%d,%d]>> load tKVgKV -> tKVrKV\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + int num_state_heads = problem_size.num_sab_heads; + int state_head_idx = work_desc.o_head_idx(); + auto gKV = make_tensor(make_gmem_ptr(params.ptr_input_state), + make_layout(make_shape(Int{}, Int{}, + num_state_heads, problem_size.num_seqs)))( + _, _, state_head_idx, seq_idx); // (KDim, VDim), K-contiguous + + auto tiled_copy_kv = + make_tiled_copy_C(Copy_Atom{}, kv_tiled_mma); + auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx); + + auto tKVgKV = thr_copy_kv.partition_S(select_tensor<1, 0>(gKV)); + copy(tiled_copy_kv, tKVgKV, tKVrKV); + }; + + auto kv_store = [&]() INLINE_LAMBDA { // tKVrKV is carried over whole mainloop + DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> tKVgKV\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + int num_state_heads = problem_size.num_sab_heads; + int state_head_idx = work_desc.o_head_idx(); // num_o_heads == num_sab_heads + auto gKV = make_tensor(make_gmem_ptr(params.ptr_output_state), + make_layout(make_shape(Int{}, Int{}, + num_state_heads, problem_size.num_seqs)))( + _, _, state_head_idx, seq_idx); // (KDim, VDim), K-contiguous + + auto tiled_copy_kv = + make_tiled_copy_C(Copy_Atom{}, kv_tiled_mma); + auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx); + + auto tKVgKV = thr_copy_kv.partition_D(select_tensor<1, 0>(gKV)); + copy(tiled_copy_kv, tKVrKV, tKVgKV); + }; + + auto o1_epi = [&](auto& tOrO1, auto const& alpha_smem_pipe_read) INLINE_LAMBDA { + if constexpr (NeedsAlpha) { + auto tOrAlphaScale_ = tOrAlphaScale(_, _, _, alpha_smem_pipe_read.index()); + CUTE_UNROLL + for (int i = 0; i < size(tOrO1); ++i) { + tOrO1(i) = tOrAlphaScale_(i) * tOrO1(i); + } + } else { + CUTE_UNROLL + for (int i = 0; i < size(tOrO1); ++i) { + tOrO1(i) = scale * tOrO1(i); + } + } + }; + + auto o_store = [&](auto tOrO) INLINE_LAMBDA { + auto tOrO_cvt = make_fragment_like(tOrO); + copy(tOrO, tOrO_cvt); + + DPRINTF0_WG("compute: o_pipeline.producer_wait: smem_pipe_write:%d\n", + o_smem_pipe_write.index()); + o_pipeline.producer_acquire(o_smem_pipe_write); + Tensor tOrO_cvt_cv = thr_copy_o.retile_S(tOrO_cvt); + cutlass::arch::fence_view_async_shared(); + copy(tiled_copy_o, tOrO_cvt_cv, tOsO(_, _, _, o_smem_pipe_write.index())); + cutlass::arch::fence_view_async_shared(); + o_pipeline.producer_commit(o_smem_pipe_write); + ++o_smem_pipe_write; + }; + + auto compute_loop_body = [&](int blk, auto is_first_block_, + auto is_final_block_) INLINE_LAMBDA { + constexpr bool is_first_block = decltype(is_first_block_)::value; + constexpr bool is_final_block = decltype(is_final_block_)::value; + int B = is_final_block ? valid_seq_len(work_desc, blk) : BlkSeqKV; + + // 2.1 Q @ KV, NOTE: use old KV here + DPRINTF0_WG("compute: q_pipeline.consumer_wait: smem_pipe_read:%d\n", + q_smem_pipe_read.index()); + q_pipeline.consumer_wait(q_smem_pipe_read); + if constexpr (NeedsAlpha) { + alpha_pipeline.consumer_wait(alpha_smem_pipe_read); + } + + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch O WGMMA\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + auto tOrO = partition_fragment_C(o1_thr_mma, select<0, 1>(TileShapeO1{})); + if constexpr (is_first_block) { + DPRINTF0_WG("compute: q_pipeline.consumer_release: smem_pipe_read:%d\n", + q_smem_pipe_read.index()); + q_pipeline.consumer_release(q_smem_pipe_read); + ++q_smem_pipe_read; + } else { + Tensor tOrKV = make_acc_into_op(tKVrKV, typename TiledMmaO1::LayoutA_TV{}); + warpgroup_fence_operand(tOrKV); + warpgroup_fence_operand(tOrO); + math_barriers.ordered_or_wait(warpgroup_idx); + warpgroup_arrive(); + gemm_zero_acc(o1_thr_mma, tOrKV, tOrQ(_, _, _, q_smem_pipe_read.index()), tOrO); + warpgroup_commit_batch(); // q@kv batch + math_barriers.notify_next_blocked(warpgroup_idx); + } + if constexpr (!is_first_block) { + warpgroup_wait<0>(); // q@kv batch + DPRINTF0_WG("compute: q_pipeline.consumer_release: smem_pipe_read:%d\n", + q_smem_pipe_read.index()); + q_pipeline.consumer_release(q_smem_pipe_read); + ++q_smem_pipe_read; + o1_epi(tOrO, alpha_smem_pipe_read); + } + + DPRINTF0_WG("compute: k_pipeline.consumer_wait: smem_pipe_read:%d\n", + k_smem_pipe_read.index()); + k_pipeline.consumer_wait(k_smem_pipe_read); + + auto tSKrSK = partition_fragment_C(sk_thr_mma, sVkv(_, _, _0{})); + if constexpr (!is_first_block) { + auto tSKrS = make_acc_into_op(tKVrKV, typename TiledMmaSK::LayoutA_TV{}); + warpgroup_fence_operand(tSKrSK); + warpgroup_fence_operand(tSKrS); + math_barriers.ordered_or_wait(warpgroup_idx); + warpgroup_arrive(); + gemm_zero_acc(sk_tiled_mma, tSKrS, tSKrK(_, _, _, k_smem_pipe_read.index()), tSKrSK); + warpgroup_commit_batch(); + math_barriers.notify_next_blocked(warpgroup_idx); + warpgroup_wait<0>(); + } + + DPRINTF0_WG("compute: v_pipeline.consumer_wait: smem_pipe_read:%d\n", + v_smem_pipe_read.index()); + v_pipeline.consumer_wait(v_smem_pipe_read); + auto tSKrV = sk_load_v(v_smem_pipe_read.index()); + if constexpr (!is_first_block) { + sk_epi(tSKrSK, alpha_smem_pipe_read); + transform(tSKrV, tSKrSK, tSKrV, [](auto v, auto sk) { return v - Element(sk); }); + } + + kk_pipeline.consumer_wait(kk_smem_pipe_read); + auto tNewVrA = make_acc_into_op(tSKrV, typename TiledMmaNewV::LayoutA_TV{}); + auto tNewVrC = partition_fragment_C(newv_thr_mma, select<0, 1>(TileShapeNewV{})); + warpgroup_fence_operand(tNewVrA); + warpgroup_fence_operand(tNewVrC); + math_barriers.ordered_or_wait(warpgroup_idx); + warpgroup_arrive(); + gemm_zero_acc(o1_thr_mma, tNewVrA, tNewVrB(_, _, _, kk_smem_pipe_read.index()), tNewVrC); + warpgroup_commit_batch(); // new_v batch + math_barriers.notify_next_blocked(warpgroup_idx); + warpgroup_wait<0>(); // new_v batch + DPRINTF0_WG("compute: v_pipeline.consumer_release: smem_pipe_read:%d\n", + v_smem_pipe_read.index()); + ++v_smem_pipe_read; // NOTE: if we delay this increment after consumer_release, race + // condition happens, why? + v_pipeline.consumer_release(v_smem_pipe_read); + + kk_pipeline.consumer_release(kk_smem_pipe_read); + ++kk_smem_pipe_read; + + ///////////////////////////////////////////////////////////////////////// + // 2. compute qkv + // 2.2 QK @ V, NOTE: use old KV here and QK is scaled + qk_pipeline.consumer_wait(qk_smem_pipe_read); + auto tOrV_or_tKVrV = make_acc_into_op(tNewVrC, typename TiledMmaKV::LayoutA_TV{}); + warpgroup_fence_operand(tOrV_or_tKVrV); + warpgroup_fence_operand(tOrO); + math_barriers.ordered_or_wait(warpgroup_idx); + warpgroup_arrive(); + if constexpr (is_first_block) { + gemm_zero_acc(o2_tiled_mma, tOrV_or_tKVrV, tOrQK(_, _, _, qk_smem_pipe_read.index()), tOrO); + } else { + gemm(o2_tiled_mma, tOrV_or_tKVrV, tOrQK(_, _, _, qk_smem_pipe_read.index()), tOrO); + } + warpgroup_commit_batch(); // qk@v batch + math_barriers.notify_next_blocked(warpgroup_idx); + warpgroup_wait<0>(); // qk@v batch + qk_pipeline.consumer_release(qk_smem_pipe_read); + ++qk_smem_pipe_read; + o_store(tOrO); + + ///////////////////////////////////////////////////////////////////////// + // 3. update KV + float block_coeff = 1.0f; + if constexpr (NeedsAlpha) { + block_coeff = Alpha(B - 1, AlphaCumProdIdx{}, alpha_smem_pipe_read.index()); + } + + cute::transform(tKVrKV, [&](auto kv) { return block_coeff * kv; }); + kv_decay_v(tOrV_or_tKVrV, alpha_smem_pipe_read, is_final_block_, B); + + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch KV WGMMA\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + warpgroup_fence_operand(tOrV_or_tKVrV); + warpgroup_fence_operand(tKVrKV); + math_barriers.ordered_or_wait(warpgroup_idx); + warpgroup_arrive(); + gemm(kv_tiled_mma, tOrV_or_tKVrV, tKVrK(_, _, _, k_smem_pipe_read.index()), tKVrKV); + warpgroup_commit_batch(); // k@v batch + math_barriers.notify_next_blocked(warpgroup_idx); + warpgroup_wait<0>(); + + DPRINTF0_WG("compute: k_pipeline.consumer_release: smem_pipe_read:%d\n", + k_smem_pipe_read.index()); + k_pipeline.consumer_release(k_smem_pipe_read); + ++k_smem_pipe_read; + + if constexpr (NeedsAlpha) { + alpha_pipeline.consumer_release(alpha_smem_pipe_read); + ++alpha_smem_pipe_read; + } + }; + + if constexpr (!kInitStateFromInput) { + clear(tKVrKV); + compute_loop_body(0, /*is_first_block_=*/cute::true_type{}, + /*is_final_block_=*/cute::true_type{}); + } else { + kv_load(tKVrKV); + compute_loop_body(0, /*is_first_block_=*/cute::false_type{}, + /*is_final_block_=*/cute::true_type{}); + } + CUTE_NO_UNROLL + for (int blk = 1; blk < num_blocks - 1; ++blk) { + compute_loop_body(blk, /*is_first_block_=*/cute::false_type{}, + /*is_final_block_=*/cute::false_type{}); + } + if (num_blocks != 1) { + compute_loop_body(num_blocks - 1, /*is_first_block_=*/cute::false_type{}, + /*is_final_block_=*/cute::true_type{}); + } + kv_store(); + } + + template + CUTE_DEVICE void compute_aux(Params const& params, ProblemShape const& problem_size, + WorkDesc const& work_desc, MainloopQPipeline& q_pipeline, + QPipelineState& q_smem_pipe_read, MainloopKPipeline& k_pipeline, + KPipelineState& k_smem_pipe_read, MainloopQKPipeline& qk_pipeline, + QKPipelineState& qk_smem_pipe_write, MainloopKKPipeline& kk_pipeline, + KKPipelineState& kk_smem_pipe_write, + MainloopAlphaPipeline& alpha_pipeline, + AlphaPipelineState& alpha_smem_pipe_read, + MainloopBetaPipeline& beta_pipeline, + BetaPipelineState& beta_smem_pipe_read, SharedStorage& storage) { + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + float scale = params.scale; + + Tensor Beta = make_tensor(make_smem_ptr(storage.smem_beta.data()), SmemLayoutBeta{}); + Tensor Alpha = make_tensor(make_smem_ptr(storage.smem_alpha.data()), SmemLayoutAlpha{}); + + Tensor sQqk = make_tensor(make_smem_ptr(storage.smem_q.data()), QKSmemLayoutQ{}); + Tensor sKqk = make_tensor(make_smem_ptr(storage.smem_k.data()), QKSmemLayoutK{}); + Tensor sKkv = make_tensor(make_smem_ptr(storage.smem_k.data()), KVSmemLayoutK{}); + Tensor sVkv = make_tensor(make_smem_ptr(storage.smem_v.data()), KVSmemLayoutV{}); + Tensor sQK = make_tensor(make_smem_ptr(storage.smem_qk.data()), SmemLayoutQK{}); + Tensor sO = make_tensor(make_smem_ptr(storage.smem_o.data()), SmemLayoutO{}); + + static_assert(sizeof(InverseType) == sizeof(Element)); + Tensor sKK_inv = make_tensor(make_smem_ptr(storage.smem_kk.data()), SmemLayoutKK{}); + Tensor sKK_opd = make_tensor(make_smem_ptr(reinterpret_cast(storage.smem_kk.data())), + SmemLayoutKK{}); + + /////////////////////////////////////////////////////////////////////////// + // Q@K + auto qk_tiled_mma = TiledMmaQK{}; + auto qk_thr_mma = qk_tiled_mma.get_thread_slice(thread_idx); + + Tensor tQKsQ = qk_thr_mma.partition_A(sQqk); + Tensor tQKsK = qk_thr_mma.partition_B(sKqk); + Tensor tQKrQ = qk_thr_mma.make_fragment_A(tQKsQ); + Tensor tQKrK = qk_thr_mma.make_fragment_B(tQKsK); + + auto cMqk = make_identity_tensor(select<0, 1>(TileShapeQK{})); // (QTok, KTok) + auto tQKcMqk = qk_thr_mma.partition_C(cMqk); // (idx) -> (tok_q, tok_k) + + /////////////////////////////////////////////////////////////////////////// + // K@K (basically I + strict_lower_triangular(K K^T) + auto kk_tiled_mma = TiledMmaKK{}; + auto kk_thr_mma = kk_tiled_mma.get_thread_slice(thread_idx); + Tensor tKKsK = kk_thr_mma.partition_B(sKqk); + Tensor tKKrA = kk_thr_mma.make_fragment_A(tKKsK); + Tensor tKKrB = kk_thr_mma.make_fragment_B(tKKsK); + + auto const& cMkk = cMqk; + auto tKKcMkk = kk_thr_mma.partition_C(cMkk); + + auto const seq_idx = work_desc.seq_idx; + auto const q_head_idx = work_desc.q_head_idx(); + auto const k_head_idx = work_desc.k_head_idx(); + auto const v_head_idx = work_desc.v_head_idx(); + + auto qk_and_kk_epi = [&](auto& tQKrQK, auto& tKKrKK, auto const& alpha_smem_pipe_read, + auto const& beta_smem_pipe_read, auto is_final_block_, + auto B /*valid seqlen*/) { + if constexpr (NeedsAlpha) { + Tensor Alpha_cumsum_log = Alpha(_, AlphaCumSumLogIdx{}, alpha_smem_pipe_read.index()); + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk(i); + auto [s, t] = coord; + float alpha = exp2f(Alpha_cumsum_log(s) - Alpha_cumsum_log(t)); + tQKrQK(i) *= alpha * scale; + tKKrKK(i) *= alpha; + }); + } else { + transform(tQKrQK, [scale](auto v) { return v * scale; }); + } + + if constexpr (NeedsBeta) { + Tensor Beta_ = Beta(_, beta_smem_pipe_read.index()); + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk(i); + auto [s, t] = coord; + tKKrKK(i) *= Beta_(s); + }); + } + + constexpr bool is_final_block = decltype(is_final_block_)::value; + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk(i); + auto [s, t] = coord; + bool pred = s >= t; + tQKrQK(i) = pred ? tQKrQK(i) : 0.0f; + tKKrKK(i) = + pred ? tKKrKK(i) : 0.0f; // diagonal is garbage filled, will process during inversion + if constexpr (is_final_block) { + bool pred = s < B || t < B; + tQKrQK(i) = pred ? tQKrQK(i) : 0.0f; + tKKrKK(i) = pred ? tKKrKK(i) : 0.0f; + } + }); + }; + + auto qk_store = [&](auto tQKrQK, auto const& qk_smem_pipe_write) { + auto sQK_pipe_slice = sQK(_, _, qk_smem_pipe_write.index()); + + static_assert(sizeof(Element) == 2); + using CopyOpR2S = SM90_U32x4_STSM_N; + auto tiled_copy_qk = make_tiled_copy_C(Copy_Atom{}, qk_tiled_mma); + auto thr_copy_qk = tiled_copy_qk.get_thread_slice(thread_idx); + auto tQKsQK = thr_copy_qk.partition_D(sQK_pipe_slice); + auto tQKrQK_cv = thr_copy_qk.retile_S(tQKrQK); + auto tQKrQK_cvt_cv = make_fragment_like(tQKrQK_cv); + cute::transform(tQKrQK_cv, tQKrQK_cvt_cv, [](auto v) { return Element(v); }); + copy(tiled_copy_qk, tQKrQK_cvt_cv, tQKsQK); + }; + + auto kk_store_and_inv = [&](auto tKKrKK, auto const& kk_smem_pipe_write) INLINE_LAMBDA { + auto sKK_inv_pipe_slice = sKK_inv(_, _, kk_smem_pipe_write.index()); + + static_assert(sizeof(Element) == 2); + using CopyOpR2S = SM90_U32x4_STSM_N; + auto tiled_store_kk = make_tiled_copy_C(Copy_Atom{}, kk_tiled_mma); + auto thr_store_kk = tiled_store_kk.get_thread_slice(thread_idx); + auto tKKsKK = thr_store_kk.partition_D(sKK_inv_pipe_slice); + auto tKKrKK_cv = thr_store_kk.retile_S(tKKrKK); + auto tKKrKK_cvt_cv = make_fragment_like(tKKrKK_cv); + cute::transform(tKKrKK_cv, tKKrKK_cvt_cv, [](auto v) { return InverseType(v); }); + copy(tiled_store_kk, tKKrKK_cvt_cv, tKKsKK); + + cutlass::arch::NamedBarrier::arrive_and_wait(cutlass::NumThreadsPerWarpGroup, + DeltaRuleNamedBarriers::AuxMath); + + auto collective_inverse = CollectiveInverse(DeltaRuleNamedBarriers::AuxMath); + collective_inverse.compute(sKK_inv_pipe_slice); + + // FIXME: we can ignore core matrices above diagonal + if constexpr (NeedsBeta || !std::is_same_v) { + cutlass::arch::NamedBarrier::arrive_and_wait(cutlass::NumThreadsPerWarpGroup, + DeltaRuleNamedBarriers::AuxMath); + using CopyOpS2R = SM75_U32x4_LDSM_N; + auto tiled_load_kk = make_tiled_copy_C(Copy_Atom{}, kk_tiled_mma); + auto thr_load_kk = tiled_load_kk.get_thread_slice(thread_idx); + auto tKKrKK_cpy = make_fragment_like(tKKrKK_cvt_cv); + auto tKKrKK_cvt = make_fragment_like(tKKrKK_cvt_cv); + auto tKKcMkk_cv = thr_load_kk.retile_D(tKKcMkk); + copy(tiled_load_kk, thr_load_kk.partition_S(sKK_inv_pipe_slice), tKKrKK_cpy); + cute::transform(tKKrKK_cpy, tKKcMkk_cv, tKKrKK_cvt, [&](auto val, auto coord) { + auto [_, t] = coord; + if constexpr (NeedsBeta) { + return Element(float(val) * Beta(t, beta_smem_pipe_read.index())); + } else { + return Element(val); + } + }); + copy(tiled_store_kk, tKKrKK_cvt, recast(tKKsKK)); + } + }; + + auto compute_aux_loop_body = [&](int blk, auto is_final_block_) INLINE_LAMBDA { + constexpr bool is_final_block = decltype(is_final_block_)::value; + + int B = is_final_block ? valid_seq_len(work_desc, blk) : BlkSeqKV; + + Tensor tKKrKK = partition_fragment_C(TiledMmaKK{}, select<0, 1>(TileShapeKK{})); + Tensor tQKrQK = partition_fragment_C(TiledMmaQK{}, select<0, 1>(TileShapeQK{})); + + k_pipeline.consumer_wait(k_smem_pipe_read); + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch KK WGMMA\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + warpgroup_fence_operand(tKKrKK); + warpgroup_arrive(); + gemm_zero_acc(kk_tiled_mma, tKKrA(_, _, _, k_smem_pipe_read.index()), + tKKrB(_, _, _, k_smem_pipe_read.index()), tKKrKK); + warpgroup_commit_batch(); // K@Kt batch + + q_pipeline.consumer_wait(q_smem_pipe_read); + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch QK WGMMA\n", seq_idx, q_head_idx, k_head_idx, + v_head_idx); + warpgroup_fence_operand(tQKrQK); + warpgroup_arrive(); + gemm_zero_acc(qk_tiled_mma, tQKrQ(_, _, _, q_smem_pipe_read.index()), + tQKrK(_, _, _, k_smem_pipe_read.index()), tQKrQK); + warpgroup_commit_batch(); // Q@Kt batch + + // K@Kt and Q@Kt batch finished, we fused masking logic for qk and kk so wait for all of them + warpgroup_wait<0>(); + + k_pipeline.consumer_release(k_smem_pipe_read); + ++k_smem_pipe_read; + q_pipeline.consumer_release(q_smem_pipe_read); + ++q_smem_pipe_read; + + if constexpr (NeedsAlpha) { + alpha_pipeline.consumer_wait(alpha_smem_pipe_read); + } + if constexpr (NeedsBeta) { + beta_pipeline.consumer_wait(beta_smem_pipe_read); + } + cutlass::arch::fence_view_async_shared(); + + qk_and_kk_epi(tQKrQK, tKKrKK, alpha_smem_pipe_read, beta_smem_pipe_read, is_final_block_, B); + + kk_pipeline.producer_acquire(kk_smem_pipe_write); + kk_store_and_inv(tKKrKK, kk_smem_pipe_write); + cutlass::arch::fence_view_async_shared(); + kk_pipeline.producer_commit(kk_smem_pipe_write); + ++kk_smem_pipe_write; + + qk_pipeline.producer_acquire(qk_smem_pipe_write); + qk_store(tQKrQK, qk_smem_pipe_write); + cutlass::arch::fence_view_async_shared(); + qk_pipeline.producer_commit(qk_smem_pipe_write); + ++qk_smem_pipe_write; + + if constexpr (NeedsAlpha) { + alpha_pipeline.consumer_release(alpha_smem_pipe_read); + ++alpha_smem_pipe_read; + } + if constexpr (NeedsBeta) { + beta_pipeline.consumer_release(beta_smem_pipe_read); + ++beta_smem_pipe_read; + } + }; + + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks - 1; ++blk) { + compute_aux_loop_body(blk, /*is_final_block_=*/cute::false_type{}); + } + compute_aux_loop_body(num_blocks - 1, /*is_final_block_=*/cute::true_type{}); + } + + template + CUTE_DEVICE int valid_seq_len(WorkDesc work_desc, int blk_idx) { + int remain_len = work_desc.seq_len - BlkSeqKV * blk_idx; + return remain_len <= BlkSeqKV ? remain_len : BlkSeqKV; + } +}; + +} // namespace flat::collective diff --git a/csrc/flat/hopper/collective/flat_common.hpp b/csrc/flat/hopper/collective/flat_common.hpp new file mode 100644 index 0000000000..df3f66ce54 --- /dev/null +++ b/csrc/flat/hopper/collective/flat_common.hpp @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass/kernel_hardware_info.h" + +namespace flat::collective { + +using namespace cute; + +template +CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + if constexpr (rA == 2 && rB == 2 && rC == 1) { + CUTE_UNROLL + for (int k_block = 0; k_block < size<1>(tA); k_block++) { + cute::gemm(atom, tA(_, k_block), tB(_, k_block), tC); + atom.accumulate_ = GMMA::ScaleOut::One; + } + } else { + static_assert(rA == 3 && rB == 3 && rC == 3); + CUTE_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_, _, k_block), tB(_, _, k_block), tC); + atom.accumulate_ = GMMA::ScaleOut::One; + } + } +} + +template +CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = GMMA::ScaleOut::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template