From eeb4fd3371d9c048940449499605dfcc83ecb3e2 Mon Sep 17 00:00:00 2001 From: FlamingoPg <1106310035@qq.com> Date: Tue, 11 Nov 2025 04:06:24 +0000 Subject: [PATCH 1/7] fp8 kernel modification --- .../sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu | 10 + .../sm90/dense_fp8/flash_fwd_mla_kernel.h | 709 ++++++++++++++++++ csrc/extension/sm90/dense_fp8/flash_mla.h | 22 + .../sm90/dense_fp8/flash_mla_fp8_ops.h | 186 +++++ .../sm90/dense_fp8/fp8_transpose_v.h | 88 +++ csrc/extension/sm90/dense_fp8/named_barrier.h | 21 + csrc/extension/sm90/dense_fp8/softmax.h | 202 +++++ csrc/extension/sm90/dense_fp8/static_switch.h | 70 ++ csrc/extension/sm90/dense_fp8/utils.h | 279 +++++++ csrc/extension/torch_api.cc | 30 + csrc/sm100/decode/sparse_fp8/splitkv_mla.cu | 2 +- csrc/sm90/decode/dense/splitkv_mla.cu | 2 +- csrc/sm90/decode/sparse_fp8/splitkv_mla.cu | 2 +- csrc/smxx/get_mla_metadata.cu | 2 +- csrc/smxx/mla_combine.cu | 2 +- 15 files changed, 1622 insertions(+), 5 deletions(-) create mode 100644 csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu create mode 100644 csrc/extension/sm90/dense_fp8/flash_fwd_mla_kernel.h create mode 100644 csrc/extension/sm90/dense_fp8/flash_mla.h create mode 100644 csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h create mode 100644 csrc/extension/sm90/dense_fp8/fp8_transpose_v.h create mode 100644 csrc/extension/sm90/dense_fp8/named_barrier.h create mode 100644 csrc/extension/sm90/dense_fp8/softmax.h create mode 100644 csrc/extension/sm90/dense_fp8/static_switch.h create mode 100644 csrc/extension/sm90/dense_fp8/utils.h create mode 100644 csrc/extension/torch_api.cc diff --git a/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu new file mode 100644 index 0000000..b87902c --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu @@ -0,0 +1,10 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +#include "flash_fwd_mla_kernel.h" + +#ifndef FLASH_MLA_DISABLE_FP8 +template void run_mha_fwd_splitkv_mla(DecodingParams_fp8 ¶ms, cudaStream_t stream); +#endif \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/flash_fwd_mla_kernel.h b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_kernel.h new file mode 100644 index 0000000..7aefe62 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_kernel.h @@ -0,0 +1,709 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +#pragma once + +#include +#include +#include +#include + +using namespace cute; + +#include "named_barrier.h" +#include "utils.h" +#include "softmax.h" +#include "static_switch.h" +#include "flash_mla.h" +#include "fp8_transpose_v.h" + + +template +constexpr auto getSmemLayoutK() { + constexpr int headSizeBytes = sizeof(PrecType) * DIM; + constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; + + if constexpr (major == GMMA::Major::K) { + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } else { + return GMMA::Layout_K_SW32_Atom{}; + } + } else { + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } else { + return GMMA::Layout_MN_SW32_Atom{}; + } + } + +} + +template +struct Flash_fwd_kernel_traits_mla { + using Element = elem_type; + using ElementO = elem_type_o; + using ElementAccum = float; + using index_t = int64_t; + + static constexpr bool Is_FP8 = cute::is_same_v; + + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + static constexpr int kNWarpsS = 4; + static constexpr int kNThreadsS = kNWarpsS * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; + static_assert(kHeadDimV % 32 == 0); + static_assert(kHeadDimV <= kHeadDim); + + static constexpr int kBlockKSmem = Is_FP8 ? (kHeadDim % 128 == 0 ? 128 : 64) : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kBlockKSmemO = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kSwizzleO = kBlockKSmemO == 32 ? 2 : 3; + + static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K; + + using TiledMma = decltype(make_tiled_mma( + cute::GMMA::ss_op_selector, Int, Int>, + GMMA::Major::K, GMMA::Major::K>(), + Layout, _1, _1>>{})); + + static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; + using TiledMmaO = decltype(make_tiled_mma( + cute::GMMA::rs_op_selector, Int, Int>, + GMMA::Major::K, MmaMajorV>(), + Layout, Int, _1>>{})); + + using SmemLayoutQ = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutK = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutV = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + + using SmemLayoutP = std::conditional_t< + Is_FP8, + Layout, Int, _1, _2, Int>>, + Layout, Int, _1, _2, Int>> + >; + using SmemLayoutRow = Layout>, Stride<_1, _2>>; + + using SmemLayoutAtomO = decltype(composition( + Swizzle{}, + Layout, Int>, Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL; + static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; + static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + + static constexpr int kGmemElemsPerLoadO = sizeof(cute::uint128_t) / sizeof(ElementO); + static_assert(kHeadDim % kGmemElemsPerLoadO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadO"); + static constexpr int kGmemThreadsPerRowO = kBlockKSmemO / kGmemElemsPerLoadO; + static_assert(kNThreadsLoad % kGmemThreadsPerRowO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowO"); + + using GmemLayoutAtom = Layout< + Shape, Int>, + Stride, _1>>; + + + using GmemTiledCopy = decltype(make_tiled_copy( + Copy_Atom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 vals per read + + using GmemLayoutAtomO = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyO = decltype(make_tiled_copy( + Copy_Atom, ElementO>{}, + GmemLayoutAtomO{}, + Layout>>{})); // Val layout, 8 vals per store + + static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); + static constexpr int kGmemThreadsPerRowAccum = kBlockKSmemO / kGmemElemsPerLoadAccum; + using GmemLayoutAtomOaccum = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout>>{})); // Val layout, 4 vals per store + + + // ------ for f8 ------ + using SmemFp8Tranpose = SmemTransposeFp8_64x64; + using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt; +}; + +namespace flash { + +using namespace cute; + +template +struct SharedStorageMLA { + using SmemV_t = std::conditional_t>, + cute::array_aligned>; + union { + struct { + cute::array_aligned> smem_q; + cute::array_aligned * 2> smem_k; // Double buffer + SmemV_t smem_vt; + cute::array_aligned> smem_p; + cute::array_aligned> smem_scale; + }; + struct { + cute::array_aligned> smem_max; + cute::array_aligned> smem_sum; + cute::array_aligned> smem_o; + }; + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void store(const DecodingParams_fp8 ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, + SharedStorage &shared_storage, AccO tOrO, Softmax softmax, float descale_k, float scale_softmax) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + using Element = typename Kernel_traits::ElementO; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + + // Epilogue + + const int split_offset = __ldg(params.num_splits_ptr + bidb); + + Tensor lse = softmax.template normalize_softmax_lse(tOrO, scale_softmax, descale_k); + + using ElementO = std::conditional_t; + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(tOrO); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + __syncthreads(); + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM) * params.d_v; + const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; + const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), + Shape>{}, Stride<_1>{}); + + using GmemTiledCopyO = std::conditional_t; + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + if (tidx >= kNThreadsS) { return; } + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) + Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < params.q_seq_per_hk - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.q_seq_per_hk - m_block * kBlockM + ); +} + +template +__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const DecodingParams_fp8 ¶ms, + const int bidb, const int bidh, const int m_block, + const int n_split_idx, const int seqlen_k, + const int n_block_min, const int n_block_max, const bool NoSplit, + SharedStorage &shared_storage, const float descale_k, const float scale_softmax, const float scale_softmax_log2) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreads = Kernel_traits::kNThreads; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + static_assert(kNThreads == 256 and kNThreadsS == 128); + using Element = typename Kernel_traits::Element; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + int n_block = n_block_max - 1; + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); + + auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); + auto sVt = [&](){ + if constexpr(Kernel_traits::Is_FP8){ + return make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + } + }(); + + Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); + Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _); + Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); + Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS); + Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS); + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N) + Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + clear(tOrO); + + flash::Softmax<2 * size<1>(tOrO)> softmax; + + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if (warp_group_idx == 0) { + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; +#pragma unroll 1 + for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { + __syncthreads(); + + Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + flash::gemm(tiled_mma, tSrQ, tSrK, tSrS); + + const bool is_masking_step = masking_step > 0; + const bool is_first_masking_step = masking_step == n_masking_steps; + + if (is_masking_step) { + Tensor cS = make_identity_tensor(Shape, Int>{}); + Tensor tScS = thr_mma.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if constexpr (!Is_causal) { // Just masking based on col + if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; + } else { + // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / q_head_per_hk + // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / q_head_per_hk + int row = int(get<0>(tScS(i))); + int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk; + if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; + } + } + } + + // We have key_padding_mask so we'll need to Check_inf + Tensor scale_o = is_first_masking_step + ? softmax.template softmax(tSrS, scale_softmax_log2) + : is_masking_step ? + softmax.template softmax(tSrS, scale_softmax_log2) + : softmax.template softmax(tSrS, scale_softmax_log2); + + if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); } + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + + cute::copy(tOrP, tPsP); // send Aregs of MMA1 instead of Cregs of MMA0 + cute::copy(scale_o, tScale_osScale_o); + + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); + + flash::rescale_o(tOrO, scale_o); + + if constexpr (Kernel_traits::Is_FP8) { + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); + __syncthreads(); + } + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + cute::copy(softmax.row_max, tRow_maxsRow_max); + cute::copy(softmax.row_sum, tRow_sumsRow_sum); + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + } else { + const int *block_table = params.block_table + bidb * params.block_table_batch_stride; + int cur_block_table = __ldg(&block_table[n_block]); + + const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + params.q_seq_per_hk - m_block * kBlockM); + + const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K; + auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS); + Tensor tKgK = gmem_thr_copy_K.partition_S(gK); + Tensor tKsK = gmem_thr_copy_K.partition_D(sK); + Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + tKsK.data() = tKsK.data() + sK_offset; + if constexpr (!Kernel_traits::Is_FP8) { + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + // We need to clear the sK smem tiles because K is V. + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, + seqlen_k - n_block * kBlockN); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + + if (n_block - 1 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 1]); + } + +#pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + flash::cp_async_wait<0>(); + __syncthreads(); + + if (n_block - 1 >= n_block_min) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tKsK.data() = tKsK.data() + sK_offset; + + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + } + + if constexpr (Kernel_traits::Is_FP8) { + auto TransV = [&]() { + using SmemFp8Tranpose = typename Kernel_traits::SmemFp8Tranpose; + SmemFp8Tranpose smem_transpose_V; + Tensor sV_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename SmemFp8Tranpose::SmemLayoutTransposeV{})); + Tensor sVt_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename SmemFp8Tranpose::SmemLayoutTransposeVt{})); + + if (n_block % 2 == 1) { + sV_divide.data() = sV_divide.data() + size(sK); + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < shape<2>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < shape<1>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++i) { + smem_transpose_V.transpose(flatten(sV_divide(_, i, j)), flatten(sVt_divide(_, i, j))); + } + } + }; + + TransV(); + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::TransVReady)); + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); + + if (n_block - 2 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 2]); + } + + typename Kernel_traits::TiledMma tiled_mma; + auto tSrS_layout = flash::convert_layout_acc_Aregs(partition_fragment_C(tiled_mma, Shape, Int>{}).layout()); + Tensor tOrP = make_tensor(tSrS_layout); + Tensor scale_o = make_tensor(Shape<_2>{}); + cute::copy(tScale_osScale_o, scale_o); + cute::copy(tPsP, tOrP); + + flash::rescale_o(tOrO, scale_o); + + if constexpr (Kernel_traits::Is_FP8) __syncthreads(); + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + if constexpr (!Kernel_traits::Is_FP8) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + cute::copy(tRow_maxsRow_max, softmax.row_max); + cute::copy(tRow_sumsRow_sum, softmax.row_sum); + } + + if (NoSplit) + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); + else + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); +} + +template +__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) +flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams_fp8 params) { + constexpr int kBlockN = Kernel_traits::kBlockN; + const int m_block = blockIdx.x; + const int bidh = blockIdx.y; + const int partition_idx = blockIdx.z; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int begin_seqlen = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int end_seqlen = tile_scheduler_metadata.w; + if (begin_idx >= params.b || begin_idx < 0) return; + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + + float descale_k = 1.f; + float scale_softmax = params.scale_softmax; + float scale_softmax_log2 = params.scale_softmax_log2; + if constexpr (Kernel_traits::Is_FP8) { + float descale_q = __ldg(params.descale_q_ptr); + descale_k = __ldg(params.descale_k_ptr); + scale_softmax = scale_softmax * descale_q * descale_k; + scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k; + } + +#pragma unroll 1 + for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { + const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; + const int seqlen_k = __ldg(params.seqlens_k_ptr + batch_id); + const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; + const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); + const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); + if (batch_id > begin_idx) { + __syncthreads(); // Barrier between two tiles. + } + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(256, 1, 1) +flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const DecodingParams_fp8 params) { + constexpr int kNThreads = 128; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int hs = params.h_k * params.q_seq_per_hk; + const int batch_idx = bidx / hs; + const int hs_idx = bidx % hs; + + const int split_offset = __ldg(params.num_splits_ptr + batch_idx); + const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset; + FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits); + if (actual_num_splits <= 1) return; + + __shared__ ElementAccum sLseScale[kMaxSplits]; + + const index_t row_offset_lseaccum = split_offset * hs + hs_idx; + const index_t row_offset_lse = bidx; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, make_stride(hs)); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape<_1>{}, Stride<_1>{}); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + float local_lse[kNLsePerThread]; + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY; + } + + float max_lse = -INFINITY; + for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]); + for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); + max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf + + float sum_lse = 0; + for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse); + for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); + + float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse; + if (tidx == 0) gLSE(0) = global_lse; + + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + __syncthreads(); + + static_assert(kHeadDimV % kNThreads == 0); + constexpr int Elements = kHeadDimV / kNThreads; + const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape>{}, Stride<_1>{}); + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + Layout>>{}, + Layout>>{})); + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + for (int split = 0; split < actual_num_splits; ++split) { + cute::copy(tOgOaccum, tOrOaccum); + ElementAccum lse_scale = sLseScale[split]; + for (int i = 0; i < size(tOrO); ++i) { + tOrO(i) += lse_scale * tOrOaccum(i); + } + tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV; + } + + Tensor rO = flash::convert_type(tOrO); + const int head_idx = (bidx - batch_idx * hs) / params.q_seq_per_hk; + const int row = bidx - batch_idx * hs - head_idx * params.q_seq_per_hk; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{}); + cute::copy(rO, gO); +} + +} // namespace flash + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_flash_splitkv_fwd_mla(DecodingParams_fp8 ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); + const int num_m_block = cute::ceil_div(params.q_seq_per_hk, Kernel_traits::kBlockM); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + auto kernel = &flash::flash_fwd_splitkv_mla_kernel; + constexpr size_t smem_size = sizeof(SharedStorage); + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); + + dim3 grid_combine(params.b * params.h_k * params.q_seq_per_hk); + MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { + auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< + typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; + combine_kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template +void run_mha_fwd_splitkv_mla(DecodingParams_fp8 ¶ms, cudaStream_t stream) { + static_assert(Headdim == 576); + FLASH_ASSERT(params.d_v == 512); + using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>; + run_flash_splitkv_fwd_mla>(params, stream); +} \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/flash_mla.h b/csrc/extension/sm90/dense_fp8/flash_mla.h new file mode 100644 index 0000000..4e5e202 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/flash_mla.h @@ -0,0 +1,22 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +#pragma once + +#include "../../../params.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// FP8-specific extension of the original DecodingParams +struct DecodingParams_fp8 : public DecodingParams { + int h_h_k_ratio; + float* __restrict__ descale_q_ptr = nullptr; + float* __restrict__ descale_k_ptr = nullptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_splitkv_mla(DecodingParams_fp8 ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h b/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h new file mode 100644 index 0000000..0a660e3 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h @@ -0,0 +1,186 @@ +#include +#include +#include +#include +#include + +#include "flash_mla.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +std::vector +fwd_kvcache_mla_fp8( + at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) + const int head_size_v, + const at::Tensor &seqlens_k, // batch_size + const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq + const double softmax_scale, + bool is_causal, + const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize + const at::Tensor &num_splits, // batch_size + 1 + const std::optional &descale_q, // None or batch_size + const std::optional &descale_k // None or batch_size +) { + // Check the architecture + auto dprops = at::cuda::getCurrentDeviceProperties(); + TORCH_CHECK(dprops->major == 9 && dprops->minor == 0, "Dense FP8 MLA is only supported on SM90"); + + // Check data types + TORCH_CHECK(q.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(kcache.dtype() == q.dtype(), "query and key must have the same dtype"); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); + TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); + + // Check device + CHECK_DEVICE(q); + CHECK_DEVICE(kcache); + CHECK_DEVICE(seqlens_k); + CHECK_DEVICE(block_table); + CHECK_DEVICE(tile_scheduler_metadata); + CHECK_DEVICE(num_splits); + if (descale_q.has_value()) CHECK_DEVICE(descale_q.value()); + if (descale_k.has_value()) CHECK_DEVICE(descale_k.value()); + + // Check layout + TORCH_CHECK(q.stride(-1) == 1, "q must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "kcache must have contiguous last dimension"); + CHECK_CONTIGUOUS(seqlens_k); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + CHECK_CONTIGUOUS(tile_scheduler_metadata); + CHECK_CONTIGUOUS(num_splits); + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + const int seqlen_q_ori = sizes[1]; + const int num_heads_q = sizes[2]; + const int head_size_k = sizes[3]; + TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); + TORCH_CHECK(head_size_v == 512, "Only head_size_v == 512 is supported"); + + const int max_num_blocks_per_seq = block_table.size(1); + const int num_blocks = kcache.size(0); + const int page_block_size = kcache.size(1); + const int num_heads_k = kcache.size(2); + TORCH_CHECK(page_block_size == 64, "Currently page_block_size must be 64"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8"); + auto descale_q_ = descale_q.value(); + auto descale_k_ = descale_k.value(); + CHECK_DEVICE(descale_q_); + CHECK_DEVICE(descale_k_); + TORCH_CHECK(descale_q_.stride(-1) == 1); + TORCH_CHECK(descale_k_.stride(-1) == 1); + TORCH_CHECK(descale_q_.dtype() == torch::kFloat); + TORCH_CHECK(descale_k_.dtype() == torch::kFloat); + CHECK_SHAPE(descale_q_, 1); + CHECK_SHAPE(descale_k_, 1); + + if (seqlen_q_ori == 1) { is_causal = false; } + + const int num_q_heads_per_hk = num_heads_q / num_heads_k; + const int q_seq_per_hk = seqlen_q_ori * num_q_heads_per_hk; + const int num_heads = num_heads_k; + q = q.view({batch_size, seqlen_q_ori, num_heads_k, num_q_heads_per_hk, head_size_k}).transpose(2, 3) + .reshape({batch_size, q_seq_per_hk, num_heads, head_size_k}); + + CHECK_SHAPE(q, batch_size, q_seq_per_hk, num_heads, head_size_k); + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); + CHECK_SHAPE(seqlens_k, batch_size); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); + CHECK_SHAPE(num_splits, batch_size+1); + + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + caffe2::TypeMeta out_type; + out_type = torch::kBFloat16; + at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts.dtype(out_type)); + at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse); + + // Set up parameters for the dense FP8 kernel + DecodingParams_fp8 params = {}; + // Set the sizes. + params.b = batch_size; + params.s_q = seqlen_q_ori; + params.q_seq_per_hk = q_seq_per_hk; + params.seqlens_k_ptr = seqlens_k.data_ptr(); + params.h_q = num_heads_q; + params.h_k = num_heads_k; + params.num_blocks = num_blocks; + params.q_head_per_hk = num_q_heads_per_hk; + params.is_causal = is_causal; + params.d = head_size_k; + params.d_v = head_size_v; + params.scale_softmax = static_cast(softmax_scale); + params.scale_softmax_log2 = float(static_cast(softmax_scale) * M_LOG2E); + params.topk = -1; // Dense attention + + // FP8-specific parameters + params.h_h_k_ratio = 1; + params.descale_q_ptr = reinterpret_cast(descale_q.value().data_ptr()); + params.descale_k_ptr = reinterpret_cast(descale_k.value().data_ptr()); + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = kcache.data_ptr(); + params.o_ptr = out.data_ptr(); + params.indices_ptr = nullptr; + params.softmax_lse_ptr = softmax_lse.data_ptr(); + + // All stride are in elements, not bytes. + params.q_batch_stride = q.stride(0); + params.k_batch_stride = kcache.stride(0); + params.o_batch_stride = out.stride(0); + params.q_row_stride = q.stride(-3); + params.k_row_stride = kcache.stride(1); + params.o_row_stride = out.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = kcache.stride(2); + params.o_head_stride = out.stride(-2); + params.indices_batch_stride = 0; + params.indices_row_stride = 0; + + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.page_block_size = page_block_size; + + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + params.num_sm_parts = tile_scheduler_metadata.size(0); + params.num_splits_ptr = num_splits.data_ptr(); + + // Set up accumulation tensors + const int total_num_splits = batch_size + params.num_sm_parts; + at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); + CHECK_CONTIGUOUS(softmax_lse_accum); + CHECK_CONTIGUOUS(out_accum); + params.total_num_splits = total_num_splits; + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // Call the actual kernel implementation +#ifdef FLASH_MLA_DISABLE_FP8 + TORCH_CHECK(false, "FlashMLA is compiled with -DFLASH_MLA_DISABLE_FP8. Please remove this flag from your environment and re-compile FlashMLA."); +#else + run_mha_fwd_splitkv_mla(params, stream); +#endif + + // Reshape outputs back to original format + out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3) + .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); + softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) + .reshape({batch_size, num_heads_q, seqlen_q_ori}); + + return {out, softmax_lse}; +} \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/fp8_transpose_v.h b/csrc/extension/sm90/dense_fp8/fp8_transpose_v.h new file mode 100644 index 0000000..9001e1f --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/fp8_transpose_v.h @@ -0,0 +1,88 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + + +/** + * ref to Fa3's SmemTranspose64x64: + * https://github.com/Dao-AILab/flash-attention/blob/0823cf7b5d96499c1c79a4f64b1e256a035ba4b4/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L26 +*/ + +#pragma once + +template +struct SmemTransposeFp8_64x64 { + static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0)); + + using Element = cutlass::float_e4m3_t; + using TransposeShapeAtomV = Shape<_64, _64>; + using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- src layout + using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); + using SmemShapeLDSM = Shape, Shape<_16, _4>>; + using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}))); + using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{}))); + + // For fp8, this is the memory transpose. + using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutVt = + decltype(tile_to_shape(SmemLayoutAtomVt{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- dst layout + using SmemLayoutVtTrans = decltype(composition( + SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{}))); + using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); + using SmemShapeSTSM = Shape, Shape<_16, _4>>; + using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}))); + using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); + + + using ldsm_thread_shape = Shape<_4, _1, _8, _4>; + using ldsm_value_shape = Shape<_2, _8, _2, _1>; + using ldsm_value_stride = Stride<_2, _4, _1, _0>; + using TiledCopyLDSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopyLDSM tiled_copy_ldsm; + + using stsm_thread_shape = Shape<_4, _1, _8, _4>; + // using stsm_thread_stride = Stride<_1, _0, _4, _32>; + using stsm_value_shape = Shape<_4, _4, _2, _1>; + using stsm_value_stride = Stride<_1, _8, _4, _0>; + + using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopySTSM tiled_copy_stsm; + + template + CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) { + using namespace cute; + + auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); + auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); + + auto tXsX = thr_copy_ldsm.partition_S(s_in); + auto tXrX = make_tensor(shape(tXsX)); + auto tXsX_out = thr_copy_stsm.partition_D(s_out); + + cute::copy(tiled_copy_ldsm, tXsX, tXrX); + + auto data = tXrX.data(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size(tXrX); n += 8) { + uint32_t *data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + data_32bit[0] = __byte_perm(upper, lower, 0x6420); + data_32bit[1] = __byte_perm(upper, lower, 0x7531); + } + + cute::copy(tiled_copy_stsm, tXrX, tXsX_out); + } +}; diff --git a/csrc/extension/sm90/dense_fp8/named_barrier.h b/csrc/extension/sm90/dense_fp8/named_barrier.h new file mode 100644 index 0000000..8f2e546 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/named_barrier.h @@ -0,0 +1,21 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +#pragma once + +#include "cutlass/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class NamedBarriers { + SReady = 1, + SoftmaxReady = 2, + TransVReady = 3, +}; + +} // flash \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/softmax.h b/csrc/extension/sm90/dense_fp8/softmax.h new file mode 100644 index 0000000..1996e85 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/softmax.h @@ -0,0 +1,202 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h + +#pragma once + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ auto scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + #endif + } + } + return tensor; +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scale_o); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scale_o; + clear(scale_o); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scale_o(mi) = scores_scale; + row_sum(mi) *= scores_scale; + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + return scale_o; + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float descale_v, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +} // namespace flash \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/static_switch.h b/csrc/extension/sm90/dense_fp8/static_switch.h new file mode 100644 index 0000000..c0f7331 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/static_switch.h @@ -0,0 +1,70 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +#pragma once + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) + + +#define FLASH_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + exit(1); \ + } \ + } while(0) + + +#define FLASH_DEVICE_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ + } while(0) + + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + + +#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ + [&] { \ + if (NUM_SPLITS <= 32) { \ + constexpr static int NAME = 32; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 64) { \ + constexpr static int NAME = 64; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 96) { \ + constexpr static int NAME = 96; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 128) { \ + constexpr static int NAME = 128; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 160) { \ + constexpr static int NAME = 160; \ + return __VA_ARGS__(); \ + } else { \ + FLASH_ASSERT(false); \ + } \ + }() \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/utils.h b/csrc/extension/sm90/dense_fp8/utils.h new file mode 100644 index 0000000..cd6f95b --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/utils.h @@ -0,0 +1,279 @@ +/* + * Taken from FlashMLA PR https://github.com/deepseek-ai/FlashMLA/pull/54 + * originally authored by @endurehero + */ + +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h + +#pragma once + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) +// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) { + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { + auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) + return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } else { + static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); + static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); + static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); + auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) + // This combines the first two modes (<0, 0> and <0, 1>) into one mode. + // Will require register shuffling later to be correct. + return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), + get<1>(acc_layout), + coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) + // This combination is right but doesn't work with register shuffling. + // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), + // get<1>(acc_layout), + // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag_64b); ++mi) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { + cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { + // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. + using From_type = typename Engine::value_type; + using To_type = typename EngineOut::value_type; + static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); + static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); + Tensor frag = recast const>(tensor); + Tensor out_frg = recast>(out); + static_assert(size(frag) == size(out_frg)); + cutlass::NumericArrayConverter convert_op; + #pragma unroll + for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/csrc/extension/torch_api.cc b/csrc/extension/torch_api.cc new file mode 100644 index 0000000..6ebe1f4 --- /dev/null +++ b/csrc/extension/torch_api.cc @@ -0,0 +1,30 @@ +#include + +#include + +extern +std::vector +fwd_kvcache_mla_fp8( + at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) + const int head_size_v, + const at::Tensor &seqlens_k, // batch_size + const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq + const double softmax_scale, + bool is_causal, + const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize + const at::Tensor &num_splits, // batch_size + 1 + const std::optional &descale_q, // None or batch_size + const std::optional &descale_k // None or batch_size +); + +TORCH_LIBRARY(_flashmla_extension_C, m) { + m.def("fwd_kvcache_mla_fp8", &fwd_kvcache_mla_fp8); + m.impl("fwd_kvcache_mla_fp8", torch::kCUDA, &fwd_kvcache_mla_fp8); +} + +PyMODINIT_FUNC PyInit__flashmla_extension_C() { + static struct PyModuleDef module = { + PyModuleDef_HEAD_INIT, "_flashmla_extension_C", nullptr, 0, nullptr}; + return PyModule_Create(&module); +} \ No newline at end of file diff --git a/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu b/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu index 5077961..e78dcd5 100644 --- a/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu +++ b/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu @@ -163,7 +163,7 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p int sched_begin_block_idx = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int sched_end_block_idx = tile_scheduler_metadata.w; - if (begin_idx >= params.b) { + if (begin_idx >= params.b || begin_idx < 0) { if (warp_idx == 0) { cute::TMEM::Allocator1Sm().free(0, 512); } diff --git a/csrc/sm90/decode/dense/splitkv_mla.cu b/csrc/sm90/decode/dense/splitkv_mla.cu index 70e28f4..56e8f12 100644 --- a/csrc/sm90/decode/dense/splitkv_mla.cu +++ b/csrc/sm90/decode/dense/splitkv_mla.cu @@ -1026,7 +1026,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const DecodingParams params, __gr int sched_begin_block_idx = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int sched_end_block_idx = tile_scheduler_metadata.w; - if (begin_idx >= params.b) return; + if (begin_idx >= params.b || begin_idx < 0) return; int begin_n_split_idx = *(tile_scheduler_metadata_ptr + 4); // Copy the first Q diff --git a/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu b/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu index fd7b28c..550451b 100644 --- a/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu +++ b/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu @@ -174,7 +174,7 @@ flash_fwd_splitkv_mla_fp8_sparse_kernel(__grid_constant__ const DecodingParams p int sched_begin_block_idx = tile_scheduler_metadata.y; int end_idx = tile_scheduler_metadata.z; int sched_end_block_idx = tile_scheduler_metadata.w; - if (begin_idx >= params.b) return; + if (begin_idx >= params.b || begin_idx < 0) return; int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); if (warp_idx == 0 && elect_one_sync()) { diff --git a/csrc/smxx/get_mla_metadata.cu b/csrc/smxx/get_mla_metadata.cu index e6b5e63..d46fe53 100644 --- a/csrc/smxx/get_mla_metadata.cu +++ b/csrc/smxx/get_mla_metadata.cu @@ -51,7 +51,7 @@ get_mla_metadata_kernel(__grid_constant__ const GetDecodingMetadataParams params num_splits_shared[0] = 0; for (int i = 0; i < num_sm_parts; ++i) { int tile_scheduler_metadata0[4], tile_scheduler_metadata1; - tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[0] = (now_idx >= batch_size ? -1 : now_idx); tile_scheduler_metadata0[1] = now_block + first_block_idx_shared[now_idx]; tile_scheduler_metadata1 = now_n_split_idx; int remain_payload = payload; diff --git a/csrc/smxx/mla_combine.cu b/csrc/smxx/mla_combine.cu index 3cbaa47..2027432 100644 --- a/csrc/smxx/mla_combine.cu +++ b/csrc/smxx/mla_combine.cu @@ -25,7 +25,7 @@ flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) { const int end_split_idx = __ldg(params.num_splits_ptr + batch_idx + 1); const int my_num_splits = end_split_idx - start_split_idx; FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS); - if (my_num_splits == 1) { + if (my_num_splits <= 1) { return; } From 04f8312989807d4bd513956c7bfd927403ace95f Mon Sep 17 00:00:00 2001 From: FlamingoPg <1106310035@qq.com> Date: Tue, 11 Nov 2025 04:16:41 +0000 Subject: [PATCH 2/7] fp8 kernel modification --- csrc/extension/{torch_api.cc => python_api.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename csrc/extension/{torch_api.cc => python_api.cpp} (100%) diff --git a/csrc/extension/torch_api.cc b/csrc/extension/python_api.cpp similarity index 100% rename from csrc/extension/torch_api.cc rename to csrc/extension/python_api.cpp From b2174aa1661eec728ec1b8d30994afd6a19e6dc3 Mon Sep 17 00:00:00 2001 From: FlamingoPg <1106310035@qq.com> Date: Tue, 11 Nov 2025 07:36:59 +0000 Subject: [PATCH 3/7] fp8 kernel modification --- csrc/extension/python_api.cpp | 2 +- .../extension/sm90/dense_fp8/flash_mla_fp8_ops.h | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/csrc/extension/python_api.cpp b/csrc/extension/python_api.cpp index 6ebe1f4..a8e045b 100644 --- a/csrc/extension/python_api.cpp +++ b/csrc/extension/python_api.cpp @@ -7,7 +7,7 @@ std::vector fwd_kvcache_mla_fp8( at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) - const int head_size_v, + const int64_t head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq const double softmax_scale, diff --git a/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h b/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h index 0a660e3..1ba0fb4 100644 --- a/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h +++ b/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h @@ -14,7 +14,7 @@ std::vector fwd_kvcache_mla_fp8( at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &kcache, // num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True) - const int head_size_v, + const int64_t head_size_v, const at::Tensor &seqlens_k, // batch_size const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq const double softmax_scale, @@ -24,6 +24,8 @@ fwd_kvcache_mla_fp8( const std::optional &descale_q, // None or batch_size const std::optional &descale_k // None or batch_size ) { + int head_size_v_int = static_cast(head_size_v) + // Check the architecture auto dprops = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK(dprops->major == 9 && dprops->minor == 0, "Dense FP8 MLA is only supported on SM90"); @@ -60,7 +62,7 @@ fwd_kvcache_mla_fp8( const int num_heads_q = sizes[2]; const int head_size_k = sizes[3]; TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported"); - TORCH_CHECK(head_size_v == 512, "Only head_size_v == 512 is supported"); + TORCH_CHECK(head_size_v_int == 512, "Only head_size_v == 512 is supported"); const int max_num_blocks_per_seq = block_table.size(1); const int num_blocks = kcache.size(0); @@ -102,7 +104,7 @@ fwd_kvcache_mla_fp8( auto opts = q.options(); caffe2::TypeMeta out_type; out_type = torch::kBFloat16; - at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts.dtype(out_type)); + at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v_int}, opts.dtype(out_type)); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); CHECK_CONTIGUOUS(softmax_lse); @@ -119,7 +121,7 @@ fwd_kvcache_mla_fp8( params.q_head_per_hk = num_q_heads_per_hk; params.is_causal = is_causal; params.d = head_size_k; - params.d_v = head_size_v; + params.d_v = head_size_v_int; params.scale_softmax = static_cast(softmax_scale); params.scale_softmax_log2 = float(static_cast(softmax_scale) * M_LOG2E); params.topk = -1; // Dense attention @@ -160,7 +162,7 @@ fwd_kvcache_mla_fp8( // Set up accumulation tensors const int total_num_splits = batch_size + params.num_sm_parts; at::Tensor softmax_lse_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat)); - at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({total_num_splits, num_heads, q_seq_per_hk, head_size_v_int}, opts.dtype(at::kFloat)); CHECK_CONTIGUOUS(softmax_lse_accum); CHECK_CONTIGUOUS(out_accum); params.total_num_splits = total_num_splits; @@ -177,8 +179,8 @@ fwd_kvcache_mla_fp8( #endif // Reshape outputs back to original format - out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v}).transpose(2, 3) - .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v}); + out = out.view({batch_size, seqlen_q_ori, num_q_heads_per_hk, num_heads_k, head_size_v_int}).transpose(2, 3) + .reshape({batch_size, seqlen_q_ori, num_heads_q, head_size_v_int}); softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, num_q_heads_per_hk}).transpose(2, 3) .reshape({batch_size, num_heads_q, seqlen_q_ori}); From b9aad20e49e6139d8198e7b5ddce1607f27142e7 Mon Sep 17 00:00:00 2001 From: FlamingoPg <1106310035@qq.com> Date: Tue, 11 Nov 2025 08:42:46 +0000 Subject: [PATCH 4/7] fix some error --- csrc/extension/python_api.cpp | 17 ++-- .../sm90/dense_fp8/flash_fwd_mla_metadata.cu | 77 +++++++++++++++++++ csrc/extension/sm90/dense_fp8/flash_mla.h | 13 +++- .../sm90/dense_fp8/flash_mla_fp8_ops.h | 38 +++++++++ csrc/smxx/mla_combine.cu | 8 +- 5 files changed, 138 insertions(+), 15 deletions(-) create mode 100644 csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu diff --git a/csrc/extension/python_api.cpp b/csrc/extension/python_api.cpp index a8e045b..d5d9908 100644 --- a/csrc/extension/python_api.cpp +++ b/csrc/extension/python_api.cpp @@ -18,13 +18,10 @@ fwd_kvcache_mla_fp8( const std::optional &descale_k // None or batch_size ); -TORCH_LIBRARY(_flashmla_extension_C, m) { - m.def("fwd_kvcache_mla_fp8", &fwd_kvcache_mla_fp8); - m.impl("fwd_kvcache_mla_fp8", torch::kCUDA, &fwd_kvcache_mla_fp8); -} - -PyMODINIT_FUNC PyInit__flashmla_extension_C() { - static struct PyModuleDef module = { - PyModuleDef_HEAD_INIT, "_flashmla_extension_C", nullptr, 0, nullptr}; - return PyModule_Create(&module); -} \ No newline at end of file +extern +std::vector +get_mla_decoding_metadata_dense_fp8( + at::Tensor &seqlens_k, + const int num_heads_per_head_k, + const int num_heads_k +); diff --git a/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu new file mode 100644 index 0000000..96c2bd3 --- /dev/null +++ b/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu @@ -0,0 +1,77 @@ +#include "flash_fwd_mla_kernel.h" + +static constexpr int MaxBatchSize = 4096; + +__global__ void __launch_bounds__(256, 1, 1) +get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + __shared__ int num_blocks_shared[MaxBatchSize]; + __shared__ int num_splits_shared[MaxBatchSize]; + + int total_num_blocks = 0; + for (int i = threadIdx.x; i < batch_size; i += 32) { + int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + num_blocks_shared[i] = num_blocks; + } + for (int offset = 16; offset >= 1; offset /= 2) { + total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); + } + __syncwarp(); + + if (threadIdx.x == 0) { + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + num_splits_shared[0] = 0; + for (int i = 0; i < num_sm_parts; ++i) { + int tile_scheduler_metadata0[4], tile_scheduler_metadata1; + tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata1 = now_n_split_idx; + int remain_payload = payload; + while (now_idx < batch_size) { + int num_blocks = num_blocks_shared[now_idx]; + int now_remain_blocks = num_blocks - now_block; + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_shared[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + } + __syncwarp(); + + for (int i = threadIdx.x; i <= batch_size; i += 32) { + num_splits_ptr[i] = num_splits_shared[i]; + } +} + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.batch_size < MaxBatchSize); + get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/flash_mla.h b/csrc/extension/sm90/dense_fp8/flash_mla.h index 4e5e202..64c8b53 100644 --- a/csrc/extension/sm90/dense_fp8/flash_mla.h +++ b/csrc/extension/sm90/dense_fp8/flash_mla.h @@ -19,4 +19,15 @@ struct DecodingParams_fp8 : public DecodingParams { //////////////////////////////////////////////////////////////////////////////////////////////////// template -void run_mha_fwd_splitkv_mla(DecodingParams_fp8 ¶ms, cudaStream_t stream); \ No newline at end of file +void run_mha_fwd_splitkv_mla(DecodingParams_fp8 ¶ms, cudaStream_t stream); + +struct Mla_metadata_params { + int *__restrict__ seqlens_k_ptr; + int *__restrict__ tile_scheduler_metadata_ptr; + int *__restrict__ num_splits_ptr; + int batch_size; + int block_size_n; + int fixed_overhead_num_blocks; + int num_sm_parts; +}; +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h b/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h index 1ba0fb4..e9a89f7 100644 --- a/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h +++ b/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h @@ -2,6 +2,7 @@ #include #include #include +#include #include #include "flash_mla.h" @@ -185,4 +186,41 @@ fwd_kvcache_mla_fp8( .reshape({batch_size, num_heads_q, seqlen_q_ori}); return {out, softmax_lse}; +} + +std::vector +get_mla_decoding_metadata_dense_fp8( + at::Tensor &seqlens_k, + const int num_heads_per_head_k, + const int num_heads_k +) { + // This should match the logic in the MLA kernel. + static constexpr int block_size_m = 64; + static constexpr int block_size_n = 64; + static constexpr int fixed_overhead_num_blocks = 5; + CHECK_DEVICE(seqlens_k); + TORCH_CHECK(seqlens_k.is_contiguous()); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); + int batch_size = seqlens_k.size(0); + int *seqlens_k_ptr = seqlens_k.data_ptr(); + auto options = seqlens_k.options(); + auto dprops = at::cuda::getCurrentDeviceProperties(); + int sm_count = dprops->multiProcessorCount; + int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m); + auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); + auto num_splits = torch::empty({batch_size + 1}, options); + int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + int *num_splits_ptr = num_splits.data_ptr(); + at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + Mla_metadata_params params = {}; + params.seqlens_k_ptr = seqlens_k_ptr; + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; + params.num_splits_ptr = num_splits_ptr; + params.batch_size = batch_size; + params.block_size_n = block_size_n; + params.fixed_overhead_num_blocks = fixed_overhead_num_blocks; + params.num_sm_parts = num_sm_parts; + get_mla_metadata_func(params, stream); + return {tile_scheduler_metadata, num_splits}; } \ No newline at end of file diff --git a/csrc/smxx/mla_combine.cu b/csrc/smxx/mla_combine.cu index 2027432..0e4a652 100644 --- a/csrc/smxx/mla_combine.cu +++ b/csrc/smxx/mla_combine.cu @@ -13,11 +13,11 @@ using namespace cute; template __global__ void __launch_bounds__(NUM_THREADS) flash_fwd_mla_combine_kernel(__grid_constant__ const DecodingParams params) { - // grid_shape: [batch_size, num_q_heads*s_q / BLOCK_SIZE_M] + // grid_shape: [num_q_heads*s_q / BLOCK_SIZE_M, batch_size] // Each CTA gathers the activation of some heads from one batch, do scaling & accumulation, and save the result static_assert(NUM_THREADS/32 == BLOCK_SIZE_M); // The number of warps == block_size_m - const int batch_idx = blockIdx.x; - const int m_block_idx = blockIdx.y; + const int batch_idx = blockIdx.y; + const int m_block_idx = blockIdx.x; const int warp_idx = threadIdx.x / 32; const int lane_idx = threadIdx.x % 32; @@ -189,7 +189,7 @@ void run_flash_mla_combine_kernel(DecodingParams ¶ms, cudaStream_t stream) { attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attribute[0].val.programmaticStreamSerializationAllowed = 1; cudaLaunchConfig_t combine_kernel_config = { - dim3(params.b, cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), 1), + dim3(cute::ceil_div(params.h_k*params.q_seq_per_hk, BLOCK_SIZE_M), params.b, 1), dim3(NUM_THREADS, 1, 1), smem_size, stream, From fcb8f3efbfe034961a063486c46b859fd3a24e85 Mon Sep 17 00:00:00 2001 From: FlamingoPg <1106310035@qq.com> Date: Tue, 11 Nov 2025 08:48:58 +0000 Subject: [PATCH 5/7] fix some errors --- csrc/extension/python_api.cpp | 4 ++-- csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/csrc/extension/python_api.cpp b/csrc/extension/python_api.cpp index d5d9908..2ef29bb 100644 --- a/csrc/extension/python_api.cpp +++ b/csrc/extension/python_api.cpp @@ -22,6 +22,6 @@ extern std::vector get_mla_decoding_metadata_dense_fp8( at::Tensor &seqlens_k, - const int num_heads_per_head_k, - const int num_heads_k + const int64_t num_heads_per_head_k, + const int64_t num_heads_k ); diff --git a/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h b/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h index e9a89f7..59e8cd4 100644 --- a/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h +++ b/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h @@ -191,9 +191,11 @@ fwd_kvcache_mla_fp8( std::vector get_mla_decoding_metadata_dense_fp8( at::Tensor &seqlens_k, - const int num_heads_per_head_k, - const int num_heads_k + const int64_t num_heads_per_head_k, + const int64_t num_heads_k ) { + int num_heads_per_head_k_int = static_cast(num_heads_per_head_k); + int num_heads_k_int = static_cast(num_heads_k); // This should match the logic in the MLA kernel. static constexpr int block_size_m = 64; static constexpr int block_size_n = 64; @@ -206,7 +208,7 @@ get_mla_decoding_metadata_dense_fp8( auto options = seqlens_k.options(); auto dprops = at::cuda::getCurrentDeviceProperties(); int sm_count = dprops->multiProcessorCount; - int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m); + int num_sm_parts = sm_count / num_heads_k_int / cutlass::ceil_div(num_heads_per_head_k_int, block_size_m); auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); auto num_splits = torch::empty({batch_size + 1}, options); int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); From 0f3fbb688debb12a6e35a58eae2c7fe956264b96 Mon Sep 17 00:00:00 2001 From: FlamingoPg <1106310035@qq.com> Date: Tue, 11 Nov 2025 15:43:18 +0000 Subject: [PATCH 6/7] fix some errors --- .../dense_fp8/{flash_mla_fp8_ops.h => dense_fp8_python_api.cpp} | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename csrc/extension/sm90/dense_fp8/{flash_mla_fp8_ops.h => dense_fp8_python_api.cpp} (99%) diff --git a/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h b/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp similarity index 99% rename from csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h rename to csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp index 59e8cd4..d8159ec 100644 --- a/csrc/extension/sm90/dense_fp8/flash_mla_fp8_ops.h +++ b/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp @@ -25,7 +25,7 @@ fwd_kvcache_mla_fp8( const std::optional &descale_q, // None or batch_size const std::optional &descale_k // None or batch_size ) { - int head_size_v_int = static_cast(head_size_v) + int head_size_v_int = static_cast(head_size_v); // Check the architecture auto dprops = at::cuda::getCurrentDeviceProperties(); From 75fda5312d6c94803d588f6fd50b752df4b2e8f7 Mon Sep 17 00:00:00 2001 From: FlamingoPg <1106310035@qq.com> Date: Tue, 11 Nov 2025 15:52:11 +0000 Subject: [PATCH 7/7] fix libtorch --- csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp b/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp index d8159ec..8ce97be 100644 --- a/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp +++ b/csrc/extension/sm90/dense_fp8/dense_fp8_python_api.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include