diff --git a/CMakeLists.txt b/CMakeLists.txt index 0194cc1c5bb..74ebbce3808 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -178,27 +178,48 @@ endif () if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) # BF16 source files file(GLOB FA3_BF16_GEN_SRCS - "hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu") + "hopper/instantiations/flash_fwd_hdim64_bf16*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim96_bf16*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim128_bf16*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim192_bf16*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim256_bf16*_sm90.cu") + # Add these for hdim diff cases file(GLOB FA3_BF16_GEN_SRCS_ - "hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu") + # "hopper/instantiations/flash_fwd_hdim64_256_bf16*_sm90.cu" + # "hopper/instantiations/flash_fwd_hdim64_512_bf16*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim192_128_bf16*_sm90.cu") list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) file(GLOB FA3_BF16_GEN_SRCS_ "hopper/instantiations/flash_fwd_*_bf16_*_sm80.cu") list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) + # FP16 source files file(GLOB FA3_FP16_GEN_SRCS - "hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu") + "hopper/instantiations/flash_fwd_hdim64_fp16*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim96_fp16*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim128_fp16*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim192_fp16*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim256_fp16*_sm90.cu") + # Add these for hdim diff cases file(GLOB FA3_FP16_GEN_SRCS_ - "hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu") + # "hopper/instantiations/flash_fwd_hdim64_256_fp16*_sm90.cu" + # "hopper/instantiations/flash_fwd_hdim64_512_fp16*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim192_128_fp16*_sm90.cu") list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) file(GLOB FA3_FP16_GEN_SRCS_ "hopper/instantiations/flash_fwd_*_fp16_*_sm80.cu") list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) + # FP8 source files file(GLOB FA3_FP8_GEN_SRCS - "hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu") + "hopper/instantiations/flash_fwd_hdim64_e4m3*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim96_e4m3*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim128_e4m3*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim192_e4m3*_sm90.cu" + "hopper/instantiations/flash_fwd_hdim256_e4m3*_sm90.cu") + # Add these for hdim diff cases (192 only) file(GLOB FA3_FP8_GEN_SRCS_ - "hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu") + "hopper/instantiations/flash_fwd_hdim192_128_e4m3*_sm90.cu") list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_}) set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS}) @@ -244,11 +265,17 @@ if (FA3_ENABLED AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0) FLASHATTENTION_DISABLE_BACKWARD FLASHATTENTION_DISABLE_DROPOUT # FLASHATTENTION_DISABLE_ALIBI - # FLASHATTENTION_DISABLE_SOFTCAP + FLASHATTENTION_DISABLE_SOFTCAP FLASHATTENTION_DISABLE_UNEVEN_K # FLASHATTENTION_DISABLE_LOCAL FLASHATTENTION_DISABLE_PYBIND FLASHATTENTION_VARLEN_ONLY # Custom flag to save on binary size + FLASHATTENTION_DISABLE_CLUSTER # disabled for varlen in any case + # FLASHATTENTION_DISABLE_SM8x + FLASHATTENTION_DISABLE_HDIMDIFF64 + # FLASHATTENTION_DISABLE_HDIMDIFF192 + CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED + CUTLASS_ENABLE_GDC_FOR_SM90 ) elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 12.0) message(STATUS "FA3 is disabled because CUDA version is not 12.0 or later.") diff --git a/hopper/block.h b/hopper/block.h index eda7eaa1c40..3da119cae91 100644 --- a/hopper/block.h +++ b/hopper/block.h @@ -11,28 +11,39 @@ struct BlockMN { static CUTLASS_DEVICE - cute::tuple get_n_block_min_max( + cute::tuple get_n_block_min_max( SeqlenInfo_t const& seqlen_info, int const m_block, int const bidb, int const split_idx, int const num_splits, int const window_size_left, int const window_size_right, cutlass::FastDivmod const& qhead_per_khead_divmod) { - int const seqlen_k = seqlen_info.seqlen_k; + int seqlen_k = seqlen_info.seqlen_k; int const seqlen_q = seqlen_info.seqlen_q; + int n_offset = 0; + + // If local, calculate n_offset and update seqlen_k + if constexpr (Is_local) { + int m_idx_min = m_block * kBlockM; + if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); } + // unlike previously, we don't divide by kBlockN because we want offset for seqlen_k + n_offset = std::max(int(0), m_idx_min + seqlen_k - seqlen_q - window_size_left); + // Subtract n_offset from seqlen_k for subsequent calculations such as n_block_max + // This is the actual seqlen_k processed for this m_block + seqlen_k -= n_offset; + } + int n_block_max = cute::ceil_div(seqlen_k, kBlockN); if constexpr (Is_causal || Is_local) { int m_idx_max = (m_block + 1) * kBlockM; // TODO: check off-by-1 error if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; } + // If local, blocking (m_idx_max - m_idx_min + window_size_right + window_size_left) n_block_max = std::min(n_block_max, cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN)); } + // Now, only adjust n_block_min if split int n_block_min = 0; - if constexpr (Is_local) { - int m_idx_min = m_block * kBlockM; - if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); } - n_block_min = std::max(int(0), (m_idx_min + seqlen_k - seqlen_q - window_size_left) / kBlockN); - } + // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } if constexpr (Split) { uint32_t num_splits_dynamic_u = reinterpret_cast(split_idx) >> 16; // first 16 bits are for num_splits @@ -45,7 +56,9 @@ struct BlockMN { // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); } } // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); } - return {n_block_min, n_block_max}; + + // Return n_offset to add to KV gmem pointers and use in masks + return {n_block_min, n_block_max, n_offset}; } static @@ -55,12 +68,12 @@ struct BlockMN { int const m_block, int const bidb, int const split_idx, int const num_splits, int const window_size_left, int const window_size_right, cutlass::FastDivmod const& qhead_per_khead_divmod) { - - auto [n_block_min, n_block_max] = get_n_block_min_max( + // TODO: check logic with n_offset + auto [n_block_min, n_block_max, n_offset] = get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, num_splits, window_size_left, window_size_right, qhead_per_khead_divmod); - int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0); - int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); + int const idx_k_new_min = std::max(n_block_min * kBlockN + n_offset - seqlen_info.seqlen_k_og, 0); + int const idx_k_new_max = std::min(n_block_max * kBlockN + n_offset - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new); int const n_block_new_min = idx_k_new_min / kBlockN; int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min; // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);} diff --git a/hopper/epilogue_fwd.hpp b/hopper/epilogue_fwd.hpp index 69102e8c4e6..9725999b27d 100644 --- a/hopper/epilogue_fwd.hpp +++ b/hopper/epilogue_fwd.hpp @@ -21,7 +21,7 @@ namespace flash { using namespace cute; template + int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false, int kBlockH_=1> struct CollectiveEpilogueFwd { using TileShape_MNK_PV = TileShape_MNK_PV_; @@ -32,9 +32,10 @@ struct CollectiveEpilogueFwd { static constexpr int NumEpilogueThreads = NumEpilogueThreads_; static constexpr bool Varlen = Varlen_; static constexpr bool PackGQA = PackGQA_; + static constexpr bool PackGQA_TMA = PackGQA && (kBlockH_ > 1); static constexpr bool Split = Split_; static constexpr bool Use_smem = !(Split && !Varlen); - static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA; + static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && (!PackGQA || PackGQA_TMA); static_assert(ArchTag::kMinComputeCapability >= 80); static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1); @@ -42,6 +43,7 @@ struct CollectiveEpilogueFwd { static constexpr int kBlockM = get<0>(TileShape_MNK_PV{}); static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{}); + static constexpr int kBlockH = kBlockH_; static constexpr bool LargeHeadDimV = kHeadDimV > 256; @@ -83,7 +85,10 @@ struct CollectiveEpilogueFwd { using StrideO = cute::Stride; using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits) // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) - using ShapeOPacked = std::conditional_t, int32_t, int32_t, int32_t, int32_t>>; + using ShapeOPackedTMA = std::conditional_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>; + using ShapeOPacked = std::conditional_t, int32_t, int32_t, int32_t, int32_t>, + ShapeOPackedTMA>; using StrideOPacked = std::conditional_t, _1, int64_t, int64_t, int64_t>>; // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits) using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; @@ -110,7 +115,7 @@ struct CollectiveEpilogueFwd { Use_TMA_O, decltype(make_tma_copy( GmemTiledCopyOTMA{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeO{}, StrideO{}), + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeOPackedTMA{}, StrideOPacked{}), SmemLayoutOTMA{}, select<0, 1>(TileShape_MNK_PV{}), _1{})), // no mcast for O @@ -158,19 +163,13 @@ struct CollectiveEpilogueFwd { static Params to_underlying_arguments(Arguments const& args) { - Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); - TMA_O tma_store_O = [&]{ - if constexpr (Use_TMA_O) { - return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast - } else { - return nullptr; - } - }(); // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits) int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv); auto const shape_O_packed = cute::conditional_return( args.shape_O, - make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) + make_shape( + make_shape(cute::conditional_return(Int{}, qhead_per_khead), get<0>(args.shape_O)), + get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O)) ); auto const stride_O_packed = cute::conditional_return( args.stride_O, @@ -180,6 +179,15 @@ struct CollectiveEpilogueFwd { args.stride_O_partial, make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial)) ); + Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), shape_O_packed, stride_O_packed); + TMA_O tma_store_O = [&]{ + if constexpr (Use_TMA_O) { + return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast + } else { + return nullptr; + } + }(); + // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits) auto const shape_LSE_packed = cute::conditional_return( select<0, 2, 3, 4>(args.shape_O), @@ -308,7 +316,7 @@ struct CollectiveEpilogueFwd { // Step 3: Write O from smem -> gmem if constexpr (Use_TMA_O) { - Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx); + Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O_packed)(_, _, bidh, bidb, split_idx); Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K) auto block_tma_O = params.tma_store_O.get_slice(_0{}); Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 07878dea63e..6a4f0e6ee65 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -272,6 +272,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if (params.dv > 256 && Arch == 90) { return run_mha_fwd_(params, stream); } else if (params.dv > 64 && Arch == 90) { @@ -279,6 +280,9 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { return run_mha_fwd_(params, stream); } + #else + return run_mha_fwd_(params, stream); + #endif } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 @@ -289,11 +293,15 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if (params.dv <= 128 && Arch == 90) { return run_mha_fwd_(params, stream); } else { return run_mha_fwd_(params, stream); } + #else + return run_mha_fwd_(params, stream); + #endif } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 @@ -303,6 +311,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if (params.dv > 256 && Arch == 90) { return run_mha_fwd_(params, stream); } else if (params.dv > 64 && Arch == 90) { @@ -310,6 +319,9 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } else { return run_mha_fwd_(params, stream); } + #else + return run_mha_fwd_(params, stream); + #endif } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 @@ -320,11 +332,15 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if (params.dv <= 128 && Arch == 90) { return run_mha_fwd_(params, stream); } else { return run_mha_fwd_(params, stream); } + #else + return run_mha_fwd_(params, stream); + #endif } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 @@ -347,11 +363,15 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if (params.dv <= 128 && Arch == 90) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } else { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } + #else + return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); + #endif } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 @@ -397,9 +417,10 @@ void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool ena } inline bool get_pagedkv_tma(Flash_fwd_params const& params) { - if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; } + // disable for local since we move k_ptr to start of sliding window by m_block + if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr || params.is_local) { return false; } // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f, use_one_mma_wg(params)); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90); // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower, @@ -411,13 +432,19 @@ inline bool get_pack_gqa(Flash_fwd_params const& params) { // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size. // Has little effect on speed. if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; } + // Always enable PackGQA for special case of hdim = 64, qheads/kvheads = 8, local attention + // TODO: investigate more cases where PackGQA improves perf due to better tile quantization + bool const packgqa_override = params.arch >= 90 && (params.h / params.h_k) == 8 && + params.is_local && + params.d == 64 && (params.dv == params.d); + if (packgqa_override) { return true; } #ifdef FLASHATTENTION_DISABLE_PACKGQA return false; #else // params.page_table must already be set if (params.h == params.h_k) { return false; } // This needs to match the kernel configs - auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); + auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f, use_one_mma_wg(params)); int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90); return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM); #endif @@ -450,6 +477,12 @@ inline int get_num_splits(Flash_fwd_params const& params) { // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending // that batch = 1. int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks; + // For debugging + // printf("num sm = %d.\n", params.num_sm); + // printf("bM = %d, bN = %d.\n", kBlockM, kBlockN); + // printf("num_n_blocks, num_m_blocks = %d, %d.\n", num_n_blocks, num_m_blocks); + // printf("total m blocks = %d.\n", total_mblocks); + // printf("seqlen_k = %d, size_one_kv_head = %d.\n", params.seqlen_k, size_one_kv_head); return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128); #endif } @@ -581,15 +614,16 @@ mha_fwd_get_scheduler_metadata( params.page_size = page_size.has_value() ? page_size.value() : 1; params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); - bool const use_dynamic_split = params.b <= 992; + // disable dynamic split if given explicit instructions to not split + bool const use_dynamic_split = params.b <= 992 && num_splits != 1; params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); - // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits) - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); // Always enable PackGQA for Split - params.pack_gqa = params.num_splits > 1; + params.pack_gqa |= params.num_splits > 1; + // printf("Num splits (metadata) = %d.\n", params.num_splits); bool is_varlen = true; @@ -748,6 +782,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "HeaddimV > 256 requires fp16 and bf16 data type"); } + #ifdef FLASHATTENTION_DISABLE_HDIMDIFF64 + TORCH_CHECK(head_size > 64, "This flash attention build does not support hdim != hdim_v when hdim <= 64"); + #endif + #ifdef FLASHATTENTION_DISABLE_HDIMDIFF192 + TORCH_CHECK(head_size <= 64, "This flash attention build does not support hdim != hdim_v when hdim in (128, 192]"); + #endif } // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM @@ -936,16 +976,16 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel - bool const use_dynamic_split = is_varlen && params.b <= 992; + bool const use_dynamic_split = is_varlen && params.b <= 992 && num_splits != 1; // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); - // Determine if we should pack GQA before num_splits since it impacts use_one_mma_wg (in get_num_splits) - params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; + // printf("Num splits = %d.\n", params.num_splits); + params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params); // Always enable PackGQA for Split - params.pack_gqa = params.num_splits > 1; + params.pack_gqa |= (params.num_splits > 1); // This needs to be set after get_num_splits at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic @@ -975,6 +1015,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v <= 256."); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); @@ -1093,9 +1134,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq } if(s_aux_.has_value()) { + TORCH_CHECK(params.arch == 90, "S aux is currently only supported for Hopper GPUs"); + TORCH_CHECK(num_heads <= 64, "We only support query heads <= 64 with S aux"); + TORCH_CHECK(head_size == head_size_v, "We don't support S aux with hdim != hdim_v"); auto s_aux = s_aux_.value(); TORCH_CHECK(s_aux.scalar_type() == at::ScalarType::BFloat16, - "We only support bf16 dtype for S extra."); + "We only support bf16 dtype for S aux."); CHECK_DEVICE(s_aux); CHECK_SHAPE(s_aux, num_heads); CHECK_CONTIGUOUS(s_aux); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index 2c9363300a5..616380b3d2d 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -27,7 +27,7 @@ using namespace cute; template + bool PackGQA, bool Split, bool V_colmajor, bool Use_one_mma_wg, int kBlockH=1> void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time"); static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time"); @@ -53,10 +53,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using ClusterShape = cute::Shape, _1, _1>; using CollectiveMainloop = std::conditional_t< Arch >= 90, - flash::CollectiveMainloopFwdSm90, + flash::CollectiveMainloopFwdSm90, flash::CollectiveMainloopFwdSm80 >; - using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; using SchedulerPersistent = std::conditional_t= 90 && kHeadDim == 128; - + // Avoid over compiliation by making sure this only get set if it is actually used, i.e. we currently only support one mma wg for 64/128 head dim and hopper + static constexpr bool Use_one_mma_wg = Use_one_mma_wg_ && Arch >= 90 && (kHeadDim == 128 || kHeadDim == 64) && (kHeadDimV == kHeadDim); // Only needed here to decide if we should use cluster static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap, Use_one_mma_wg)) : 128; - - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; - BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { + static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen && !Use_one_mma_wg; + QV_SWITCH(params.qv_ptr, HasQV_, [&] { static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] { // Only use Cluster if number of tiles along seqlen_q is even and not varlen CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] { static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1; - run_flash_fwd(params, stream); + int const qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k); + PACK_GQA_BLOCK_SWITCH(qhead_per_khead, kBlockH_, [&] { + // TODO: look at pack gqa tma for hdim diff + static constexpr int kBlockH = !PackGQA || Arch < 90 || (kHeadDim != kHeadDimV) ? 1 : kBlockH_; + run_flash_fwd(params, stream); + }); }); }); }); diff --git a/hopper/heuristics.h b/hopper/heuristics.h index 43d06f54825..621dc16bc34 100644 --- a/hopper/heuristics.h +++ b/hopper/heuristics.h @@ -8,8 +8,9 @@ #include "flash.h" inline bool use_one_mma_wg(Flash_fwd_params const& params) { - return params.arch >= 90 && params.d == 128 && - params.seqlen_q * (!params.pack_gqa ? 1 : params.h / params.h_k) <= 64; + // assume pack_gqa for seqlen calculation + return params.arch >= 90 && (params.d == 128 || params.d == 64) && + params.seqlen_q * (params.h / params.h_k) <= 64; }; inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) { diff --git a/hopper/mainloop_fwd_sm80.hpp b/hopper/mainloop_fwd_sm80.hpp index 4ce024f346a..7dcae77109d 100644 --- a/hopper/mainloop_fwd_sm80.hpp +++ b/hopper/mainloop_fwd_sm80.hpp @@ -329,6 +329,7 @@ struct CollectiveMainloopFwdSm80 { params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); int const n_block_min = get<0>(n_block_min_max); int const n_block_max = get<1>(n_block_min_max); + int const n_offset = get<2>(n_block_min_max); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier if constexpr (Is_causal || Is_local || Varlen || Split) { if (n_block_max <= n_block_min) { return false; } @@ -345,9 +346,9 @@ struct CollectiveMainloopFwdSm80 { int const bidb_kv = params.kv_batch_idx == nullptr ? bidb : params.kv_batch_idx[bidb]; Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q + seqlen_info.offset_q * get<0>(params.stride_Q)), params.shape_Q_packed, params.stride_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor gQ = local_tile(mQ, select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) - Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K + seqlen_info.offset_k * get<0>(params.stride_K)), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K + (seqlen_info.offset_k + n_offset) * get<0>(params.stride_K)), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); Tensor gK = local_tile(mK, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V + seqlen_info.offset_k * get<0>(params.stride_V)), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V + (seqlen_info.offset_k + n_offset) * get<0>(params.stride_V)), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb_kv : 0); Tensor gV = local_tile(mV, select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) GmemTiledCopyQKV gmem_tiled_copy_QKV; @@ -385,7 +386,7 @@ struct CollectiveMainloopFwdSm80 { for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } int const seqlen_q = seqlen_info.seqlen_q; - int const seqlen_k = seqlen_info.seqlen_k; + int const seqlen_k = seqlen_info.seqlen_k - n_offset; int n_block = n_block_max - 1; // Prologue: load Q, K, V @@ -423,7 +424,7 @@ struct CollectiveMainloopFwdSm80 { params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, params.page_size_divmod /*blockN_per_page_size_divmod, not used since we don't use TMA*/, - bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, + bidb_kv, bidh_kv, thread_idx, seqlen_k, seqlen_info.leftpad_k + n_offset, 0 /*bidb_kv_idx, not used since we don't use TMA for Sm8x*/ ); @@ -436,8 +437,8 @@ struct CollectiveMainloopFwdSm80 { // Instead of passing in tKVcKV, we pass in t0KVcKV and subtract the offset from the limit // (seqlen_k - n_block * kBlockN). This is because the entries of t0KVcKV are known at compile time. int const seqlenk_row_limit = -int(get<0>(tKVcKV(_0{}, _0{}, _0{}))) + (EvenN - ? seqlen_info.seqlen_k - n_block * kBlockN - : (!Seqlenk_mask ? kBlockN : std::min(seqlen_info.seqlen_k - n_block * kBlockN, kBlockN))); + ? seqlen_k - n_block * kBlockN + : (!Seqlenk_mask ? kBlockN : std::min(seqlen_k - n_block * kBlockN, kBlockN))); // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. flash::copy( gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK_cur, t0KVcKV, tKVpKV, seqlenk_row_limit); @@ -456,7 +457,7 @@ struct CollectiveMainloopFwdSm80 { // We don't call flash::copy since it doesn't support bound checking // to not overshot kBlockN when writing to smem. Tensor tVgV_cur = tVgV(_, _, _, n_block); - int const seqlenk_row_limit = seqlen_info.seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{})); + int const seqlenk_row_limit = seqlen_k - n_block * kBlockN - get<0>(tKVcKV(_0{}, _0{}, _0{})); #pragma unroll for (int m = 0; m < size<1>(tVsV); ++m) { // If kBlockN doesn't evenly divide the tiled copy, only the last `m` needs to be checked diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp index 0bdd4191538..50a2d7fb80f 100644 --- a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -30,7 +30,7 @@ using namespace cute; template + bool MmaPV_is_RS, bool IntraWGOverlap, bool PackGQA_, bool Split_, bool V_colmajor_, class ElementSAux_, int kBlockH_=1> struct CollectiveMainloopFwdSm90 { static constexpr int kStages = Stages; @@ -51,10 +51,11 @@ struct CollectiveMainloopFwdSm90 { static constexpr bool AppendKV = AppendKV_; static constexpr bool HasQv = HasQv_; static constexpr bool PackGQA = PackGQA_; + static constexpr bool PackGQA_TMA = PackGQA && kBlockH_ > 1; static constexpr bool Split = Split_; static constexpr bool V_colmajor = V_colmajor_; static constexpr bool Transpose_V = Is_FP8 && !V_colmajor; - static constexpr bool Use_TMA_Q = !PackGQA; + static constexpr bool Use_TMA_Q = !PackGQA || PackGQA_TMA; static constexpr bool Use_TMA_KV = !PagedKVNonTMA; static_assert(Use_TMA_KV || CUTE_STATIC_V(size(ClusterShape{})) == 1, "If not using TMA for KV, ClusterShape must be 1"); static_assert(Use_TMA_KV || !V_colmajor, "If not using TMA for KV, V_colmajor is not supported"); @@ -69,6 +70,7 @@ struct CollectiveMainloopFwdSm90 { static constexpr int kBlockM = get<0>(TileShape_MNK{}); static constexpr int kBlockN = get<1>(TileShape_MNK{}); static constexpr int kHeadDim = get<2>(TileShape_MNK{}); + static constexpr int kBlockH = kBlockH_; using SeqlenInfo_t = flash::SeqlenInfoQKNewK; using BlockMN_t = flash::BlockMN; @@ -241,7 +243,12 @@ struct CollectiveMainloopFwdSm90 { using StrideQK = cute::Stride; using StrideV = std::conditional_t>; // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits) - using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>>; + // using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>>; + using ShapeQPackedTMA = std::conditional_t, int32_t>, int32_t, int32_t, int32_t>>; + using ShapeQPacked = std::conditional_t, int32_t, int32_t, int32_t>, + ShapeQPackedTMA>; + using ShapeQvPacked = std::conditional_t, int32_t, int32_t, int32_t>>; using StrideQPacked = std::conditional_t, _1, int64_t, int64_t>>; using ShapePageTable = cute::Shape; // (batch, max_num_pages_per_seq) using StridePageTable = cute::Stride; @@ -251,7 +258,7 @@ struct CollectiveMainloopFwdSm90 { using TMA_Q = decltype(make_tma_copy_A_sm90( GmemTiledCopyQ{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQKV{}, StrideQK{}), + make_tensor(make_gmem_ptr(static_cast(nullptr)), ShapeQPackedTMA{}, StrideQPacked{}), SmemLayoutQ{}, TileShape_MNK{}, ClusterShape{})); @@ -427,7 +434,7 @@ struct CollectiveMainloopFwdSm90 { StrideV const stride_V_new; Element const* const ptr_Qv; StrideV const stride_Qv; - ShapeQPacked const shape_Qv_packed; + ShapeQvPacked const shape_Qv_packed; StrideQPacked const stride_Qv_packed; Element const* const ptr_rotary_cos; ShapeRotary const shape_rotary; @@ -466,7 +473,21 @@ struct CollectiveMainloopFwdSm90 { static Params to_underlying_arguments(Arguments const& args) { - Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); + // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) + int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); + auto const shape_Q_packed_tma = cute::conditional_return( + args.shape_Q, + make_shape(make_shape(Int{}, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)) + ); + auto const shape_Q_packed = cute::conditional_return( + make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)), + shape_Q_packed_tma + ); + auto const stride_Q_packed = cute::conditional_return( + args.stride_Q, + make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) + ); + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), shape_Q_packed_tma, stride_Q_packed); TMA_Q tma_load_Q = make_tma_copy_A_sm90( GmemTiledCopyQ{}, mQ, @@ -519,16 +540,7 @@ struct CollectiveMainloopFwdSm90 { return nullptr; } }(); - // If PackGQA, reshape Q to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size) - int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K)); - auto const shape_Q_packed = cute::conditional_return( - args.shape_Q, - make_shape(make_shape(qhead_per_khead, get<0>(args.shape_Q)), get<1>(args.shape_Q), get<2>(args.shape_K), get<3>(args.shape_Q)) - ); - auto const stride_Q_packed = cute::conditional_return( - args.stride_Q, - make_stride(make_stride(get<2>(args.stride_Q), get<0>(args.stride_Q)), get<1>(args.stride_Q), get<2>(args.stride_Q) * qhead_per_khead, get<3>(args.stride_Q)) - ); + auto const shape_Qv_packed = cute::conditional_return( shape_Qv, make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv)) @@ -545,6 +557,7 @@ struct CollectiveMainloopFwdSm90 { if (!PagedKVNonTMA && args.ptr_pagetable != nullptr) { assert(page_size % kBlockN == 0); assert(!args.leftpad_k); + assert(!Is_local); // Since we now use leftpad_k with local, we can't use TMA with PagedKV } // If there's tanh softcapping, we do tanh(scores * softmax_scale / softcap_val) * softcap_val. // Right after this, we multiply by log2(e) before applying exp2. @@ -612,7 +625,11 @@ struct CollectiveMainloopFwdSm90 { int const bidh = get<1>(block_coord); int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); - auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + // Update seqlen_info using n_offset: + // leftpad_k -> leftpad_k + n_offset + // offset_k -> offset_k + n_offset + // seqlen_k -> seqlen_k - n_offset + auto [n_block_min, n_block_max, n_offset] = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. Loading K can cause illegal memory access. @@ -658,15 +675,23 @@ struct CollectiveMainloopFwdSm90 { bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; - Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q_packed)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK_TMA = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, _); auto shape_V = make_shape(params.headdim_v, get<0>(params.shape_K), get<2>(params.shape_K), get<3>(params.shape_K)); Tensor mVt_TMA = params.tma_load_V.get_tma_tensor(shape_V)(_, _, bidh_kv, _); - Tensor gQ = local_tile(domain_offset(make_coord(seqlen_info.offset_q, _0{}), mQ), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gQ = local_tile( + domain_offset( + cute::conditional_return( + make_coord(seqlen_info.offset_q, _0{}), + make_coord(make_coord(_0{}, seqlen_info.offset_q), _0{})), + mQ), + select<0, 2>(TileShape_MNK{}), + make_coord(m_block, _0{})); // (M, K) // if (cute::thread0()) { printf("Varlen = %d, params.leftpad_k = %p, leftpad_k = %d\n", Varlen, params.leftpad_k, leftpad_k); } - Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _)); // (N, K, _, _) - Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _)); // (K, N, _, _) + // Now add n_offset to update KV gmem pointers + Tensor gK_TMA = local_tile(domain_offset(make_coord(seqlen_info.offset_k + n_offset, _0{}, _0{}), mK_TMA), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{}, _)); // (N, K, _, _) + Tensor gVt_TMA = local_tile(domain_offset(make_coord(_0{}, seqlen_info.offset_k + n_offset, _0{}), mVt_TMA), select<1, 2>(TileShape_MNK_PV{}), make_coord(_0{}, _, _)); // (K, N, _, _) auto block_tma_Q = params.tma_load_Q.get_slice(_0{}); Tensor tQgQ = group_modes<0, 3>(block_tma_Q.partition_S(gQ)); // (TMA) @@ -701,7 +726,7 @@ struct CollectiveMainloopFwdSm90 { params.ptr_K, params.shape_K, params.stride_K, params.ptr_V, params.headdim_v, params.stride_V, params.page_size_divmod, params.blockN_per_page_size_divmod, - bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k, seqlen_info.leftpad_k, bidb_kv_idx + bidb_kv, bidh_kv, thread_idx, seqlen_info.seqlen_k - n_offset, seqlen_info.leftpad_k + n_offset, bidb_kv_idx ); // Set up for transposing V, only used if Transpose_V @@ -979,7 +1004,7 @@ struct CollectiveMainloopFwdSm90 { int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); int const bidh_kv = !PackGQA ? params.qhead_per_khead_divmod.divide(bidh) : bidh; - auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + auto [n_block_min, n_block_max, n_offset] = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier @@ -1060,11 +1085,14 @@ struct CollectiveMainloopFwdSm90 { }; int const seqlen_q = seqlen_info.seqlen_q; - int const seqlen_k = seqlen_info.seqlen_k; + // Compute actual seqlen_k for this mma worktile + int const seqlen_k = seqlen_info.seqlen_k - n_offset; int n_block = n_block_max - 1; + // NOTE: sink_token_length is dead code + // But we subtract n_offset for consistency in mask calculations flash::Mask mask( - thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 /*sink_token_length*/, + thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 - n_offset /*sink_token_length*/, params.qhead_per_khead_divmod ); @@ -1246,6 +1274,7 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); + // If local, blocking (window_size_right + window_size_left) int const n_block_min_causal_local_mask = std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); #pragma unroll 1 @@ -1255,6 +1284,7 @@ struct CollectiveMainloopFwdSm90 { } int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; + // If local, blocking (m_idx_max - m_idx_min) int const n_block_min_before_local_mask = !Is_local ? n_block_min : std::max(n_block_min, @@ -1350,6 +1380,7 @@ struct CollectiveMainloopFwdSm90 { if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply(tSrS, m_block, n_block); }; int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM); + // If local, blocking (window_size_right + window_size_left) int const n_block_min_causal_local_mask = std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN); #pragma unroll 1 @@ -1358,6 +1389,7 @@ struct CollectiveMainloopFwdSm90 { } } int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1; + // If local, blocking (m_idx_max - m_idx_min) int const n_block_min_before_local_mask = !Is_local ? n_block_min : std::max(n_block_min, @@ -1413,7 +1445,7 @@ struct CollectiveMainloopFwdSm90 { int const m_block = get<0>(block_coord); int const bidb = get<2>(block_coord); int const split_idx = get<3>(block_coord); - auto [n_block_min, n_block_max] = BlockMN_t::get_n_block_min_max( + auto [n_block_min, n_block_max, n_offset] = BlockMN_t::get_n_block_min_max( seqlen_info, m_block, bidb, split_idx, params.num_splits, params.window_size_left, params.window_size_right, params.qhead_per_khead_divmod); // It's possible to have n_block_max <= n_block_min. We don't want to load Q or change any barrier diff --git a/hopper/setup.py b/hopper/setup.py index e12d98b7cff..611eac28532 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -64,6 +64,29 @@ ENABLE_VCOLMAJOR = os.getenv("FLASH_ATTENTION_ENABLE_VCOLMAJOR", "FALSE") == "TRUE" +DISABLE_HDIMDIFF64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" +DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" + +# DISABLE_BACKWARD = True +# DISABLE_SPLIT = True +# DISABLE_PAGEDKV = True +# DISABLE_APPENDKV = True +# DISABLE_LOCAL = True +# DISABLE_SOFTCAP = True +# DISABLE_PACKGQA = True +# DISABLE_FP16 = True +# DISABLE_FP8 = True +# DISABLE_VARLEN = True +# DISABLE_CLUSTER = True +# DISABLE_HDIM64 = True +# DISABLE_HDIM96 = True +# DISABLE_HDIM128 = True +# DISABLE_HDIM192 = True +# DISABLE_HDIM256 = True +# DISABLE_SM8x = True + +# DISABLE_HDIMDIFF64 = True +# DISABLE_HDIMDIFF192 = True # HACK: we monkey patch pytorch's _write_ninja_file to pass # "-gencode arch=compute_sm90a,code=sm_90a" to files ending in '_sm90.cu', @@ -467,10 +490,13 @@ def nvcc_threads_args(): + (["-DFLASHATTENTION_DISABLE_HDIM256"] if DISABLE_HDIM256 else []) + (["-DFLASHATTENTION_DISABLE_SM8x"] if DISABLE_SM8x else []) + (["-DFLASHATTENTION_ENABLE_VCOLMAJOR"] if ENABLE_VCOLMAJOR else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF64"] if DISABLE_HDIMDIFF64 else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF192"] if DISABLE_HDIMDIFF192 else []) ) DTYPE_FWD_SM80 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + (["e4m3"] if not DISABLE_FP8 else []) + HALF_DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_BWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) HEAD_DIMENSIONS_BWD = ( [] @@ -480,7 +506,28 @@ def nvcc_threads_args(): + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) - HEAD_DIMENSIONS_FWD = ["all", "diff"] + # HEAD_DIMENSIONS_FWD = ["all", "diff"] + # HEAD_DIMENSIONS_FWD = ( + # ["all"] + # + (["diff"] if not DISABLE_HDIMDIFF else []) + # ) + HEAD_DIMENSIONS_FWD = ( + [] + + ([64] if not DISABLE_HDIM64 else []) + + ([96] if not DISABLE_HDIM96 else []) + + ([128] if not DISABLE_HDIM128 else []) + + ([192] if not DISABLE_HDIM192 else []) + + ([256] if not DISABLE_HDIM256 else []) + ) + HEAD_DIMENSIONS_DIFF64_FWD = ( + [] + + (["64_256"] if not DISABLE_HDIMDIFF64 else []) + + (["64_512"] if not DISABLE_HDIMDIFF64 else []) + ) + HEAD_DIMENSIONS_DIFF192_FWD = ( + [] + + (["192_128"] if not DISABLE_HDIMDIFF192 else []) + ) HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else []) PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else []) @@ -494,6 +541,14 @@ def nvcc_threads_args(): sources_fwd_sm90 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF64: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF64_FWD, HALF_DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF192: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF192_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] sources_bwd_sm80 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm80.cu" for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP)] sources_bwd_sm90 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm90.cu" diff --git a/hopper/static_switch.h b/hopper/static_switch.h index 4701fa202ea..45bf0cf771e 100644 --- a/hopper/static_switch.h +++ b/hopper/static_switch.h @@ -61,6 +61,16 @@ }() #endif +#ifdef FLASHATTENTION_DISABLE_HDIMDIFF64 + #define QV_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define QV_SWITCH BOOL_SWITCH +#endif + #ifdef FLASHATTENTION_DISABLE_SOFTCAP #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ [&] { \ @@ -190,3 +200,22 @@ return __VA_ARGS__(); \ } \ }() + +#ifdef FLASH_ATTENTION_DISABLE_PACKGQA + #define PACK_GQA_BLOCK_SWITCH(QHEADS_PER_KHEADS, BLOCK_H, ...) \ + [&] { \ + constexpr static int BLOCK_H = 1; \ + return __VA_ARGS__(); \ + }() +#else + #define PACK_GQA_BLOCK_SWITCH(QHEADS_PER_KHEADS, BLOCK_H, ...) \ + [&] { \ + if (QHEADS_PER_KHEADS == 8) { \ + constexpr static int BLOCK_H = 8; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int BLOCK_H = 1; \ + return __VA_ARGS__(); \ + } \ + }() +#endif \ No newline at end of file diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index 4d20ff8af2b..9b390fa4313 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -37,6 +37,25 @@ DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" DISABLE_HDIM192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM192", "FALSE") == "TRUE" DISABLE_HDIM256 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM256", "FALSE") == "TRUE" +DISABLE_HDIMDIFF64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" +DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" + +DISABLE_BACKWARD = True +# DISABLE_SPLIT = True +# DISABLE_PAGEDKV = True +# DISABLE_APPENDKV = True +# DISABLE_LOCAL = True +# DISABLE_SOFTCAP = True +# DISABLE_PACKGQA = True +# DISABLE_FP16 = True +# DISABLE_FP8 = True +# DISABLE_HDIM64 = True +# DISABLE_HDIM96 = True +# DISABLE_HDIM128 = True +# DISABLE_HDIM192 = True +# DISABLE_HDIM256 = True +DISABLE_HDIMDIFF64 = True +# DISABLE_HDIMDIFF192 = True COMPILED_HDIMS = ( [] @@ -54,10 +73,10 @@ # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("has_qv_", [False] + ([True] if not DISABLE_HDIMDIFF64 else [])) +# @pytest.mark.parametrize("has_qv_", [True]) # @pytest.mark.parametrize("deterministic", [False, True]) -@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("deterministic", [True]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) # @pytest.mark.parametrize("softcap", [0.0]) @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) @@ -74,7 +93,9 @@ # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) @pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("test_sink", [False, True]) +# @pytest.mark.parametrize("test_sink", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -102,10 +123,14 @@ ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv, mha_type, dtype + seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv_, mha_type, dtype, test_sink ): if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + if has_qv_ and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") + if test_sink and has_qv_: + pytest.skip("Sink disabled for Qv") device = "cuda" # set seed torch.random.manual_seed(0) @@ -113,14 +138,24 @@ def test_flash_attn_output( # nheads = 16 batch_size = 9 if seqlen_k <= 2048 else 2 # batch_size = 1 - nheads = 6 + nheads = 16 # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - if dtype == torch.float8_e4m3fn: + if d == 192 and not DISABLE_HDIMDIFF192: + dv_vals = [128, d] + elif d == 64 and not DISABLE_HDIMDIFF64 and dtype != torch.float8_e4m3fn: + dv_vals = [256, 512, d] + else: + dv_vals = [d] + s_aux = torch.randn(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None + # s_aux = torch.ones(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None + # print("s_aux ", s_aux) + if test_sink: dv_vals = [d] for dv in dv_vals: + print("dv =", dv) + has_qv = has_qv_ and d == 64 and dv >= 256 q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -153,7 +188,8 @@ def test_flash_attn_output( qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, - softcap=softcap + softcap=softcap, + s_aux=s_aux, ) out_pt, attn_pt = attention_ref( q_ref, @@ -169,6 +205,7 @@ def test_flash_attn_output( upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + s_aux=s_aux, ) # qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float() @@ -199,8 +236,11 @@ def test_flash_attn_output( window_size=window_size, softcap=softcap, pack_gqa=pack_gqa, - num_splits=num_splits + num_splits=num_splits, + s_aux=s_aux, ) + print("Pack GQA =", pack_gqa) + print("Num splits =", num_splits) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") # if not causal: @@ -211,7 +251,7 @@ def test_flash_attn_output( # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv and not test_sink: g = torch.randn_like(out) do_o = ((g.float() * out.float()).sum(-1)).transpose(1, 2) # import flash_attn_3_cuda @@ -262,7 +302,7 @@ def test_flash_attn_output( # breakpoint() - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor and not has_qv and not test_sink: dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) @@ -277,8 +317,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("has_qv", [False] + ([True] if not DISABLE_HDIMDIFF64 else [])) +@pytest.mark.parametrize("has_qv_", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -297,6 +337,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", COMPILED_HDIMS) # @pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("test_sink", [False, True]) +# @pytest.mark.parametrize("test_sink", [True]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -323,8 +365,12 @@ def test_flash_attn_output( ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv_, mha_type, dtype, test_sink ): + if has_qv_ and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") + if test_sink and has_qv_: + pytest.skip("Sink disabled for Qv") device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) @@ -336,10 +382,20 @@ def test_flash_attn_varlen_output( # nheads = 1 nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype - dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - if dtype == torch.float8_e4m3fn: + if d == 192 and not DISABLE_HDIMDIFF192: + dv_vals = [128, d] + elif d == 64 and not DISABLE_HDIMDIFF64 and dtype != torch.float8_e4m3fn: + dv_vals = [256, 512, d] + else: + dv_vals = [d] + s_aux = torch.randn(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None + # s_aux = torch.ones(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None + # print("s_aux", s_aux) + if test_sink: dv_vals = [d] for dv in dv_vals: + print("dv =", dv) + has_qv = has_qv_ and d == 64 and dv >= 256 q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -416,7 +472,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): qv=qv_ref, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, window_size=window_size, - softcap=softcap + softcap=softcap, + s_aux=s_aux, ) out_pt, attn_pt = attention_ref( q_ref, @@ -432,6 +489,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + s_aux=s_aux, ) @@ -446,7 +504,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): rtol = 2 if softcap == 0.0 else 3 pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): out_unpad, lse = flash_attn_varlen_func( q_unpad, @@ -464,7 +522,10 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): k_descale=k_descale, v_descale=v_descale, window_size=window_size, softcap=softcap, + s_aux=s_aux, ) + print("Pack GQA =",pack_gqa) + print("Num splits =",num_splits) out = output_pad_fn(out_unpad) if query_unused_mask is not None: out.masked_fill_(q_zero_masking, 0.0) @@ -479,7 +540,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv and not test_sink: g_unpad = torch.randn_like(out_unpad) do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) # import flash_attn_3_cuda @@ -547,7 +608,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") # breakpoint() - if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv and not test_sink: dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) @@ -563,32 +624,35 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) -# @pytest.mark.parametrize("new_kv", [True]) +# @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) -# @pytest.mark.parametrize("causal,local", [(False, False)]) +# @pytest.mark.parametrize("causal,local", [(False, True)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) # @pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) -# @pytest.mark.parametrize("rotary_interleaved", [True]) +# @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) -# @pytest.mark.parametrize("page_size", [None]) +# @pytest.mark.parametrize("page_size", [4]) @pytest.mark.parametrize("has_leftpad", [False, True]) # @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False, True]) # @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("varlen_q", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize("d", COMPILED_HDIMS) @pytest.mark.parametrize("d", [64]) # @pytest.mark.parametrize("d", [192]) +# @pytest.mark.parametrize("test_sink", [False, True]) +@pytest.mark.parametrize("test_sink", [False]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -625,6 +689,7 @@ def test_flash_attn_kvcache( new_kv, mha_type, dtype, + test_sink, ): if page_size is not None and seqlen_k % page_size != 0: pytest.skip() @@ -635,6 +700,8 @@ def test_flash_attn_kvcache( if rotary_fraction == 0.0 and has_rotary_seqlens: pytest.skip() device = "cuda" + print("causal: ", causal) + print("local: ", local) # set seed torch.random.manual_seed(0) batch_size = 5 @@ -648,9 +715,12 @@ def test_flash_attn_kvcache( assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) - if dtype == torch.float8_e4m3fn: + # dv_vals = [d] + s_aux = torch.randn(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None + if dtype == torch.float8_e4m3fn and d != 192: dv_vals = [d] for dv in dv_vals: + print("dv =", dv) has_qv = d == 64 and dv >= 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: @@ -799,6 +869,7 @@ def test_flash_attn_kvcache( qv=qv, window_size=window_size, key_leftpad=cache_leftpad, + s_aux=s_aux, ) out_pt, _ = attention_ref( q_ro, @@ -812,7 +883,8 @@ def test_flash_attn_kvcache( upcast=False, reorder_ops=True, key_leftpad=cache_leftpad, - intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + s_aux=s_aux, ) q = q.to(dtype) q_unpad = q_unpad.to(dtype) if varlen_q else None @@ -830,12 +902,17 @@ def test_flash_attn_kvcache( sin = sin.to(dtype) if sin is not None else None k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() - num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] precompute_metadata_vals = [False, True] + # precompute_metadata_vals = [False] for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + print("Num splits = ",num_splits) + print("Precompute metadata = ",precompute_metadata) + # print("max seqlen_q, seqlen_q ", max_seqlen_q, seqlen_q) if precompute_metadata: + # WARNING: seqlen_k is not max_seqlen_k if using page table, so we can't expect this to make sense? scheduler_metadata = get_scheduler_metadata( - batch_size, seqlen_q, seqlen_k, nheads, nheads_k, d, + batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, max_seqlen_k_new=seqlen_new, page_size=page_size, @@ -868,13 +945,15 @@ def test_flash_attn_kvcache( cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, max_seqlen_q=max_seqlen_q, + # max_seqlen_k=max_seqlen_k, rotary_seqlens=rotary_seqlens, causal=causal, window_size=window_size, rotary_interleaved=rotary_interleaved, scheduler_metadata=scheduler_metadata, num_splits=num_splits, - return_softmax_lse=True + return_softmax_lse=True, + s_aux=s_aux, ) if varlen_q: out = output_pad_fn(out) diff --git a/hopper/test_util.py b/hopper/test_util.py index 8c10e2d5dba..6709b79248b 100644 --- a/hopper/test_util.py +++ b/hopper/test_util.py @@ -209,12 +209,13 @@ def attention_ref( upcast=True, reorder_ops=False, intermediate_dtype=None, + s_aux=None ): """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads, head_dim) - v: (batch_size, seqlen_k, nheads, head_dim_v) + k: (batch_size, seqlen_k, nheads_kv, head_dim) + v: (batch_size, seqlen_k, nheads_kv, head_dim_v) qv: (batch_size, seqlen_q, nheads, head_dim_v) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) @@ -227,16 +228,21 @@ def attention_ref( reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) without changing the math. This is to estimate the numerical error from operation reordering. + s_aux: (nheads) Output: output: (batch_size, seqlen_q, nheads, head_dim_v) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout """ + batch_size = q.shape[0] + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + nheads = q.shape[2] if causal: window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() qv = qv.float() if qv is not None else None + s_aux = s_aux.float() if s_aux is not None else None if q_descale is not None: q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) q = (q.float() * q_descale).to(q.dtype) @@ -245,7 +251,6 @@ def attention_ref( k = (k.float() * rearrange(k_descale, "b h -> b 1 h 1")).to(dtype=k.dtype) if v_descale is not None: v = (v.float() * rearrange(v_descale, "b h -> b 1 h 1")).to(dtype=v.dtype) - seqlen_q, seqlen_k = q.shape[1], k.shape[1] k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] @@ -275,7 +280,14 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias + if s_aux is not None: + # concatenate sink column before softmax + s_aux = s_aux.reshape(1, nheads, 1, 1).expand(batch_size, -1, seqlen_q, -1) + scores = torch.cat([scores, s_aux], dim=-1) attention = torch.softmax(scores, dim=-1).to(v.dtype) + if s_aux is not None: + # remove sink column + attention = attention[..., :-1] # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: diff --git a/hopper/tile_size.h b/hopper/tile_size.h index b87a83afff8..d63999c6384 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -20,9 +20,18 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim_v == 256) { return {128, 112, true, false}; } else { - // Switch to tile size 192 x 192 for now - bool const use_blockN_128 = is_causal || is_local; - return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; + if (use_one_mma_wg) { + return {64, 192, true, true}; + } else { + // Switch to tile size 192 x 192 for now + // bool const use_blockN_128 = is_causal || is_local; + // return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; // BASE + // Benefits SWA when window length <= 128 + return {192, is_causal ? 128 : is_local || paged_kv_non_TMA ? 160 : 192, is_causal || is_local, !is_local}; + // return {192, is_causal ? 128 : 160, true, !is_local}; + // return {128, use_blockN_128 ? 160 : 192, use_blockN_128, !use_blockN_128}; + // return {192, is_local ? 160 : 192, true, false}; + } } // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen // return {192, is_causal || is_local ? 192 : 176, true, false}; diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 5012cbc61bc..06de7fd17b7 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -416,63 +416,34 @@ def flash_attn_with_kvcache( cache_batch_idx = maybe_contiguous(cache_batch_idx) block_table = maybe_contiguous(block_table) - if fa_version == 2: - if scheduler_metadata is not None and q_descale is not None \ - and k_descale is not None and v_descale is not None: - raise NotImplementedError( - "FA2 does not support scheduler_metadata, q_descale, " - "k_descale, v_descale" - ) - if s_aux is not None: - raise NotImplementedError("FA2 does not support s_aux") - out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache( - q, k_cache, v_cache, - k, v, # k_new, v_new - cache_seqlens, - rotary_cos, - rotary_sin, - cache_batch_idx, - cache_leftpad, - block_table, - alibi_slopes, - out, - softmax_scale, - causal, - window_size[0], - window_size[1], - softcap, - rotary_interleaved, - num_splits, - ) - elif fa_version == 3: - assert alibi_slopes is None, "Alibi is not supported in FA3" - out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( - q, k_cache, v_cache, # q, k, v - k, v, # k_new, v_new - None, # q_v - out, - None, None, # cu_seqlens_q, cu_seqlens_k - None, # cu_seqlens_k_new - None, cache_seqlens, # seqused_q, seqused_k - None, None, # max_seqlen_q, max_seqlen_k - block_table, - cache_batch_idx, # kv_batch_idx - None, # leftpad_k - None, None, None, # rotary_cos, rotary_sin, seqlens_rotary - q_descale, k_descale, v_descale, - softmax_scale, - causal, - window_size[0], window_size[1], - softcap, - rotary_interleaved, # rotary_interleaved - scheduler_metadata, - num_splits, # num_splits - None, # pack_gqa - 0, # sm_margin - s_aux, # s_aux - ) - else: - raise ValueError(f"Unsupported FA version: {fa_version}") + if s_aux is not None: + raise NotImplementedError("FA2 does not support s_aux") + if scheduler_metadata is not None and q_descale is not None \ + and k_descale is not None and v_descale is not None: + raise NotImplementedError( + "FA2 does not support scheduler_metadata, q_descale, " + "k_descale, v_descale" + ) + + out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache( + q, k_cache, v_cache, + k, v, # k_new, v_new + cache_seqlens, + rotary_cos, + rotary_sin, + cache_batch_idx, + cache_leftpad, + block_table, + alibi_slopes, + out, + softmax_scale, + causal, + window_size[0], + window_size[1], + softcap, + rotary_interleaved, + num_splits, + ) return (out, softmax_lse) if return_softmax_lse else out