diff --git a/README.md b/README.md index b79757c..2e1ebff 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. Currently released: -- BF16, FP16 +- BF16, FP16, E4M3 - Paged kvcache with block size of 64 ## Quick start diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index d2567fe..d6f8108 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -68,7 +68,9 @@ mha_fwd_kvcache_mla( const float softmax_scale, bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize - const at::Tensor &num_splits // batch_size + 1 + const at::Tensor &num_splits, // batch_size + 1 + c10::optional &descale_q_, // batch_size + c10::optional &descale_k_ // batch_size ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; @@ -76,8 +78,9 @@ mha_fwd_kvcache_mla( at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; - auto q_dtype = q.dtype(); - TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + auto q_dtype = q.scalar_type(); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf || q_dtype == torch::kFloat8_e4m3fn); + TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -104,6 +107,20 @@ mha_fwd_kvcache_mla( TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (q_dtype == torch::kFloat8_e4m3fn) { + TORCH_CHECK(descale_q_.has_value() && descale_k_.has_value(), "descale is required when input dtype is fp8"); + auto descale_q = descale_q_.value(); + auto descale_k = descale_k_.value(); + CHECK_DEVICE(descale_q); + CHECK_DEVICE(descale_k); + TORCH_CHECK(descale_q.stride(-1) == 1); + TORCH_CHECK(descale_k.stride(-1) == 1); + TORCH_CHECK(descale_q.dtype() == torch::kFloat); + TORCH_CHECK(descale_k.dtype() == torch::kFloat); + CHECK_SHAPE(descale_q, 1); + CHECK_SHAPE(descale_k, 1); + } + if (seqlen_q_ori == 1) { is_causal = false; } const int ngroups = num_heads_ori / num_heads_k; @@ -127,7 +144,8 @@ mha_fwd_kvcache_mla( at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts); + auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype; + at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); Flash_fwd_mla_params params = {}; @@ -167,6 +185,11 @@ mha_fwd_kvcache_mla( params.block_table_batch_stride = block_table.stride(0); params.page_block_size = page_block_size; + if (q_dtype == torch::kFloat8_e4m3fn) { + params.descale_q_ptr = reinterpret_cast(descale_q_.value().data_ptr()); + params.descale_k_ptr = reinterpret_cast(descale_k_.value().data_ptr()); + } + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); CHECK_DEVICE(tile_scheduler_metadata); @@ -186,12 +209,18 @@ mha_fwd_kvcache_mla( auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); + if (q_dtype == torch::kBFloat16) { - run_mha_fwd_splitkv_mla(params, stream); + run_mha_fwd_splitkv_mla(params, stream); } #ifndef FLASH_MLA_DISABLE_FP16 else if (q_dtype == torch::kHalf) { - run_mha_fwd_splitkv_mla(params, stream); + run_mha_fwd_splitkv_mla(params, stream); + } + #endif + #ifndef FLASH_MLA_DISABLE_FP8 + else if (q_dtype == torch::kFloat8_e4m3fn) { + run_mha_fwd_splitkv_mla(params, stream); } #endif else { diff --git a/csrc/flash_fwd_mla_bf16_sm90.cu b/csrc/flash_fwd_mla_bf16_sm90.cu index 35691f2..4990c48 100644 --- a/csrc/flash_fwd_mla_bf16_sm90.cu +++ b/csrc/flash_fwd_mla_bf16_sm90.cu @@ -1,3 +1,3 @@ #include "flash_fwd_mla_kernel.h" -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_fp16_sm90.cu b/csrc/flash_fwd_mla_fp16_sm90.cu index abdaf7b..a7f09b8 100644 --- a/csrc/flash_fwd_mla_fp16_sm90.cu +++ b/csrc/flash_fwd_mla_fp16_sm90.cu @@ -1,3 +1,3 @@ #include "flash_fwd_mla_kernel.h" -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_fp8_sm90.cu b/csrc/flash_fwd_mla_fp8_sm90.cu new file mode 100644 index 0000000..b678962 --- /dev/null +++ b/csrc/flash_fwd_mla_fp8_sm90.cu @@ -0,0 +1,3 @@ +#include "flash_fwd_mla_kernel.h" + +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index d96acd8..6e92b5e 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -12,27 +12,42 @@ using namespace cute; #include "softmax.h" #include "static_switch.h" #include "flash_mla.h" +#include "fp8_transpose_v.h" -template +template constexpr auto getSmemLayoutK() { constexpr int headSizeBytes = sizeof(PrecType) * DIM; constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; - if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { - return GMMA::Layout_K_SW128_Atom{}; - } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { - return GMMA::Layout_K_SW64_Atom{}; + if constexpr (major == GMMA::Major::K) { + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } else { + return GMMA::Layout_K_SW32_Atom{}; + } } else { - return GMMA::Layout_K_SW32_Atom{}; + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } else { + return GMMA::Layout_MN_SW32_Atom{}; + } } + } -template +template struct Flash_fwd_kernel_traits_mla { using Element = elem_type; + using ElementO = elem_type_o; using ElementAccum = float; using index_t = int64_t; + + static constexpr bool Is_FP8 = cute::is_same_v; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 32; @@ -46,8 +61,12 @@ struct Flash_fwd_kernel_traits_mla { static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 32 == 0); static_assert(kHeadDimV <= kHeadDim); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int kBlockKSmem = Is_FP8 ? (kHeadDim % 128 == 0 ? 128 : 64) : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kBlockKSmemO = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kSwizzleO = kBlockKSmemO == 32 ? 2 : 3; + + static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K; using TiledMma = decltype(make_tiled_mma( cute::GMMA::ss_op_selector, Int, Int>, @@ -57,7 +76,7 @@ struct Flash_fwd_kernel_traits_mla { static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; using TiledMmaO = decltype(make_tiled_mma( cute::GMMA::rs_op_selector, Int, Int>, - GMMA::Major::K, GMMA::Major::MN>(), + GMMA::Major::K, MmaMajorV>(), Layout, Int, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( @@ -73,16 +92,20 @@ struct Flash_fwd_kernel_traits_mla { Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - using SmemLayoutP = Layout, Int, _1, Int>>; + using SmemLayoutP = std::conditional_t< + Is_FP8, + Layout, Int, _1, _2, Int>>, + Layout, Int, _1, _2, Int>> + >; using SmemLayoutRow = Layout>, Stride<_1, _2>>; using SmemLayoutAtomO = decltype(composition( - Swizzle{}, - Layout, Int>, Stride, _1>>{})); + Swizzle{}, + Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); @@ -92,31 +115,43 @@ struct Flash_fwd_kernel_traits_mla { static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + static constexpr int kGmemElemsPerLoadO = sizeof(cute::uint128_t) / sizeof(ElementO); + static_assert(kHeadDim % kGmemElemsPerLoadO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadO"); + static constexpr int kGmemThreadsPerRowO = kBlockKSmemO / kGmemElemsPerLoadO; + static_assert(kNThreadsLoad % kGmemThreadsPerRowO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowO"); + using GmemLayoutAtom = Layout< Shape, Int>, Stride, _1>>; + + using GmemTiledCopy = decltype(make_tiled_copy( Copy_Atom{}, GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read + Layout>>{})); // Val layout, 8 vals per read using GmemLayoutAtomO = Layout< - Shape, Int>, - Stride, _1>>; + Shape, Int>, + Stride, _1>>; using GmemTiledCopyO = decltype(make_tiled_copy( - Copy_Atom, Element>{}, + Copy_Atom, ElementO>{}, GmemLayoutAtomO{}, - Layout>{})); // Val layout, 8 vals per store + Layout>>{})); // Val layout, 8 vals per store static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); - static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum; + static constexpr int kGmemThreadsPerRowAccum = kBlockKSmemO / kGmemElemsPerLoadAccum; using GmemLayoutAtomOaccum = Layout< Shape, Int>, Stride, _1>>; using GmemTiledCopyOaccum = decltype(make_tiled_copy( Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store + Layout>>{})); // Val layout, 4 vals per store + + + // ------ for f8 ------ + using SmemFp8Tranpose = SmemTransposeFp8_64x64; + using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt; }; namespace flash { @@ -125,10 +160,14 @@ using namespace cute; template struct SharedStorageMLA { + using SmemV_t = std::conditional_t>, + cute::array_aligned>; union { struct { cute::array_aligned> smem_q; cute::array_aligned * 2> smem_k; // Double buffer + SmemV_t smem_vt; cute::array_aligned> smem_p; cute::array_aligned> smem_scale; }; @@ -144,11 +183,11 @@ struct SharedStorageMLA { template __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, - SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { + SharedStorage &shared_storage, AccO tOrO, Softmax softmax, float descale_k, float scale_softmax) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; constexpr int kNThreadsS = Kernel_traits::kNThreadsS; - using Element = typename Kernel_traits::Element; + using Element = typename Kernel_traits::ElementO; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; @@ -161,7 +200,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const const int split_offset = __ldg(params.num_splits_ptr + bidb); - Tensor lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse(tOrO, scale_softmax, descale_k); using ElementO = std::conditional_t; Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -233,7 +272,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, - SharedStorage &shared_storage) { + SharedStorage &shared_storage, const float descale_k, const float scale_softmax, const float scale_softmax_log2) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; @@ -249,11 +288,18 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + + auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); + auto sVt = [&](){ + if constexpr(Kernel_traits::Is_FP8){ + return make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + } + }(); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); - Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); + Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _); Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); @@ -279,8 +325,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if (n_block % 2 == 1) { // Double buffer for sK constexpr int sK_offset = size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; + + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } } // We need masking on S for the very last block when K and V has length not multiple of kBlockN. @@ -318,26 +369,37 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // We have key_padding_mask so we'll need to Check_inf Tensor scale_o = is_first_masking_step - ? softmax.template softmax(tSrS, params.scale_softmax_log2) + ? softmax.template softmax(tSrS, scale_softmax_log2) : is_masking_step ? - softmax.template softmax(tSrS, params.scale_softmax_log2) - : softmax.template softmax(tSrS, params.scale_softmax_log2); - - Tensor rP = flash::convert_type(tSrS); - cute::copy(rP, tPsP); + softmax.template softmax(tSrS, scale_softmax_log2) + : softmax.template softmax(tSrS, scale_softmax_log2); + + if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); } + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + + cute::copy(tOrP, tPsP); // send Aregs of MMA1 instead of Cregs of MMA0 cute::copy(scale_o, tScale_osScale_o); cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); flash::rescale_o(tOrO, scale_o); - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + if constexpr (Kernel_traits::Is_FP8) { + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); + __syncthreads(); + } flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } } cute::copy(softmax.row_max, tRow_maxsRow_max); @@ -379,7 +441,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Double buffer for sK constexpr int sK_offset = size(sK); tKsK.data() = tKsK.data() + sK_offset; - tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (!Kernel_traits::Is_FP8) { + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } } // We need to clear the sK smem tiles because K is V. @@ -411,6 +475,32 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cute::cp_async_fence(); } + if constexpr (Kernel_traits::Is_FP8) { + auto TransV = [&]() { + using SmemFp8Tranpose = typename Kernel_traits::SmemFp8Tranpose; + SmemFp8Tranpose smem_transpose_V; + Tensor sV_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename SmemFp8Tranpose::SmemLayoutTransposeV{})); + Tensor sVt_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename SmemFp8Tranpose::SmemLayoutTransposeVt{})); + + if (n_block % 2 == 1) { + sV_divide.data() = sV_divide.data() + size(sK); + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < shape<2>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < shape<1>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++i) { + smem_transpose_V.transpose(flatten(sV_divide(_, i, j)), flatten(sVt_divide(_, i, j))); + } + } + }; + + TransV(); + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::TransVReady)); + } + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); if (n_block - 2 >= n_block_min) { @@ -418,20 +508,22 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f } typename Kernel_traits::TiledMma tiled_mma; - auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout(); - Tensor rP = make_tensor(tSrS_layout); + auto tSrS_layout = flash::convert_layout_acc_Aregs(partition_fragment_C(tiled_mma, Shape, Int>{}).layout()); + Tensor tOrP = make_tensor(tSrS_layout); Tensor scale_o = make_tensor(Shape<_2>{}); cute::copy(tScale_osScale_o, scale_o); - cute::copy(tPsP, rP); + cute::copy(tPsP, tOrP); flash::rescale_o(tOrO, scale_o); - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + if constexpr (Kernel_traits::Is_FP8) __syncthreads(); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (!Kernel_traits::Is_FP8) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } } cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); @@ -440,9 +532,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f } if (NoSplit) - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); else - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); } template @@ -465,6 +557,16 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (begin_idx >= params.b) return; int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + float descale_k = 1.f; + float scale_softmax = params.scale_softmax; + float scale_softmax_log2 = params.scale_softmax_log2; + if constexpr (Kernel_traits::Is_FP8) { + float descale_q = __ldg(params.descale_q_ptr); + descale_k = __ldg(params.descale_k_ptr); + scale_softmax = scale_softmax * descale_q * descale_k; + scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k; + } + #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; @@ -475,7 +577,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } - flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2); } } @@ -587,17 +689,17 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream dim3 grid_combine(params.b * params.h * params.seqlen_q); MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< - typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; + typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; combine_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { static_assert(Headdim == 576); FLASH_ASSERT(params.d_v == 512); FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV - using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; + using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>; run_flash_splitkv_fwd_mla>(params, stream); } diff --git a/csrc/flash_mla.h b/csrc/flash_mla.h index 2994cb7..b7e2fed 100644 --- a/csrc/flash_mla.h +++ b/csrc/flash_mla.h @@ -17,6 +17,9 @@ struct Flash_fwd_mla_params { void *__restrict__ o_ptr; void *__restrict__ softmax_lse_ptr; + float* __restrict__ descale_q_ptr = nullptr; + float* __restrict__ descale_k_ptr = nullptr; + index_t q_batch_stride; index_t k_batch_stride; index_t v_batch_stride; @@ -47,7 +50,7 @@ static constexpr int TileSchedulerMetaDataSize = 8; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); struct Mla_metadata_params { diff --git a/csrc/fp8_transpose_v.h b/csrc/fp8_transpose_v.h new file mode 100644 index 0000000..40bb4d5 --- /dev/null +++ b/csrc/fp8_transpose_v.h @@ -0,0 +1,83 @@ +/** + * ref to Fa3's SmemTranspose64x64: + * https://github.com/Dao-AILab/flash-attention/blob/0823cf7b5d96499c1c79a4f64b1e256a035ba4b4/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L26 +*/ + +#pragma once + +template +struct SmemTransposeFp8_64x64 { + static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0)); + + using Element = cutlass::float_e4m3_t; + using TransposeShapeAtomV = Shape<_64, _64>; + using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- src layout + using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); + using SmemShapeLDSM = Shape, Shape<_16, _4>>; + using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}))); + using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{}))); + + // For fp8, this is the memory transpose. + using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutVt = + decltype(tile_to_shape(SmemLayoutAtomVt{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- dst layout + using SmemLayoutVtTrans = decltype(composition( + SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{}))); + using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); + using SmemShapeSTSM = Shape, Shape<_16, _4>>; + using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}))); + using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); + + + using ldsm_thread_shape = Shape<_4, _1, _8, _4>; + using ldsm_value_shape = Shape<_2, _8, _2, _1>; + using ldsm_value_stride = Stride<_2, _4, _1, _0>; + using TiledCopyLDSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopyLDSM tiled_copy_ldsm; + + using stsm_thread_shape = Shape<_4, _1, _8, _4>; + // using stsm_thread_stride = Stride<_1, _0, _4, _32>; + using stsm_value_shape = Shape<_4, _4, _2, _1>; + using stsm_value_stride = Stride<_1, _8, _4, _0>; + + using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopySTSM tiled_copy_stsm; + + template + CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) { + using namespace cute; + + auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); + auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); + + auto tXsX = thr_copy_ldsm.partition_S(s_in); + auto tXrX = make_tensor(shape(tXsX)); + auto tXsX_out = thr_copy_stsm.partition_D(s_out); + + cute::copy(tiled_copy_ldsm, tXsX, tXrX); + + auto data = tXrX.data(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size(tXrX); n += 8) { + uint32_t *data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + data_32bit[0] = __byte_perm(upper, lower, 0x6420); + data_32bit[1] = __byte_perm(upper, lower, 0x7531); + } + + cute::copy(tiled_copy_stsm, tXrX, tXsX_out); + } +}; + diff --git a/csrc/named_barrier.h b/csrc/named_barrier.h index cefa936..940c934 100644 --- a/csrc/named_barrier.h +++ b/csrc/named_barrier.h @@ -10,6 +10,7 @@ namespace flash { enum class NamedBarriers { SReady = 1, SoftmaxReady = 2, + TransVReady = 3, }; } // flash diff --git a/csrc/softmax.h b/csrc/softmax.h index 4ab6ae9..bcb8cac 100644 --- a/csrc/softmax.h +++ b/csrc/softmax.h @@ -174,7 +174,7 @@ struct Softmax { }; template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float descale_v, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); @@ -184,7 +184,7 @@ struct Softmax { #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = row_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; #pragma unroll diff --git a/csrc/utils.h b/csrc/utils.h index 3b8dd52..716c50c 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -235,4 +235,40 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor +CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag_64b); ++mi) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { + cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { + // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. + using From_type = typename Engine::value_type; + using To_type = typename EngineOut::value_type; + static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); + static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); + Tensor frag = recast const>(tensor); + Tensor out_frg = recast>(out); + static_assert(size(frag) == size(out_frg)); + cutlass::NumericArrayConverter convert_op; + #pragma unroll + for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index b2922af..736ac69 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -33,6 +33,8 @@ def flash_mla_with_kvcache( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -45,8 +47,10 @@ def flash_mla_with_kvcache( num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. - - Returns: + descale_q: (batch_size), torch.float. dequant scale for query + descale_k: (batch_size), torch.float. dequant scale for key + + Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ @@ -63,5 +67,6 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, + descale_q, descale_k, ) return out, softmax_lse diff --git a/setup.py b/setup.py index cd311f2..0b971c4 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ ) DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_MLA_DISABLE_FP8", "FALSE") == "TRUE" def append_nvcc_threads(nvcc_extra_args): @@ -28,6 +29,8 @@ def get_sources(): if not DISABLE_FP16: sources.append("csrc/flash_fwd_mla_fp16_sm90.cu") + if not DISABLE_FP8: + sources.append("csrc/flash_fwd_mla_fp8_sm90.cu") return sources @@ -36,6 +39,8 @@ def get_features_args(): features_args = [] if DISABLE_FP16: features_args.append("-DFLASH_MLA_DISABLE_FP16") + if DISABLE_FP8: + features_args.append("-DFLASH_MLA_DISABLE_FP8") return features_args @@ -73,7 +78,8 @@ def get_features_args(): "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", - "--ptxas-options=-v,--register-usage-level=10" + "--ptxas-options=-v,--register-usage-level=10", + "--ftemplate-backtrace-limit=0" ] + cc_flag ) + get_features_args(), diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 0abe9d2..0cd173c 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -28,21 +28,26 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): return attn_weight @ value, lse -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None: x, y = x.double(), y.double() RMSE = ((x - y) * (x - y)).mean().sqrt().item() cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) amax_diff = (x - y).abs().max().item() # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - assert cos_diff < 1e-5 + + if use_fp8: + assert cos_diff < 1e-2 + else: + assert cos_diff < 1e-5 @torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtype): print( - f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}" + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}" ) + use_fp8 = torch_dtype == torch.float8_e4m3fn cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: for i in range(b): @@ -68,7 +73,31 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): tile_scheduler_metadata, num_splits = get_mla_metadata( cache_seqlens, s_q * h_q // h_kv, h_kv ) - + + init_dtype = q.dtype + def prepare_fp8_input(): + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None + + if use_fp8: + nonlocal q, blocked_k, blocked_v + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q_fp8 = q.to(fp8_dtype) + blocked_k_fp8 = blocked_k.to(fp8_dtype) + blocked_v_fp8 = blocked_v.to(fp8_dtype) + + return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k + + + + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input() + if use_fp8: + q = q_fp8 + blocked_k = blocked_k_fp8 + blocked_v = blocked_v_fp8 + def flash_mla(): return flash_mla_with_kvcache( q, @@ -79,18 +108,23 @@ def flash_mla(): tile_scheduler_metadata, num_splits, causal=causal, + descale_q=descale_q, + descale_k=descale_k, ) def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), h_q=h_q, h_kv=h_kv, is_causal=causal, @@ -101,14 +135,12 @@ def ref_mla(): out_flash, lse_flash = flash_mla() out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") + cal_diff(out_flash, out_torch, "out", use_fp8) cal_diff(lse_flash, lse_torch, "lse") t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(q.dtype).bits // 8 - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) print( f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" ) @@ -116,7 +148,8 @@ def ref_mla(): def main(torch_dtype): device = torch.device("cuda:0") - torch.set_default_dtype(torch_dtype) + init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype + torch.set_default_dtype(init_dtype) torch.set_default_device(device) torch.cuda.set_device(device) torch.manual_seed(0) @@ -124,14 +157,14 @@ def main(torch_dtype): h_kv = 1 d, dv = 576, 512 - causal = True + causal = False for b in [128]: for s in [4096, 8192]: for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 for s_q in [1, 2]: # MTP = 1, 2 for varlen in [False, True]: - test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) + test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, torch_dtype) if __name__ == "__main__": @@ -139,9 +172,9 @@ def main(torch_dtype): parser.add_argument( "--dtype", type=str, - choices=["bf16", "fp16"], + choices=["bf16", "fp16", "e4m3"], default="bf16", - help="Data type to use for testing (bf16 or fp16)", + help="Data type to use for testing (bf16/fp16/e4m3)", ) args = parser.parse_args() @@ -149,5 +182,7 @@ def main(torch_dtype): torch_dtype = torch.bfloat16 if args.dtype == "fp16": torch_dtype = torch.float16 + elif args.dtype == "e4m3": + torch_dtype = torch.float8_e4m3fn main(torch_dtype)