From dae0690055f2b5d453404d239753978a9a6233cf Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Mon, 24 Feb 2025 21:12:36 +0800 Subject: [PATCH 01/30] init fp8 --- csrc/flash_fwd_mla_fp8_sm90.cu | 3 +++ setup.py | 1 + 2 files changed, 4 insertions(+) create mode 100644 csrc/flash_fwd_mla_fp8_sm90.cu diff --git a/csrc/flash_fwd_mla_fp8_sm90.cu b/csrc/flash_fwd_mla_fp8_sm90.cu new file mode 100644 index 0000000..2384a30 --- /dev/null +++ b/csrc/flash_fwd_mla_fp8_sm90.cu @@ -0,0 +1,3 @@ +#include "flash_fwd_mla_kernel.h" + +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/setup.py b/setup.py index 0a3bd17..c622b7c 100644 --- a/setup.py +++ b/setup.py @@ -37,6 +37,7 @@ def append_nvcc_threads(nvcc_extra_args): sources=[ "csrc/flash_api.cpp", "csrc/flash_fwd_mla_bf16_sm90.cu", + "csrc/flash_fwd_mla_fp8_sm90.cu", ], extra_compile_args={ "cxx": cxx_args, From d833dbd7111e44139dec9615bb544e7d956a856f Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Tue, 25 Feb 2025 09:03:02 +0800 Subject: [PATCH 02/30] enable fp8 --- csrc/flash_fwd_mla_kernel.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 55f6811..9262632 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -33,6 +33,8 @@ struct Flash_fwd_kernel_traits_mla { using Element = elem_type; using ElementAccum = float; using index_t = int64_t; + + static constexpr bool Is_FP8 = cute::is_same_v; static constexpr int kNWarps = kNWarps_; static constexpr int kNThreads = kNWarps * 32; @@ -49,6 +51,8 @@ struct Flash_fwd_kernel_traits_mla { static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K; + using TiledMma = decltype(make_tiled_mma( cute::GMMA::ss_op_selector, Int, Int>, GMMA::Major::K, GMMA::Major::K>(), @@ -57,7 +61,7 @@ struct Flash_fwd_kernel_traits_mla { static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; using TiledMmaO = decltype(make_tiled_mma( cute::GMMA::rs_op_selector, Int, Int>, - GMMA::Major::K, GMMA::Major::MN>(), + GMMA::Major::K, MmaMajorV>(), Layout, Int, _1>>{})); using SmemLayoutQ = decltype(tile_to_shape( From b67a18f850387c5b68e8c8548f9d91c5960709f0 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Tue, 25 Feb 2025 09:40:56 +0800 Subject: [PATCH 03/30] update gmem --- csrc/flash_api.cpp | 2 +- csrc/flash_fwd_mla_bf16_sm90.cu | 2 +- csrc/flash_fwd_mla_fp8_sm90.cu | 2 +- csrc/flash_fwd_mla_kernel.h | 40 ++++++++++++++++++++------------- csrc/flash_mla.h | 2 +- 5 files changed, 29 insertions(+), 19 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 5a1cb8e..1f44b68 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -186,7 +186,7 @@ mha_fwd_kvcache_mla( auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); - run_mha_fwd_splitkv_mla(params, stream); + run_mha_fwd_splitkv_mla(params, stream); out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); diff --git a/csrc/flash_fwd_mla_bf16_sm90.cu b/csrc/flash_fwd_mla_bf16_sm90.cu index 35691f2..4990c48 100644 --- a/csrc/flash_fwd_mla_bf16_sm90.cu +++ b/csrc/flash_fwd_mla_bf16_sm90.cu @@ -1,3 +1,3 @@ #include "flash_fwd_mla_kernel.h" -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_fwd_mla_fp8_sm90.cu b/csrc/flash_fwd_mla_fp8_sm90.cu index 2384a30..b678962 100644 --- a/csrc/flash_fwd_mla_fp8_sm90.cu +++ b/csrc/flash_fwd_mla_fp8_sm90.cu @@ -1,3 +1,3 @@ #include "flash_fwd_mla_kernel.h" -template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); \ No newline at end of file +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 9262632..e83e9cc 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -28,9 +28,10 @@ constexpr auto getSmemLayoutK() { } } -template +template struct Flash_fwd_kernel_traits_mla { using Element = elem_type; + using ElementO = elem_type_o; using ElementAccum = float; using index_t = int64_t; @@ -48,8 +49,10 @@ struct Flash_fwd_kernel_traits_mla { static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; static_assert(kHeadDimV % 32 == 0); static_assert(kHeadDimV <= kHeadDim); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int kBlockKSmem = Is_FP8 ? (kHeadDim % 128 == 0 ? 128 : 64) : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kBlockKSmemO = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kSwizzleO = kBlockKSmemO == 32 ? 2 : 3; static constexpr cute::GMMA::Major MmaMajorV = !Is_FP8 ? GMMA::Major::MN : GMMA::Major::K; @@ -81,12 +84,12 @@ struct Flash_fwd_kernel_traits_mla { using SmemLayoutRow = Layout>, Stride<_1, _2>>; using SmemLayoutAtomO = decltype(composition( - Swizzle{}, - Layout, Int>, Stride, _1>>{})); + Swizzle{}, + Layout, Int>, Stride, _1>>{})); using SmemLayoutO = decltype(tile_to_shape( SmemLayoutAtomO{}, Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); @@ -96,31 +99,38 @@ struct Flash_fwd_kernel_traits_mla { static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + static constexpr int kGmemElemsPerLoadO = sizeof(cute::uint128_t) / sizeof(ElementO); + static_assert(kHeadDim % kGmemElemsPerLoadO == 0, "kHeadDim must be a multiple of kGmemElemsPerLoadO"); + static constexpr int kGmemThreadsPerRowO = kBlockKSmemO / kGmemElemsPerLoadO; + static_assert(kNThreadsLoad % kGmemThreadsPerRowO == 0, "kNThreads must be a multiple of kGmemThreadsPerRowO"); + using GmemLayoutAtom = Layout< Shape, Int>, Stride, _1>>; + + using GmemTiledCopy = decltype(make_tiled_copy( Copy_Atom{}, GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read + Layout>>{})); // Val layout, 8 vals per read using GmemLayoutAtomO = Layout< - Shape, Int>, - Stride, _1>>; + Shape, Int>, + Stride, _1>>; using GmemTiledCopyO = decltype(make_tiled_copy( - Copy_Atom, Element>{}, + Copy_Atom, ElementO>{}, GmemLayoutAtomO{}, - Layout>{})); // Val layout, 8 vals per store + Layout>>{})); // Val layout, 8 vals per store static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); - static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum; + static constexpr int kGmemThreadsPerRowAccum = kBlockKSmemO / kGmemElemsPerLoadAccum; using GmemLayoutAtomOaccum = Layout< Shape, Int>, Stride, _1>>; using GmemTiledCopyOaccum = decltype(make_tiled_copy( Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, - Layout>{})); // Val layout, 4 vals per store + Layout>>{})); // Val layout, 4 vals per store }; namespace flash { @@ -597,12 +607,12 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream CHECK_CUDA_KERNEL_LAUNCH(); } -template +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { static_assert(Headdim == 576); FLASH_ASSERT(params.d_v == 512); FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV - using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; + using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>; run_flash_splitkv_fwd_mla>(params, stream); } diff --git a/csrc/flash_mla.h b/csrc/flash_mla.h index 2994cb7..a2ef414 100644 --- a/csrc/flash_mla.h +++ b/csrc/flash_mla.h @@ -47,7 +47,7 @@ static constexpr int TileSchedulerMetaDataSize = 8; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); struct Mla_metadata_params { From fed0499301edaa47f8b74639b7018d7e1f694bf6 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Tue, 25 Feb 2025 11:08:28 +0800 Subject: [PATCH 04/30] fp8 shared mem --- csrc/flash_fwd_mla_kernel.h | 8 ++++++++ setup.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index e83e9cc..1af3eb7 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -80,6 +80,10 @@ struct Flash_fwd_kernel_traits_mla { Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtMMa = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int >{})); + using SmemLayoutP = Layout, Int, _1, Int>>; using SmemLayoutRow = Layout>, Stride<_1, _2>>; @@ -139,10 +143,14 @@ using namespace cute; template struct SharedStorageMLA { + using SmemV_t = std::conditional_t * 2>, + cute::array_aligned>; union { struct { cute::array_aligned> smem_q; cute::array_aligned * 2> smem_k; // Double buffer + SmemV_t smem_vt; // Double buffer cute::array_aligned> smem_p; cute::array_aligned> smem_scale; }; diff --git a/setup.py b/setup.py index c622b7c..bfe931f 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def append_nvcc_threads(nvcc_extra_args): sources=[ "csrc/flash_api.cpp", "csrc/flash_fwd_mla_bf16_sm90.cu", - "csrc/flash_fwd_mla_fp8_sm90.cu", + #"csrc/flash_fwd_mla_fp8_sm90.cu", ], extra_compile_args={ "cxx": cxx_args, From 7409203f44dd54b8f51734e05c1fe7789c2d86a9 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Tue, 25 Feb 2025 17:48:07 +0800 Subject: [PATCH 05/30] enable fp8 compile --- csrc/flash_fwd_mla_kernel.h | 11 +++++++++-- setup.py | 5 +++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 1af3eb7..fb53f79 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -135,6 +135,11 @@ struct Flash_fwd_kernel_traits_mla { Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>>{})); // Val layout, 4 vals per store + + + + // for fp8 trans-v + }; namespace flash { @@ -170,7 +175,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; constexpr int kNThreadsS = Kernel_traits::kNThreadsS; - using Element = typename Kernel_traits::Element; + using Element = typename Kernel_traits::ElementO; using ElementAccum = typename Kernel_traits::ElementAccum; using index_t = typename Kernel_traits::index_t; @@ -272,7 +277,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); - Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + auto sVt = cute::conditional_return( + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{}), + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtransposed{})); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); diff --git a/setup.py b/setup.py index bfe931f..8b11e00 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def append_nvcc_threads(nvcc_extra_args): sources=[ "csrc/flash_api.cpp", "csrc/flash_fwd_mla_bf16_sm90.cu", - #"csrc/flash_fwd_mla_fp8_sm90.cu", + "csrc/flash_fwd_mla_fp8_sm90.cu", ], extra_compile_args={ "cxx": cxx_args, @@ -55,7 +55,8 @@ def append_nvcc_threads(nvcc_extra_args): "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", - "--ptxas-options=-v,--register-usage-level=10" + "--ptxas-options=-v,--register-usage-level=10", + "--ftemplate-backtrace-limit=0" ] + cc_flag ), From c50d29d1702482fccc3a356892ea4c6f38be90de Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Tue, 25 Feb 2025 21:30:39 +0800 Subject: [PATCH 06/30] fix compile --- csrc/flash_fwd_mla_kernel.h | 76 --------------------------------- csrc/flash_mla_utils.cu | 85 +++++++++++++++++++++++++++++++++++++ setup.py | 1 + 3 files changed, 86 insertions(+), 76 deletions(-) create mode 100644 csrc/flash_mla_utils.cu diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index fb53f79..d4940f1 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -630,79 +630,3 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, To, 512>; run_flash_splitkv_fwd_mla>(params, stream); } - -static constexpr int MaxBatchSize = 4096; - -__global__ void __launch_bounds__(256, 1, 1) -get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { - int *seqlens_k_ptr = params.seqlens_k_ptr; - int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; - int *num_splits_ptr = params.num_splits_ptr; - int batch_size = params.batch_size; - int block_size_n = params.block_size_n; - int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; - int num_sm_parts = params.num_sm_parts; - - __shared__ int num_blocks_shared[MaxBatchSize]; - __shared__ int num_splits_shared[MaxBatchSize]; - - int total_num_blocks = 0; - for (int i = threadIdx.x; i < batch_size; i += 32) { - int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); - total_num_blocks += num_blocks + fixed_overhead_num_blocks; - num_blocks_shared[i] = num_blocks; - } - for (int offset = 16; offset >= 1; offset /= 2) { - total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); - } - __syncwarp(); - - if (threadIdx.x == 0) { - int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; - - int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; - num_splits_shared[0] = 0; - for (int i = 0; i < num_sm_parts; ++i) { - int tile_scheduler_metadata0[4], tile_scheduler_metadata1; - tile_scheduler_metadata0[0] = now_idx; - tile_scheduler_metadata0[1] = now_block * block_size_n; - tile_scheduler_metadata1 = now_n_split_idx; - int remain_payload = payload; - while (now_idx < batch_size) { - int num_blocks = num_blocks_shared[now_idx]; - int now_remain_blocks = num_blocks - now_block; - if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { - cum_num_splits += now_n_split_idx + 1; - num_splits_shared[now_idx + 1] = cum_num_splits; - remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; - ++now_idx; - now_block = 0; - now_n_split_idx = 0; - } else { - if (remain_payload - fixed_overhead_num_blocks > 0) { - now_block += remain_payload - fixed_overhead_num_blocks; - ++now_n_split_idx; - remain_payload = 0; - } - break; - } - } - tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; - tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; - *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); - tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; - } - FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); - } - __syncwarp(); - - for (int i = threadIdx.x; i <= batch_size; i += 32) { - num_splits_ptr[i] = num_splits_shared[i]; - } -} - -void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { - FLASH_ASSERT(params.batch_size < MaxBatchSize); - get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); - CHECK_CUDA_KERNEL_LAUNCH(); -} diff --git a/csrc/flash_mla_utils.cu b/csrc/flash_mla_utils.cu new file mode 100644 index 0000000..38c74e4 --- /dev/null +++ b/csrc/flash_mla_utils.cu @@ -0,0 +1,85 @@ +#include +#include +#include + +using namespace cute; + +#include "flash_mla.h" +#include "static_switch.h" +#include "utils.h" + +static constexpr int MaxBatchSize = 4096; + +__global__ void __launch_bounds__(256, 1, 1) +get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + __shared__ int num_blocks_shared[MaxBatchSize]; + __shared__ int num_splits_shared[MaxBatchSize]; + + int total_num_blocks = 0; + for (int i = threadIdx.x; i < batch_size; i += 32) { + int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + num_blocks_shared[i] = num_blocks; + } + for (int offset = 16; offset >= 1; offset /= 2) { + total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); + } + __syncwarp(); + + if (threadIdx.x == 0) { + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + num_splits_shared[0] = 0; + for (int i = 0; i < num_sm_parts; ++i) { + int tile_scheduler_metadata0[4], tile_scheduler_metadata1; + tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata1 = now_n_split_idx; + int remain_payload = payload; + while (now_idx < batch_size) { + int num_blocks = num_blocks_shared[now_idx]; + int now_remain_blocks = num_blocks - now_block; + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_shared[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + } + __syncwarp(); + + for (int i = threadIdx.x; i <= batch_size; i += 32) { + num_splits_ptr[i] = num_splits_shared[i]; + } +} + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.batch_size < MaxBatchSize); + get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} \ No newline at end of file diff --git a/setup.py b/setup.py index 8b11e00..c184953 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ def append_nvcc_threads(nvcc_extra_args): name="flash_mla_cuda", sources=[ "csrc/flash_api.cpp", + "csrc/flash_mla_utils.cu", "csrc/flash_fwd_mla_bf16_sm90.cu", "csrc/flash_fwd_mla_fp8_sm90.cu", ], From dfe8ffc75abbd1ac2a1f3b342777ab7354099fc4 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Tue, 25 Feb 2025 22:34:01 +0800 Subject: [PATCH 07/30] enable fp8 api --- csrc/flash_api.cpp | 21 +++++++++++++++------ csrc/flash_fwd_mla_kernel.h | 4 ++-- flash_mla/flash_mla_interface.py | 1 + 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 1f44b68..4be3c1c 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -68,7 +68,10 @@ mha_fwd_kvcache_mla( const float softmax_scale, bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize - const at::Tensor &num_splits // batch_size + 1 + const at::Tensor &num_splits, // batch_size + 1 + c10::optional &descale_q, // batch_size + c10::optional &descale_k, // batch_size + c10::optional &descale_v // batch_size ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; @@ -76,9 +79,9 @@ mha_fwd_kvcache_mla( at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; - auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kBFloat16); - TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + auto q_dtype = q.scalar_type(); + TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn); + TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype"); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -128,7 +131,8 @@ mha_fwd_kvcache_mla( at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts); + auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype; + at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); Flash_fwd_mla_params params = {}; @@ -186,7 +190,12 @@ mha_fwd_kvcache_mla( auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); - run_mha_fwd_splitkv_mla(params, stream); + + if (q_dtype == torch::kFloat8_e4m3fn) { + run_mha_fwd_splitkv_mla(params, stream); + } else { + run_mha_fwd_splitkv_mla(params, stream); + } out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index d4940f1..261a275 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -278,8 +278,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); auto sVt = cute::conditional_return( - make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{}), - make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtransposed{})); + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}), + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 2f3aa46..33c0657 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -63,5 +63,6 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, + None, None, None, ) return out, softmax_lse From 870418802ab1a0de85dd691efa9ab1380541f339 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Tue, 25 Feb 2025 23:29:18 +0800 Subject: [PATCH 08/30] add fp8 ut --- csrc/flash_api.cpp | 3 +-- flash_mla/flash_mla_interface.py | 8 ++++++-- tests/test_flash_mla.py | 24 +++++++++++++++++++++--- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 4be3c1c..a20f408 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -70,8 +70,7 @@ mha_fwd_kvcache_mla( const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits, // batch_size + 1 c10::optional &descale_q, // batch_size - c10::optional &descale_k, // batch_size - c10::optional &descale_v // batch_size + c10::optional &descale_k // batch_size ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; diff --git a/flash_mla/flash_mla_interface.py b/flash_mla/flash_mla_interface.py index 33c0657..f249315 100644 --- a/flash_mla/flash_mla_interface.py +++ b/flash_mla/flash_mla_interface.py @@ -33,6 +33,8 @@ def flash_mla_with_kvcache( num_splits: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, + descale_q: Optional[torch.Tensor] = None, + descale_k: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -45,7 +47,9 @@ def flash_mla_with_kvcache( num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. - + descale_q: (batch_size), torch.float. dequant scale for query + descale_k: (batch_size), torch.float. dequant scale for key + Return: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. @@ -63,6 +67,6 @@ def flash_mla_with_kvcache( causal, tile_scheduler_metadata, num_splits, - None, None, None, + descale_q, descale_k, ) return out, softmax_lse diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 8db5db0..37cbb10 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -37,7 +37,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: @torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = False): print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) @@ -59,11 +59,28 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen): blocked_v = blocked_k[..., :dv] tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + + + descale_q, descale_k = None, None + if use_fp8: + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((b), dtype=torch.float32) + descale_k = torch.ones((b), dtype=torch.float32) + + q_fp8 = q.to(fp8_dtype) + blocked_k_fp8 = blocked_k.to(fp8_dtype) + blocked_v_fp8 = blocked_v.to(fp8_dtype) + q = q_fp8.to(q.dtype) + blocked_k = blocked_k_fp8.to(blocked_k.dtype) + blocked_v = blocked_v_fp8.to(blocked_v.dtype) def flash_mla(): + q_ = q_fp8 if use_fp8 else q + blocked_k_ = blocked_k_fp8 if use_fp8 else blocked_k return flash_mla_with_kvcache( - q, blocked_k, block_table, cache_seqlens, dv, + q_, blocked_k_, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, causal=causal, + descale_q=descale_q, descale_k=descale_k, ) def ref_mla(): @@ -107,10 +124,11 @@ def ref_mla(): h_kv = 1 d, dv = 576, 512 causal = True + use_fp8 = False for b in [128]: for s in [4096, 8192]: for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 for s_q in [1, 2]: # MTP = 1, 2 for varlen in [False, True]: - test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) + test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8) From ef644a56e0976cb74d2a57efabded73dfc05deeb Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 08:13:56 +0800 Subject: [PATCH 09/30] update ut --- tests/test_flash_mla.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 37cbb10..f700864 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -61,22 +61,30 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) - descale_q, descale_k = None, None - if use_fp8: - fp8_dtype = torch.float8_e4m3fn - descale_q = torch.ones((b), dtype=torch.float32) - descale_k = torch.ones((b), dtype=torch.float32) - - q_fp8 = q.to(fp8_dtype) - blocked_k_fp8 = blocked_k.to(fp8_dtype) - blocked_v_fp8 = blocked_v.to(fp8_dtype) - q = q_fp8.to(q.dtype) - blocked_k = blocked_k_fp8.to(blocked_k.dtype) - blocked_v = blocked_v_fp8.to(blocked_v.dtype) + def prepare_fp8_input(): + q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None + + if use_fp8: + nonlocal q, blocked_k, blocked_v + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((b), dtype=torch.float32) + descale_k = torch.ones((b), dtype=torch.float32) + + q_fp8 = q.to(fp8_dtype) + blocked_k_fp8 = blocked_k.to(fp8_dtype) + blocked_v_fp8 = blocked_v.to(fp8_dtype) + + q = q_fp8.to(q.dtype) * descale_q + blocked_k = blocked_k_fp8.to(blocked_k.dtype) * descale_k + blocked_v = blocked_v_fp8.to(blocked_v.dtype) * descale_k + return q_fp8, blocked_k_fp8, descale_q, descale_k + + + q_fp8, blocked_k_fp8, descale_q, descale_k = prepare_fp8_input() def flash_mla(): - q_ = q_fp8 if use_fp8 else q - blocked_k_ = blocked_k_fp8 if use_fp8 else blocked_k + q_ = q; blocked_k_ = blocked_k + if use_fp8: q_ = q_fp8; blocked_k_ = blocked_k_fp8 return flash_mla_with_kvcache( q_, blocked_k_, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, causal=causal, From 4b314cd655bdceb7c81d6573d3bcff8c8fb6a283 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 08:32:05 +0800 Subject: [PATCH 10/30] update fp8 api --- csrc/flash_api.cpp | 28 ++++++++++++++++++++++++---- csrc/flash_mla.h | 3 +++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index a20f408..5a0caa1 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -69,8 +69,8 @@ mha_fwd_kvcache_mla( bool is_causal, const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize const at::Tensor &num_splits, // batch_size + 1 - c10::optional &descale_q, // batch_size - c10::optional &descale_k // batch_size + c10::optional &descale_q_, // batch_size + c10::optional &descale_k_ // batch_size ) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm90 = dprops->major == 9 && dprops->minor == 0; @@ -81,6 +81,7 @@ mha_fwd_kvcache_mla( auto q_dtype = q.scalar_type(); TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn); TORCH_CHECK(kcache.scalar_type() == q_dtype, "query and key must have the same dtype"); + bool is_fp8 = q_dtype == torch::kFloat8_e4m3fn; CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); @@ -107,6 +108,20 @@ mha_fwd_kvcache_mla( TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (is_fp8) { + TORCH_CHECK(descale_q_.has_value() && descale_k_.has_value(), "descale is required when input dtype is fp8"); + auto descale_q = descale_q_.value(); + auto descale_k = descale_k_.value(); + CHECK_DEVICE(descale_q); + CHECK_DEVICE(descale_k); + TORCH_CHECK(descale_q.stride(-1) == 1); + TORCH_CHECK(descale_k.stride(-1) == 1); + TORCH_CHECK(descale_q.dtype() == torch::kFloat); + TORCH_CHECK(descale_k.dtype() == torch::kFloat); + CHECK_SHAPE(descale_q, batch_size); + CHECK_SHAPE(descale_k, batch_size); + } + if (seqlen_q_ori == 1) { is_causal = false; } const int ngroups = num_heads_ori / num_heads_k; @@ -130,7 +145,7 @@ mha_fwd_kvcache_mla( at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype; + auto out_type = is_fp8 ? torch::kBFloat16 : q_dtype; at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type)); at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); @@ -171,6 +186,11 @@ mha_fwd_kvcache_mla( params.block_table_batch_stride = block_table.stride(0); params.page_block_size = page_block_size; + if (is_fp8) { + params.descale_q_ptr = reinterpret_cast(descale_q_.value().data_ptr()); + params.descale_k_ptr = reinterpret_cast(descale_k_.value().data_ptr()); + } + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); CHECK_DEVICE(tile_scheduler_metadata); @@ -190,7 +210,7 @@ mha_fwd_kvcache_mla( auto stream = at::cuda::getCurrentCUDAStream().stream(); TORCH_CHECK(head_size == 576); - if (q_dtype == torch::kFloat8_e4m3fn) { + if (is_fp8) { run_mha_fwd_splitkv_mla(params, stream); } else { run_mha_fwd_splitkv_mla(params, stream); diff --git a/csrc/flash_mla.h b/csrc/flash_mla.h index a2ef414..b7e2fed 100644 --- a/csrc/flash_mla.h +++ b/csrc/flash_mla.h @@ -17,6 +17,9 @@ struct Flash_fwd_mla_params { void *__restrict__ o_ptr; void *__restrict__ softmax_lse_ptr; + float* __restrict__ descale_q_ptr = nullptr; + float* __restrict__ descale_k_ptr = nullptr; + index_t q_batch_stride; index_t k_batch_stride; index_t v_batch_stride; From f6fab1b915eb398ebc0ced017e2300b57c62eff3 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 10:17:29 +0800 Subject: [PATCH 11/30] change to use per_tensor --- csrc/flash_api.cpp | 4 ++-- csrc/flash_fwd_mla_kernel.h | 10 ++++++++-- tests/test_flash_mla.py | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 5a0caa1..9631b32 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -118,8 +118,8 @@ mha_fwd_kvcache_mla( TORCH_CHECK(descale_k.stride(-1) == 1); TORCH_CHECK(descale_q.dtype() == torch::kFloat); TORCH_CHECK(descale_k.dtype() == torch::kFloat); - CHECK_SHAPE(descale_q, batch_size); - CHECK_SHAPE(descale_k, batch_size); + CHECK_SHAPE(descale_q, 1); + CHECK_SHAPE(descale_k, 1); } if (seqlen_q_ori == 1) { is_causal = false; } diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 261a275..874aded 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -260,7 +260,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, - SharedStorage &shared_storage) { + SharedStorage &shared_storage, const float descale_q, const float descale_k) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; @@ -494,6 +494,12 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (begin_idx >= params.b) return; int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + float descale_q, descale_k; + if constexpr (Kernel_traits::Is_FP8) { + descale_q = __ldg(params.descale_q_ptr); + descale_k = __ldg(params.descale_k_ptr); + } + #pragma unroll 1 for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; @@ -504,7 +510,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } - flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_q, descale_k); } } diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index f700864..5c68dba 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -67,8 +67,8 @@ def prepare_fp8_input(): if use_fp8: nonlocal q, blocked_k, blocked_v fp8_dtype = torch.float8_e4m3fn - descale_q = torch.ones((b), dtype=torch.float32) - descale_k = torch.ones((b), dtype=torch.float32) + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) q_fp8 = q.to(fp8_dtype) blocked_k_fp8 = blocked_k.to(fp8_dtype) From 29de9e0c79180b633cd57c8d2eb5b5ff8fb77475 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 16:03:17 +0800 Subject: [PATCH 12/30] debug mode --- tests/test_flash_mla.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 5c68dba..91bd6f1 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -33,7 +33,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) amax_diff = (x - y).abs().max().item() # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - assert cos_diff < 1e-5 + #assert cos_diff < 1e-5 @torch.inference_mode() @@ -131,12 +131,12 @@ def ref_mla(): h_kv = 1 d, dv = 576, 512 - causal = True - use_fp8 = False - - for b in [128]: - for s in [4096, 8192]: - for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 - for s_q in [1, 2]: # MTP = 1, 2 - for varlen in [False, True]: + causal = False + use_fp8 = True + + for b in [16]: + for s in [4096]: + for h_q in [128]: # TP = 8, 4, 2, 1 + for s_q in [2]: # MTP = 1, 2 + for varlen in [False]: test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8) From 59f691763e6fe64942cd5119d2fe642a714b41a2 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 17:25:52 +0800 Subject: [PATCH 13/30] fix Vt illegal --- csrc/flash_fwd_mla_kernel.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 874aded..7493f79 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -149,13 +149,13 @@ using namespace cute; template struct SharedStorageMLA { using SmemV_t = std::conditional_t * 2>, + cute::array_aligned>, cute::array_aligned>; union { struct { cute::array_aligned> smem_q; cute::array_aligned * 2> smem_k; // Double buffer - SmemV_t smem_vt; // Double buffer + SmemV_t smem_vt; cute::array_aligned> smem_p; cute::array_aligned> smem_scale; }; @@ -309,7 +309,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Double buffer for sK constexpr int sK_offset = size(sK); tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; } // We need masking on S for the very last block when K and V has length not multiple of kBlockN. @@ -366,7 +366,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Double buffer for sK const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); tSrK.data() = tSrK.data() + sK_offset / 8; - tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; } cute::copy(softmax.row_max, tRow_maxsRow_max); @@ -408,7 +408,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Double buffer for sK constexpr int sK_offset = size(sK); tKsK.data() = tKsK.data() + sK_offset; - tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; } // We need to clear the sK smem tiles because K is V. @@ -460,7 +460,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Double buffer for sK const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; } cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); From 6a4eb631e2b0b7b8986f1485dcf90236ac41120a Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 17:57:00 +0800 Subject: [PATCH 14/30] add transv barrier --- csrc/flash_fwd_mla_kernel.h | 8 ++++++++ csrc/named_barrier.h | 1 + 2 files changed, 9 insertions(+) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 7493f79..07a4a64 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -360,6 +360,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f flash::rescale_o(tOrO, scale_o); + if constexpr (Kernel_traits::Is_FP8) { + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); + } + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); @@ -440,6 +444,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cute::cp_async_fence(); } + if constexpr (Kernel_traits::Is_FP8) { + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::TransVReady)); + } + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); if (n_block - 2 >= n_block_min) { diff --git a/csrc/named_barrier.h b/csrc/named_barrier.h index cefa936..940c934 100644 --- a/csrc/named_barrier.h +++ b/csrc/named_barrier.h @@ -10,6 +10,7 @@ namespace flash { enum class NamedBarriers { SReady = 1, SoftmaxReady = 2, + TransVReady = 3, }; } // flash From 6dcea4952c6d1a1c7e7ae74c8f1018cbf0117d3f Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 18:37:36 +0800 Subject: [PATCH 15/30] add TransV --- csrc/flash_fwd_mla_kernel.h | 76 ++++++++++++++++++++++++++++++++++--- 1 file changed, 71 insertions(+), 5 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 07a4a64..79d6ba7 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -276,10 +276,14 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); + + auto sV = cute::conditional_return( + cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{})), + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{})); + auto sVt = cute::conditional_return( - make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}), - make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}), + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); @@ -377,6 +381,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cute::copy(softmax.row_sum, tRow_sumsRow_sum); cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); } else { + const int warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int cur_block_table = __ldg(&block_table[n_block]); @@ -412,7 +417,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Double buffer for sK constexpr int sK_offset = size(sK); tKsK.data() = tKsK.data() + sK_offset; - if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (!Kernel_traits::Is_FP8) { + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + else { + sV.data() = sV.data() + sK_offset; + } } // We need to clear the sK smem tiles because K is V. @@ -445,6 +455,58 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f } if constexpr (Kernel_traits::Is_FP8) { + auto TransV = [&]() { + // refer to fa3's TransV: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L697 + using LDSM_divide_shape = Shape<_64, _8>; + using S2RTiledCopyVt = decltype(make_tiled_copy( + Copy_Atom{}, + Layout, Stride<_4, _1, _0, _0>>{}, // thread layout + Layout, Stride<_1, _2, _16, _4>>{} // val layout + )); + + using STSM_divide_shape = Shape<_8, _16>; + using R2STiledCopyV = decltype(make_tiled_copy( + Copy_Atom{}, + Layout, Stride<_4, _1, _32, _0>>{}, // thread layout + Layout, Stride<_0, _1, _4, _8>>{} // val layout + )); + + S2RTiledCopyVt s2r_tiled_copy_vt; + R2STiledCopyV r2s_tiled_copy_v; + auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(warp_group_thread_idx); + auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(warp_group_thread_idx); + // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8) + Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32) + // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64)) + Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64)) + CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_)); + CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_)); + + static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1; + Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_)>(tTranssVt_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2)) + Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_)>(tTranssV_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2)) + #pragma unroll + for (int i = 0; i < size<1, 1>(tTranssVt); ++i) { + Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}))); + static_assert(size<0>(tTransrV) == 16); + Tensor tTransrV_64 = recast(tTransrV); + cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i)), tTransrV); + #pragma unroll + for (int j = 0; j < size(tTransrV_64); ++j) { + uint32_t upper = tTransrV_64[j].x; + uint32_t lower = tTransrV_64[j].y; + tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420); + tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531); + } + cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i))); + } + }; + + TransV(); cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::TransVReady)); } @@ -468,7 +530,11 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Double buffer for sK const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (!Kernel_traits::Is_FP8) { + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } else { + sV.data() = sV.data() + sK_offset; + } } cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); From dbd8c307eb1b28660a85cc26aba01c31377f7c59 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Wed, 26 Feb 2025 22:06:04 +0800 Subject: [PATCH 16/30] fix sV --- csrc/flash_fwd_mla_kernel.h | 39 ++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 79d6ba7..12e9883 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -14,18 +14,29 @@ using namespace cute; #include "flash_mla.h" -template +template constexpr auto getSmemLayoutK() { constexpr int headSizeBytes = sizeof(PrecType) * DIM; constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; - if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { - return GMMA::Layout_K_SW128_Atom{}; - } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { - return GMMA::Layout_K_SW64_Atom{}; + if constexpr (major == GMMA::Major::K) { + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } else { + return GMMA::Layout_K_SW32_Atom{}; + } } else { - return GMMA::Layout_K_SW32_Atom{}; + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } else { + return GMMA::Layout_MN_SW32_Atom{}; + } } + } template @@ -75,11 +86,16 @@ struct Flash_fwd_kernel_traits_mla { getSmemLayoutK(), Shape, Int>{})); + // ------ for f16 ------ using SmemLayoutV = decltype(tile_to_shape( getSmemLayoutK(), Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + // ------ for f8 ------ + using SmemLayoutVtLoad = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); using SmemLayoutVtMMa = decltype(tile_to_shape( getSmemLayoutK(), Shape, Int >{})); @@ -135,11 +151,6 @@ struct Flash_fwd_kernel_traits_mla { Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>>{})); // Val layout, 4 vals per store - - - - // for fp8 trans-v - }; namespace flash { @@ -278,7 +289,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); auto sV = cute::conditional_return( - cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtMMa{})), + cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtLoad{})), make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{})); auto sVt = cute::conditional_return( @@ -476,9 +487,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(warp_group_thread_idx); auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(warp_group_thread_idx); // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8) - Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sVt, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32) + Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sV, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32) // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64)) - Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sV, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64)) + Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sVt, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64)) CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_)); CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_)); CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_)); From 1757a6db07c25bcc257255b04676b946d1f21624 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Thu, 27 Feb 2025 09:11:17 +0800 Subject: [PATCH 17/30] try fix --- csrc/flash_fwd_mla_kernel.h | 5 +++-- csrc/utils.h | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 12e9883..3b3cd9c 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -367,6 +367,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f softmax.template softmax(tSrS, params.scale_softmax_log2) : softmax.template softmax(tSrS, params.scale_softmax_log2); + if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); } Tensor rP = flash::convert_type(tSrS); cute::copy(rP, tPsP); cute::copy(scale_o, tScale_osScale_o); @@ -379,7 +380,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); } - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK @@ -536,7 +537,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f flash::rescale_o(tOrO, scale_o); - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK diff --git a/csrc/utils.h b/csrc/utils.h index 3b8dd52..854c75f 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -235,4 +235,24 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor +CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { + // frag has shape ((2, 2, N / 8), MMA_M, MMA_N), each element is 32 bits + static_assert(decltype(size<0, 0>(frag))::value == 2); + static_assert(decltype(size<0, 1>(frag))::value == 2); + static_assert(decltype(size<0, 2>(frag))::value % 2 == 0); + static_assert(decltype(stride<0, 0>(frag))::value == 1); + static_assert(sizeof(typename Fragment::value_type) == 4); + Tensor frag_64b = group_modes<1, 3>(recast(frag)); // ((1, 2, N / 8), (MMA_M, MMA_N)) + #pragma unroll + for (int mi = 0; mi < size<1>(frag_64b); ++mi) { + #pragma unroll + for (int i = 0; i < size<0, 2>(frag_64b) / 2; ++i) { + cutlass::swap(frag_64b(make_coord(_0{}, _1{}, 2 * i), mi), frag_64b(make_coord(_0{}, _0{}, 2 * i + 1), mi)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash From d1689ab64f3d7db5e5e7f4f068c7a1b0679e253c Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Thu, 27 Feb 2025 10:56:43 +0800 Subject: [PATCH 18/30] use mm1's Aregs instead of mma0's Cregs --- csrc/flash_fwd_mla_kernel.h | 24 ++++++++++++++---------- csrc/utils.h | 16 ++++++++++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 3b3cd9c..0c575c7 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -100,7 +100,11 @@ struct Flash_fwd_kernel_traits_mla { getSmemLayoutK(), Shape, Int >{})); - using SmemLayoutP = Layout, Int, _1, Int>>; + using SmemLayoutP = std::conditional_t< + Is_FP8, + Layout, Int, _1, _2, Int>>, + Layout, Int, _1, _2, Int>> + >; using SmemLayoutRow = Layout>, Stride<_1, _2>>; using SmemLayoutAtomO = decltype(composition( @@ -297,7 +301,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); - Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); + Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _); Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); @@ -368,8 +372,11 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f : softmax.template softmax(tSrS, params.scale_softmax_log2); if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); } - Tensor rP = flash::convert_type(tSrS); - cute::copy(rP, tPsP); + Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); + Tensor tOrP = make_tensor_like(tOrP_acc); + convert_type_out(tOrP_acc, tOrP); + + cute::copy(tOrP, tPsP); // send Aregs of MMA1 instead of Cregs of MMA0 cute::copy(scale_o, tScale_osScale_o); cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); @@ -380,7 +387,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); } - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK @@ -529,15 +535,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f } typename Kernel_traits::TiledMma tiled_mma; - auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout(); - Tensor rP = make_tensor(tSrS_layout); + auto tSrS_layout = flash::convert_layout_acc_Aregs(partition_fragment_C(tiled_mma, Shape, Int>{}).layout()); + Tensor tOrP = make_tensor(tSrS_layout); Tensor scale_o = make_tensor(Shape<_2>{}); cute::copy(tScale_osScale_o, scale_o); - cute::copy(tPsP, rP); + cute::copy(tPsP, tOrP); flash::rescale_o(tOrO, scale_o); - - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK diff --git a/csrc/utils.h b/csrc/utils.h index 854c75f..716c50c 100644 --- a/csrc/utils.h +++ b/csrc/utils.h @@ -255,4 +255,20 @@ CUTLASS_DEVICE void permute_Cregs_fp8(Fragment &frag) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +CUTLASS_DEVICE void convert_type_out(Tensor const &tensor, Tensor &out) { + // Somehow if we allocate out inside this function and return it, e2e is slower and the output can be wrong. + using From_type = typename Engine::value_type; + using To_type = typename EngineOut::value_type; + static constexpr int FragmentSize = std::max(sizeof(From_type) / sizeof(To_type), sizeof(To_type) / sizeof(From_type)); + static_assert(CUTE_STATIC_V(size(tensor)) % FragmentSize == 0, "Fragment size does not vectorize properly"); + Tensor frag = recast const>(tensor); + Tensor out_frg = recast>(out); + static_assert(size(frag) == size(out_frg)); + cutlass::NumericArrayConverter convert_op; + #pragma unroll + for (int i = 0; i < size(frag); ++i) { out_frg[i] = convert_op(frag[i]); } +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash From 855c985b007bb03a6395530d6c0c23b8dc8cb154 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Thu, 27 Feb 2025 22:28:45 +0800 Subject: [PATCH 19/30] use 64x64 transpose_v --- csrc/flash_fwd_mla_kernel.h | 89 +++++++++++-------------------------- csrc/fp8_transpose_v.h | 82 ++++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 64 deletions(-) create mode 100644 csrc/fp8_transpose_v.h diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 0c575c7..512fb9b 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -12,6 +12,7 @@ using namespace cute; #include "softmax.h" #include "static_switch.h" #include "flash_mla.h" +#include "fp8_transpose_v.h" template @@ -86,20 +87,11 @@ struct Flash_fwd_kernel_traits_mla { getSmemLayoutK(), Shape, Int>{})); - // ------ for f16 ------ using SmemLayoutV = decltype(tile_to_shape( getSmemLayoutK(), Shape, Int>{})); using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); - // ------ for f8 ------ - using SmemLayoutVtLoad = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int>{})); - using SmemLayoutVtMMa = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int >{})); - using SmemLayoutP = std::conditional_t< Is_FP8, Layout, Int, _1, _2, Int>>, @@ -155,6 +147,13 @@ struct Flash_fwd_kernel_traits_mla { Copy_Atom, ElementAccum>{}, GmemLayoutAtomOaccum{}, Layout>>{})); // Val layout, 4 vals per store + + + // ------ for f8 ------ + using SmemLayoutVtMMa = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int >{})); + using SmemFp8Tranpose = SmemTransposeFp8_64x64; }; namespace flash { @@ -292,10 +291,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); - auto sV = cute::conditional_return( - cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtLoad{})), - make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{})); - + auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); auto sVt = cute::conditional_return( make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}), make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); @@ -438,9 +434,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if constexpr (!Kernel_traits::Is_FP8) { tOrVt.data() = tOrVt.data() + sK_offset / 8; } - else { - sV.data() = sV.data() + sK_offset; - } } // We need to clear the sK smem tiles because K is V. @@ -474,53 +467,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if constexpr (Kernel_traits::Is_FP8) { auto TransV = [&]() { - // refer to fa3's TransV: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L697 - using LDSM_divide_shape = Shape<_64, _8>; - using S2RTiledCopyVt = decltype(make_tiled_copy( - Copy_Atom{}, - Layout, Stride<_4, _1, _0, _0>>{}, // thread layout - Layout, Stride<_1, _2, _16, _4>>{} // val layout - )); - - using STSM_divide_shape = Shape<_8, _16>; - using R2STiledCopyV = decltype(make_tiled_copy( - Copy_Atom{}, - Layout, Stride<_4, _1, _32, _0>>{}, // thread layout - Layout, Stride<_0, _1, _4, _8>>{} // val layout - )); - - S2RTiledCopyVt s2r_tiled_copy_vt; - R2STiledCopyV r2s_tiled_copy_v; - auto s2r_thr_copy_vt = s2r_tiled_copy_vt.get_thread_slice(warp_group_thread_idx); - auto r2s_thr_copy_v = r2s_tiled_copy_v.get_thread_slice(warp_group_thread_idx); - // flat_divide(sVt, LDSM_divide_shape{}): (64, 8, kHeadDim / 64, kBlockN / 8) - Tensor tTranssVt_ = s2r_thr_copy_vt.partition_S(flat_divide(sV, LDSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, kBlockN / 32) - // flat_divide(sV, STSM_divide_shape{}): (8, 16, kHeadDim / 8, (4, kBlockN / 64)) - Tensor tTranssV_ = r2s_thr_copy_v.partition_D(flat_divide(sVt, STSM_divide_shape{})); // ((16, 1), 1, 1, kHeadDim / 64, (2, kBlockN / 64)) - CUTE_STATIC_ASSERT_V(rank(tTranssVt_) == rank(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<0>(tTranssVt_) == size<0>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<1>(tTranssVt_) == size<1>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<2>(tTranssVt_) == size<2>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<3>(tTranssVt_) == size<3>(tTranssV_)); - CUTE_STATIC_ASSERT_V(size<4>(tTranssVt_) == size<4>(tTranssV_)); - - static constexpr int Transpose_ILP = (size<2>(tTranssVt_) * size<3>(tTranssVt_)) % 2 == 0 ? 2 : 1; - Tensor tTranssVt = logical_divide(group_modes<1, rank(tTranssVt_)>(tTranssVt_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2)) - Tensor tTranssV = logical_divide(group_modes<1, rank(tTranssV_)>(tTranssV_), Shape>{}); // ((16, 1), (2, kHeadDim / 64 * kBlockN / 32 / 2)) - #pragma unroll - for (int i = 0; i < size<1, 1>(tTranssVt); ++i) { - Tensor tTransrV = make_fragment_like(tTranssV(_, make_coord(_, _0{}))); - static_assert(size<0>(tTransrV) == 16); - Tensor tTransrV_64 = recast(tTransrV); - cute::copy(s2r_tiled_copy_vt, tTranssVt(_, make_coord(_, i)), tTransrV); - #pragma unroll - for (int j = 0; j < size(tTransrV_64); ++j) { - uint32_t upper = tTransrV_64[j].x; - uint32_t lower = tTransrV_64[j].y; - tTransrV_64[j].x = __byte_perm(upper, lower, 0x6420); - tTransrV_64[j].y = __byte_perm(upper, lower, 0x7531); + using SmemFp8Tranpose = typename Kernel_traits::SmemFp8Tranpose; + SmemFp8Tranpose smem_transpose_V; + Tensor sV_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename SmemFp8Tranpose::SmemLayoutTransposeV{})); + Tensor sVt_divide = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename SmemFp8Tranpose::SmemLayoutTransposeVt{})); + + if (n_block % 2 == 1) { + sV_divide.data() = sV_divide.data() + size(sK); + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < shape<2>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < shape<1>(typename SmemFp8Tranpose::SmemLayoutTransposeV{}); ++i) { + smem_transpose_V.transpose(flatten(sV_divide(_, i, j)), flatten(sVt_divide(_, i, j))); } - cute::copy(r2s_tiled_copy_v, tTransrV, tTranssV(_, make_coord(_, i))); } }; @@ -548,8 +511,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); if constexpr (!Kernel_traits::Is_FP8) { tOrVt.data() = tOrVt.data() + sK_offset / 8; - } else { - sV.data() = sV.data() + sK_offset; } } diff --git a/csrc/fp8_transpose_v.h b/csrc/fp8_transpose_v.h new file mode 100644 index 0000000..ba9e31e --- /dev/null +++ b/csrc/fp8_transpose_v.h @@ -0,0 +1,82 @@ +#pragma once + +template +struct SmemTransposeFp8_64x64 { + static_assert(sizeof(Element) == 1); + static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0)); + + using SmemLayoutK = decltype(tile_to_shape( + GMMA::Layout_K_SW64_Atom{}, + Shape, Int>{})); + using SmemLayoutV = decltype(composition( + SmemLayoutK{}, + Layout, Int>, Stride<_1, Int>>{})); + using TransposeShapeAtomV = Shape<_64, _64>; + + // for fp8 in-kernel transpose -- src layout + using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); + using SmemShapeLDSM = Shape, Shape<_16, _4>>; + using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, + shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{}))); + using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{}))); + + // For fp8, this is the memory transpose. + using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutVt = decltype(tile_to_shape( + SmemLayoutAtomVt{}, + Shape, Int>{})); + + // for fp8 in-kernel transpose -- dst layout + using SmemLayoutVtTrans = decltype(composition( + SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{}))); + using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); + using SmemShapeSTSM = Shape, Shape<_8, _8>>; + using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), + shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{}))); + using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); + + + using ldsm_thread_shape = Shape<_4, _1, _8, _4>; + using ldsm_value_shape = Shape<_2, _8, _2, _1>; + using ldsm_value_stride = Stride<_2, _4, _1, _0>; + using TiledCopyLDSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopyLDSM tiled_copy_ldsm; + + using stsm_thread_shape = Shape<_4, _1, _8, _4>; + // using stsm_thread_stride = Stride<_1, _0, _4, _32>; + using stsm_value_shape = Shape<_4, _4, _1, _2>; + using stsm_value_stride = Stride<_1, _8, _0, _4>; + + using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, + Layout{})); + TiledCopySTSM tiled_copy_stsm; + + template + CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) { + using namespace cute; + + auto tid = threadIdx.x; + auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); + auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); + + auto tXsX = thr_copy_ldsm.partition_S(s_in); + auto tXrX = make_tensor(shape(tXsX)); + auto tXsX_out = thr_copy_stsm.partition_D(s_out); + + cute::copy(tiled_copy_ldsm, tXsX, tXrX); + + auto data = tXrX.data(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size(tXrX); n += 8) { + uint32_t *data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + data_32bit[0] = __byte_perm(upper, lower, 0x6420); + data_32bit[1] = __byte_perm(upper, lower, 0x7531); + } + + cute::copy(tiled_copy_stsm, tXrX, tXsX_out); + } +}; + From 1df91aff33134bc27b763d61dd7320696b14e23f Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Thu, 27 Feb 2025 23:40:02 +0800 Subject: [PATCH 20/30] fix compile --- csrc/flash_fwd_mla_kernel.h | 1 - csrc/fp8_transpose_v.h | 8 +++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 512fb9b..6dff13a 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -395,7 +395,6 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cute::copy(softmax.row_sum, tRow_sumsRow_sum); cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); } else { - const int warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; const int *block_table = params.block_table + bidb * params.block_table_batch_stride; int cur_block_table = __ldg(&block_table[n_block]); diff --git a/csrc/fp8_transpose_v.h b/csrc/fp8_transpose_v.h index ba9e31e..020944b 100644 --- a/csrc/fp8_transpose_v.h +++ b/csrc/fp8_transpose_v.h @@ -16,8 +16,7 @@ struct SmemTransposeFp8_64x64 { // for fp8 in-kernel transpose -- src layout using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); using SmemShapeLDSM = Shape, Shape<_16, _4>>; - using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, - shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}), shape<3>(SmemLayoutDivideV{}))); + using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{}))); using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{}))); // For fp8, this is the memory transpose. @@ -28,11 +27,10 @@ struct SmemTransposeFp8_64x64 { // for fp8 in-kernel transpose -- dst layout using SmemLayoutVtTrans = decltype(composition( - SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{}))); + SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{}))); using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); using SmemShapeSTSM = Shape, Shape<_8, _8>>; - using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), - shape<2>(SmemLayoutDivideVt{}), shape<3>(SmemLayoutDivideVt{}))); + using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}))); using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); From 0337732dc1310d24dbe54ac928da3f685433c0a0 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 08:09:02 +0800 Subject: [PATCH 21/30] reorg --- csrc/flash_fwd_mla_kernel.h | 23 +++++++++++++++-------- csrc/fp8_transpose_v.h | 7 ++----- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index 6dff13a..b4f3ed7 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -150,10 +150,11 @@ struct Flash_fwd_kernel_traits_mla { // ------ for f8 ------ - using SmemLayoutVtMMa = decltype(tile_to_shape( - getSmemLayoutK(), - Shape, Int >{})); - using SmemFp8Tranpose = SmemTransposeFp8_64x64; + using SmemFp8Tranpose = SmemTransposeFp8_64x64; + // using SmemLayoutVtMMa = decltype(tile_to_shape( + // getSmemLayoutK(), + // Shape, Int >{})); + using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt; }; namespace flash { @@ -292,9 +293,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); auto sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); - auto sVt = cute::conditional_return( - make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}), - make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{})); + auto sVt = [&](){ + if constexpr(Kernel_traits::Is_FP8){ + return make_tensor(make_smem_ptr(shared_storage.smem_vt.data()), typename Kernel_traits::SmemLayoutVtMMa{}); + } else { + return make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + } + }(); Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); Tensor tPsP = sP(_, tidx % kNThreadsS, _, _, _); @@ -381,8 +386,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if constexpr (Kernel_traits::Is_FP8) { cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::TransVReady)); + __syncthreads(); } - flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK @@ -504,6 +509,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f cute::copy(tPsP, tOrP); flash::rescale_o(tOrO, scale_o); + + if constexpr (Kernel_traits::Is_FP8) __syncthreads(); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); // Double buffer for sK diff --git a/csrc/fp8_transpose_v.h b/csrc/fp8_transpose_v.h index 020944b..f1a1919 100644 --- a/csrc/fp8_transpose_v.h +++ b/csrc/fp8_transpose_v.h @@ -1,13 +1,10 @@ #pragma once -template +template struct SmemTransposeFp8_64x64 { - static_assert(sizeof(Element) == 1); static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0)); - using SmemLayoutK = decltype(tile_to_shape( - GMMA::Layout_K_SW64_Atom{}, - Shape, Int>{})); + using Element = cutlass::float_e4m3_t; using SmemLayoutV = decltype(composition( SmemLayoutK{}, Layout, Int>, Stride<_1, Int>>{})); From 061af5fc564b0ff5cdaac113ef04ac3f5de41c8e Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 14:35:07 +0800 Subject: [PATCH 22/30] use fa'3 transv --- csrc/fp8_transpose_v.h | 31 ++++++++++++++++--------------- tests/test_flash_mla.py | 2 +- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/csrc/fp8_transpose_v.h b/csrc/fp8_transpose_v.h index f1a1919..c566066 100644 --- a/csrc/fp8_transpose_v.h +++ b/csrc/fp8_transpose_v.h @@ -5,10 +5,11 @@ struct SmemTransposeFp8_64x64 { static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0)); using Element = cutlass::float_e4m3_t; - using SmemLayoutV = decltype(composition( - SmemLayoutK{}, - Layout, Int>, Stride<_1, Int>>{})); using TransposeShapeAtomV = Shape<_64, _64>; + using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, + Shape, Int>{})); // for fp8 in-kernel transpose -- src layout using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{})); @@ -18,15 +19,15 @@ struct SmemTransposeFp8_64x64 { // For fp8, this is the memory transpose. using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom{}, TransposeShapeAtomV{})); - using SmemLayoutVt = decltype(tile_to_shape( - SmemLayoutAtomVt{}, - Shape, Int>{})); + using SmemLayoutVt = + decltype(tile_to_shape(SmemLayoutAtomVt{}, + Shape, Int>{})); // for fp8 in-kernel transpose -- dst layout using SmemLayoutVtTrans = decltype(composition( SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{}))); using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{})); - using SmemShapeSTSM = Shape, Shape<_8, _8>>; + using SmemShapeSTSM = Shape, Shape<_16, _4>>; using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{}))); using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{}))); @@ -40,8 +41,8 @@ struct SmemTransposeFp8_64x64 { using stsm_thread_shape = Shape<_4, _1, _8, _4>; // using stsm_thread_stride = Stride<_1, _0, _4, _32>; - using stsm_value_shape = Shape<_4, _4, _1, _2>; - using stsm_value_stride = Stride<_1, _8, _0, _4>; + using stsm_value_shape = Shape<_4, _4, _2, _1>; + using stsm_value_stride = Stride<_1, _8, _4, _0>; using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom{}, Layout{}, Layout{})); @@ -51,7 +52,7 @@ struct SmemTransposeFp8_64x64 { CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) { using namespace cute; - auto tid = threadIdx.x; + auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup; auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid); auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid); @@ -64,11 +65,11 @@ struct SmemTransposeFp8_64x64 { auto data = tXrX.data(); CUTLASS_PRAGMA_UNROLL for (int n = 0; n < size(tXrX); n += 8) { - uint32_t *data_32bit = reinterpret_cast(&data[n]); - auto upper = data_32bit[0]; - auto lower = data_32bit[1]; - data_32bit[0] = __byte_perm(upper, lower, 0x6420); - data_32bit[1] = __byte_perm(upper, lower, 0x7531); + uint32_t *data_32bit = reinterpret_cast(&data[n]); + auto upper = data_32bit[0]; + auto lower = data_32bit[1]; + data_32bit[0] = __byte_perm(upper, lower, 0x6420); + data_32bit[1] = __byte_perm(upper, lower, 0x7531); } cute::copy(tiled_copy_stsm, tXrX, tXsX_out); diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 91bd6f1..6cfd466 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -33,7 +33,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) amax_diff = (x - y).abs().max().item() # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - #assert cos_diff < 1e-5 + assert cos_diff < 1e-5 @torch.inference_mode() From fd1e662debb640812123111c8e1f1396bbcf94a0 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 16:52:30 +0800 Subject: [PATCH 23/30] fix mma0 --- csrc/flash_fwd_mla_kernel.h | 21 +++++++++++++++------ tests/test_flash_mla.py | 10 +++++++--- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index b4f3ed7..ad52b3c 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -328,8 +328,13 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if (n_block % 2 == 1) { // Double buffer for sK constexpr int sK_offset = size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; + + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } } // We need masking on S for the very last block when K and V has length not multiple of kBlockN. @@ -392,8 +397,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // Double buffer for sK const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tSrK.data() = tSrK.data() + sK_offset / 8; - if constexpr (!Kernel_traits::Is_FP8) tOrVt.data() = tOrVt.data() + sK_offset / 8; + if constexpr (Kernel_traits::Is_FP8) { + tSrK.data() = tSrK.data() + sK_offset / 16; + } else { + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } } cute::copy(softmax.row_max, tRow_maxsRow_max); @@ -513,9 +522,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f if constexpr (Kernel_traits::Is_FP8) __syncthreads(); flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); if constexpr (!Kernel_traits::Is_FP8) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); tOrVt.data() = tOrVt.data() + sK_offset / 8; } } diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 6cfd466..03c9037 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -27,13 +27,17 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): return attn_weight @ value, lse -def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None: x, y = x.double(), y.double() RMSE = ((x - y) * (x - y)).mean().sqrt().item() cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) amax_diff = (x - y).abs().max().item() # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") - assert cos_diff < 1e-5 + + if use_fp8: + assert cos_diff < 1e-3 + else: + assert cos_diff < 1e-5 @torch.inference_mode() @@ -111,7 +115,7 @@ def ref_mla(): out_flash, lse_flash = flash_mla() out_torch, lse_torch = ref_mla() - cal_diff(out_flash, out_torch, "out") + cal_diff(out_flash, out_torch, "out", use_fp8) cal_diff(lse_flash, lse_torch, "lse") t = triton.testing.do_bench(flash_mla) From bfe38ab10649b55c275d454afe93a346e8d5bf20 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 18:45:09 +0800 Subject: [PATCH 24/30] fix combine --- csrc/flash_fwd_mla_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index ad52b3c..bef20ee 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -688,7 +688,7 @@ void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream dim3 grid_combine(params.b * params.h * params.seqlen_q); MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< - typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; + typename Kernel_traits::ElementO, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; combine_kernel<<>>(params); }); CHECK_CUDA_KERNEL_LAUNCH(); From 4e055a6142143dc8cdd70e918fef9eb6dbef3949 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 18:59:02 +0800 Subject: [PATCH 25/30] reorg ut --- tests/test_flash_mla.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 03c9037..b840a97 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -63,8 +63,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = blocked_v = blocked_k[..., :dv] tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) - + init_dtype = q.dtype + def prepare_fp8_input(): q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None @@ -78,33 +79,36 @@ def prepare_fp8_input(): blocked_k_fp8 = blocked_k.to(fp8_dtype) blocked_v_fp8 = blocked_v.to(fp8_dtype) - q = q_fp8.to(q.dtype) * descale_q - blocked_k = blocked_k_fp8.to(blocked_k.dtype) * descale_k - blocked_v = blocked_v_fp8.to(blocked_v.dtype) * descale_k - return q_fp8, blocked_k_fp8, descale_q, descale_k + return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k - q_fp8, blocked_k_fp8, descale_q, descale_k = prepare_fp8_input() - + if use_fp8: + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input() + q = q_fp8 + blocked_k = blocked_k_fp8 + blocked_v = blocked_v_fp8 + def flash_mla(): - q_ = q; blocked_k_ = blocked_k - if use_fp8: q_ = q_fp8; blocked_k_ = blocked_k_fp8 return flash_mla_with_kvcache( - q_, blocked_k_, block_table, cache_seqlens, dv, + q, blocked_k, block_table, cache_seqlens, dv, tile_scheduler_metadata, num_splits, causal=causal, descale_q=descale_q, descale_k=descale_k, ) def ref_mla(): + if use_fp8: + q_ = (q.to(torch.float) * descale_q).to(init_dtype) + blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) + blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] O, LSE = scaled_dot_product_attention( - q[i].transpose(0, 1), - blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), - blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), h_q=h_q, h_kv=h_kv, is_causal=causal, From 8b939854d8869d94c9e63299c098ca64535c2173 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 19:43:24 +0800 Subject: [PATCH 26/30] enable scale --- csrc/flash_fwd_mla_kernel.h | 29 +++++++++++++++-------------- csrc/softmax.h | 4 ++-- tests/test_flash_mla.py | 2 +- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/csrc/flash_fwd_mla_kernel.h b/csrc/flash_fwd_mla_kernel.h index bef20ee..6e92b5e 100644 --- a/csrc/flash_fwd_mla_kernel.h +++ b/csrc/flash_fwd_mla_kernel.h @@ -151,9 +151,6 @@ struct Flash_fwd_kernel_traits_mla { // ------ for f8 ------ using SmemFp8Tranpose = SmemTransposeFp8_64x64; - // using SmemLayoutVtMMa = decltype(tile_to_shape( - // getSmemLayoutK(), - // Shape, Int >{})); using SmemLayoutVtMMa = typename SmemFp8Tranpose::SmemLayoutVt; }; @@ -186,7 +183,7 @@ struct SharedStorageMLA { template __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, - SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { + SharedStorage &shared_storage, AccO tOrO, Softmax softmax, float descale_k, float scale_softmax) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kHeadDimV = Kernel_traits::kHeadDimV; constexpr int kNThreadsS = Kernel_traits::kNThreadsS; @@ -203,7 +200,7 @@ __forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const const int split_offset = __ldg(params.num_splits_ptr + bidb); - Tensor lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse(tOrO, scale_softmax, descale_k); using ElementO = std::conditional_t; Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -275,7 +272,7 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f const int bidb, const int bidh, const int m_block, const int n_split_idx, const int seqlen_k, const int n_block_min, const int n_block_max, const bool NoSplit, - SharedStorage &shared_storage, const float descale_q, const float descale_k) { + SharedStorage &shared_storage, const float descale_k, const float scale_softmax, const float scale_softmax_log2) { constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; @@ -372,10 +369,10 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f // We have key_padding_mask so we'll need to Check_inf Tensor scale_o = is_first_masking_step - ? softmax.template softmax(tSrS, params.scale_softmax_log2) + ? softmax.template softmax(tSrS, scale_softmax_log2) : is_masking_step ? - softmax.template softmax(tSrS, params.scale_softmax_log2) - : softmax.template softmax(tSrS, params.scale_softmax_log2); + softmax.template softmax(tSrS, scale_softmax_log2) + : softmax.template softmax(tSrS, scale_softmax_log2); if constexpr (Kernel_traits::Is_FP8) { flash::permute_Cregs_fp8(tSrS); } Tensor tOrP_acc = make_tensor(tSrS.data(), flash::convert_layout_acc_Aregs(tSrS.layout())); @@ -535,9 +532,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_f } if (NoSplit) - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); else - store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax, descale_k, scale_softmax); } template @@ -560,10 +557,14 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (begin_idx >= params.b) return; int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); - float descale_q, descale_k; + float descale_k = 1.f; + float scale_softmax = params.scale_softmax; + float scale_softmax_log2 = params.scale_softmax_log2; if constexpr (Kernel_traits::Is_FP8) { - descale_q = __ldg(params.descale_q_ptr); + float descale_q = __ldg(params.descale_q_ptr); descale_k = __ldg(params.descale_k_ptr); + scale_softmax = scale_softmax * descale_q * descale_k; + scale_softmax_log2 = scale_softmax_log2 * descale_q * descale_k; } #pragma unroll 1 @@ -576,7 +577,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params if (batch_id > begin_idx) { __syncthreads(); // Barrier between two tiles. } - flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_q, descale_k); + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage, descale_k, scale_softmax, scale_softmax_log2); } } diff --git a/csrc/softmax.h b/csrc/softmax.h index 4ab6ae9..bcb8cac 100644 --- a/csrc/softmax.h +++ b/csrc/softmax.h @@ -174,7 +174,7 @@ struct Softmax { }; template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float descale_v, float rp_dropout=1.0) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); @@ -184,7 +184,7 @@ struct Softmax { #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { float sum = row_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : descale_v / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; #pragma unroll diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index b840a97..ff7cd27 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -35,7 +35,7 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) - # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") if use_fp8: - assert cos_diff < 1e-3 + assert cos_diff < 1e-2 else: assert cos_diff < 1e-5 From 9887a5501e8348e6a6056a736c2734d5493fe947 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Fri, 28 Feb 2025 22:18:04 +0800 Subject: [PATCH 27/30] update readme --- README.md | 2 +- tests/test_flash_mla.py | 25 +++++++++++++------------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index b79757c..2e1ebff 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving. Currently released: -- BF16, FP16 +- BF16, FP16, E4M3 - Paged kvcache with block size of 64 ## Quick start diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 9c1ddcf..010bda3 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -42,11 +42,12 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) - @torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = False): +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, torch_dtype): print( - f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {use_fp8=}" + f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}" ) + use_fp8 = torch_dtype == torch.float8_e4m3fn cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) if varlen: for i in range(b): @@ -73,8 +74,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, use_fp8 = cache_seqlens, s_q * h_q // h_kv, h_kv ) + init_dtype = q.dtype def prepare_fp8_input(): - q_fp8, blocked_k_fp8, descale_q, descale_k = None, None, None, None + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None if use_fp8: nonlocal q, blocked_k, blocked_v @@ -89,8 +91,9 @@ def prepare_fp8_input(): return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k + + q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input() if use_fp8: - q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input() q = q_fp8 blocked_k = blocked_k_fp8 blocked_v = blocked_v_fp8 @@ -110,10 +113,9 @@ def flash_mla(): ) def ref_mla(): - if use_fp8: - q_ = (q.to(torch.float) * descale_q).to(init_dtype) - blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) - blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) lse = torch.empty(b, h_q, s_q, dtype=torch.float32) for i in range(b): @@ -158,14 +160,13 @@ def main(torch_dtype): h_kv = 1 d, dv = 576, 512 causal = False - use_fp8 = torch_dtype == torch.float8_e4m3fn for b in [128]: for s in [4096, 8192]: for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 for s_q in [1, 2]: # MTP = 1, 2 for varlen in [False, True]: - test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, use_fp8) + test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen, torch_dtype) if __name__ == "__main__": @@ -183,7 +184,7 @@ def main(torch_dtype): torch_dtype = torch.bfloat16 if args.dtype == "fp16": torch_dtype = torch.float16 - elif args.dtype = "e4m3": - torch.dtype = torch.float8_e4m3fn + elif args.dtype == "e4m3": + torch_dtype = torch.float8_e4m3fn main(torch_dtype) From 90289837fc2d445bb94ec58d9cd2bdf17a202648 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Sat, 1 Mar 2025 02:14:42 +0800 Subject: [PATCH 28/30] update ut --- tests/test_flash_mla.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py index 010bda3..0cd173c 100644 --- a/tests/test_flash_mla.py +++ b/tests/test_flash_mla.py @@ -140,9 +140,7 @@ def ref_mla(): t = triton.testing.do_bench(flash_mla) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * ( - torch.finfo(q.dtype).bits // 8 - ) + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) print( f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" ) From 6199b0b4b56b7e13926115ba4ca84f667bfe97b2 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Sat, 1 Mar 2025 07:53:04 +0800 Subject: [PATCH 29/30] update desc --- csrc/fp8_transpose_v.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/csrc/fp8_transpose_v.h b/csrc/fp8_transpose_v.h index c566066..40bb4d5 100644 --- a/csrc/fp8_transpose_v.h +++ b/csrc/fp8_transpose_v.h @@ -1,3 +1,8 @@ +/** + * ref to Fa3's SmemTranspose64x64: + * https://github.com/Dao-AILab/flash-attention/blob/0823cf7b5d96499c1c79a4f64b1e256a035ba4b4/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L26 +*/ + #pragma once template From 7fafcd217d9fbaf30ef428b1e2e6c9c46d1ddd89 Mon Sep 17 00:00:00 2001 From: "chenhongmin.will" Date: Sat, 1 Mar 2025 14:44:25 +0800 Subject: [PATCH 30/30] add env --- csrc/flash_api.cpp | 5 ++++- setup.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index a865fc5..d6f8108 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -218,9 +218,12 @@ mha_fwd_kvcache_mla( run_mha_fwd_splitkv_mla(params, stream); } #endif + #ifndef FLASH_MLA_DISABLE_FP8 else if (q_dtype == torch::kFloat8_e4m3fn) { run_mha_fwd_splitkv_mla(params, stream); - } else { + } + #endif + else { TORCH_CHECK(false, "Unsupported tensor dtype for query"); } diff --git a/setup.py b/setup.py index ef1a8a7..0b971c4 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ ) DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_MLA_DISABLE_FP8", "FALSE") == "TRUE" def append_nvcc_threads(nvcc_extra_args): @@ -23,12 +24,13 @@ def get_sources(): sources = [ "csrc/flash_api.cpp", "csrc/flash_fwd_mla_bf16_sm90.cu", - "csrc/flash_fwd_mla_fp8_sm90.cu", "csrc/flash_fwd_mla_metadata.cu", ] if not DISABLE_FP16: sources.append("csrc/flash_fwd_mla_fp16_sm90.cu") + if not DISABLE_FP8: + sources.append("csrc/flash_fwd_mla_fp8_sm90.cu") return sources @@ -37,6 +39,8 @@ def get_features_args(): features_args = [] if DISABLE_FP16: features_args.append("-DFLASH_MLA_DISABLE_FP16") + if DISABLE_FP8: + features_args.append("-DFLASH_MLA_DISABLE_FP8") return features_args