diff --git a/hopper/flash.h b/hopper/flash.h index bee89e5f054..6848e8c9dbd 100644 --- a/hopper/flash.h +++ b/hopper/flash.h @@ -152,10 +152,16 @@ struct Flash_fwd_params : public Qkv_params { bool pack_gqa; int * __restrict__ tile_count_semaphore; - // int * __restrict__ num_m_blocks_ptr; + int * __restrict__ num_m_blocks_ptr; // int * __restrict__ num_n_blocks_ptr; int * __restrict__ num_splits_dynamic_ptr; + int * __restrict__ varlen_batch_idx_ptr; // virtual -> actual + int * __restrict__ num_nheads_in_l2_ptr; bool skip_scheduler_metadata_computation; + bool varlen_sort_batches; + int tile_count_semaphore_offset; + bool head_swizzle; + bool prepare_varlen_pdl; int arch; int num_sm; diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 33185bf2304..8ffd0d0baf9 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -39,6 +39,8 @@ PyObject* PyInit__C(void) #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define PREPARE_VARLEN_MAX_BATCHES_1CTA 992 + void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, @@ -250,6 +252,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { if (params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); @@ -257,6 +260,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -268,11 +272,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -283,6 +289,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 if (params.d <= 64) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF64 if constexpr (Arch == 90) { if (params.dv > 256) { return run_mha_fwd_(params, stream); @@ -290,6 +297,7 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -301,11 +309,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_(params, stream); } } + #endif return run_mha_fwd_(params, stream); } #endif @@ -329,11 +339,13 @@ void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) { #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 if (params.d <= 192) { + #ifndef FLASHATTENTION_DISABLE_HDIMDIFF192 if constexpr (Arch == 90) { if (params.dv <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } } + #endif return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); } #endif @@ -525,8 +537,7 @@ mha_fwd_get_scheduler_metadata( bool has_softcap, int64_t num_splits, std::optional pack_gqa_, - int64_t sm_margin - ) { + int64_t sm_margin) { TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn, "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); @@ -585,8 +596,9 @@ mha_fwd_get_scheduler_metadata( params.page_size = page_size.has_value() ? page_size.value() : 1; params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast(1); - bool const use_dynamic_split = params.b <= 992; - params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + bool const use_prepare_varlen = true; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; @@ -603,18 +615,35 @@ mha_fwd_get_scheduler_metadata( // This needs to be set after get_num_splits at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1; - if (scheduler_needs_semaphore || use_dynamic_split) { - tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32)); + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + // printf("(Metadata) num prepare batch vectors = %d.\n", num_prepare_batch_vectors); + tile_count_semaphore = torch::empty( + {int(scheduler_needs_semaphore) + tile_count_semaphore_offset}, + opts.dtype(torch::kInt32)); + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; if (scheduler_needs_semaphore) { - if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing - params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + if (!use_prepare_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing + params.tile_count_semaphore = tile_count_semaphore.data_ptr() + tile_count_semaphore_offset; } else { params.tile_count_semaphore = nullptr; } - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; } - if (params.num_splits_dynamic_ptr) { + if (use_prepare_varlen) { auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f); auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr); int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x); @@ -938,11 +967,11 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql params.cu_seqlens_knew = static_cast(cu_seqlens_k_new.data_ptr()); } } - - // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel - bool const use_dynamic_split = is_varlen && params.b <= 992; + + bool const use_prepare_varlen = is_varlen; + params.prepare_varlen_pdl = use_prepare_varlen && params.b <= PREPARE_VARLEN_MAX_BATCHES_1CTA; // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it - params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast(1); + params.num_splits_dynamic_ptr = !use_prepare_varlen ? nullptr : reinterpret_cast(1); params.pagedkv_tma = get_pagedkv_tma(params); params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits; @@ -955,8 +984,17 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql bool const scheduler_needs_semaphore = params.arch >= 90 ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); - if (scheduler_needs_semaphore || use_dynamic_split) { - int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b; + params.varlen_sort_batches = !params.is_local; // Use this value for Sort in scheduler template + params.head_swizzle = params.is_causal || params.is_local; // Use this value for LPT in scheduler template + if (scheduler_needs_semaphore || use_prepare_varlen) { + int b_rounded = round_multiple(params.b, 4); // for 16 byte alignment of pointers + int num_prepare_batch_vectors = use_prepare_varlen ? 2 : 0; + if(params.varlen_sort_batches) { num_prepare_batch_vectors += 1; } + if(params.head_swizzle) { num_prepare_batch_vectors += 1; } + int head_swizzle_offset = b_rounded * (params.varlen_sort_batches ? 3 : 2); + int tile_count_semaphore_offset = b_rounded * num_prepare_batch_vectors; + int metadata_size = int(scheduler_needs_semaphore) + tile_count_semaphore_offset; + // printf("Num prepare batch vectors = %d, metadata_size = %d.\n", num_prepare_batch_vectors, metadata_size); params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value(); if (scheduler_metadata_.has_value()) { at::Tensor scheduler_metadata = scheduler_metadata_.value(); @@ -968,15 +1006,22 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql } else { tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32)); } - if (scheduler_needs_semaphore && !use_dynamic_split) { + if (scheduler_needs_semaphore && !use_prepare_varlen) { tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing } - params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() : nullptr; - params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr() + 1 : nullptr; + // {num_splits_dynamic, num_m_blocks, varlen_batch_idx, num_nheads_in_l2} + params.num_splits_dynamic_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() : nullptr; + params.num_m_blocks_ptr = use_prepare_varlen ? tile_count_semaphore.data_ptr() + b_rounded : nullptr; + params.varlen_batch_idx_ptr = use_prepare_varlen && params.varlen_sort_batches ? tile_count_semaphore.data_ptr() + b_rounded * 2 : nullptr; + // params.num_n_blocks_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.num_nheads_in_l2_ptr = use_prepare_varlen && params.head_swizzle ? tile_count_semaphore.data_ptr() + head_swizzle_offset : nullptr; + params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr() + tile_count_semaphore_offset : nullptr; + params.tile_count_semaphore_offset = tile_count_semaphore_offset; // might need to zero out semaphore later } if (q_v_.has_value()) { TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK(head_size_v >= 256, "q_v is only supported for hdim_v >= 256."); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, "q_v is only supported for fp16 and bf16 data type"); TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); @@ -1134,7 +1179,7 @@ mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seql run_mha_fwd_combine(params, stream, true /*enable_pdl*/); } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) { // need to zero out the semaphore in this case - tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_(); + tile_count_semaphore.index({torch::indexing::Slice(params.tile_count_semaphore_offset, params.tile_count_semaphore_offset + 1)}).zero_(); } } else if (total_q > 0 && num_heads_k > 0) { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 5547f426da5..a2eb9594896 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -50,7 +50,8 @@ def _flash_attn_forward( scheduler_metadata=None, num_splits=1, pack_gqa=None, - sm_margin=0): + sm_margin=0, + ): q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)] v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [ diff --git a/hopper/flash_fwd_combine_kernel.h b/hopper/flash_fwd_combine_kernel.h index a22e05969d9..05667698006 100644 --- a/hopper/flash_fwd_combine_kernel.h +++ b/hopper/flash_fwd_combine_kernel.h @@ -145,6 +145,7 @@ class FlashAttnFwdCombine { int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; int* const semaphore_to_reset = nullptr; }; @@ -164,6 +165,7 @@ class FlashAttnFwdCombine { int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; int* const semaphore_to_reset = nullptr; }; @@ -187,7 +189,9 @@ class FlashAttnFwdCombine { args.cu_seqlens, args.seqused, args.num_splits_dynamic_ptr, - args.semaphore_to_reset + args.varlen_batch_idx_ptr, + args.semaphore_to_reset, + }; } @@ -203,8 +207,9 @@ class FlashAttnFwdCombine { int const thread_idx = threadIdx.x; int const m_block = blockIdx.x; int const k_block = blockIdx.y; - int const batch = blockIdx.z; - int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial); + int const maybe_virtual_batch = blockIdx.z; + int const batch = params.varlen_batch_idx_ptr ? params.varlen_batch_idx_ptr[maybe_virtual_batch] : maybe_virtual_batch; + int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[maybe_virtual_batch] : get<1>(params.shape_LSE_partial); if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) { cutlass::arch::wait_on_dependent_grids(); diff --git a/hopper/flash_fwd_combine_launch_template.h b/hopper/flash_fwd_combine_launch_template.h index 11d422924b4..a2ff25dcd5f 100644 --- a/hopper/flash_fwd_combine_launch_template.h +++ b/hopper/flash_fwd_combine_launch_template.h @@ -35,7 +35,7 @@ void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool e {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O static_cast(params.softmax_lse_ptr), {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE - params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore + params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.varlen_batch_idx_ptr, params.tile_count_semaphore }; typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args); diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h index b8af2977f11..d48a4fd9562 100644 --- a/hopper/flash_fwd_launch_template.h +++ b/hopper/flash_fwd_launch_template.h @@ -57,8 +57,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { using CollectiveEpilogue = flash::CollectiveEpilogueFwd; static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads; + static constexpr bool LPT = Is_causal || Is_local; + static constexpr bool Sort = !Is_local; using SchedulerPersistent = std::conditional_t= 90 /*WarpSpecialized*/>, + flash::VarlenDynamicPersistentTileScheduler= 90 /*WarpSpecialized*/, LPT, Sort, true /*Prepared*/>, std::conditional_t, flash::DynamicPersistentTileScheduler= 90 /*WarpSpecialized*/> @@ -149,14 +151,16 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits, params.h / params.h_k, params.seqlen_q, - params.seqlen_k, params.d, params.dv, sizeof(Element), + params.seqlen_k, params.d, params.dv, sizeof(Element), params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q, - // params.num_m_blocks_ptr, params.num_splits_dynamic_ptr, + params.num_m_blocks_ptr, + params.varlen_batch_idx_ptr, + params.num_nheads_in_l2_ptr }; - if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) { - prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/); + if (Varlen && !params.skip_scheduler_metadata_computation) { + prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 && params.prepare_varlen_pdl /*enable_pdl*/); CHECK_CUDA_KERNEL_LAUNCH(); } @@ -189,7 +193,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } // kernel<<>>(kernel_params); cutlass::kernel_launch(grid_dims, block_dims, smem_size, stream, kernel_params, - Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/); + Arch >= 90 && Varlen && !params.skip_scheduler_metadata_computation && params.prepare_varlen_pdl /*launch_with_pdl*/); } CHECK_CUDA_KERNEL_LAUNCH(); } @@ -205,7 +209,6 @@ void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] { // Only needed here to decide if we should use cluster static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128; - static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen; BOOL_SWITCH(params.qv_ptr, HasQV_, [&] { static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256; diff --git a/hopper/flash_prepare_scheduler.cu b/hopper/flash_prepare_scheduler.cu index 7093fff32b6..1d810c015ed 100644 --- a/hopper/flash_prepare_scheduler.cu +++ b/hopper/flash_prepare_scheduler.cu @@ -2,6 +2,7 @@ * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. ******************************************************************************/ +#include #include "cutlass/fast_math.h" #include "cutlass/barrier.h" #include "cutlass/arch/barrier.h" @@ -10,8 +11,35 @@ #include "flash.h" +#include "static_switch.h" + namespace flash { +// Sort in descending order +template +struct PrepareSortOp +{ + __device__ __forceinline__ bool operator()(T const & lhs, T const & rhs) + { + return lhs > rhs; + } +}; + +template <> +struct PrepareSortOp { + __device__ __forceinline__ bool operator()(int2 const & lhs, int2 const & rhs) const { + return lhs.x > rhs.x; + } +}; + +template <> +struct PrepareSortOp { + __device__ __forceinline__ bool operator()(int4 const & lhs, int4 const & rhs) const { + return lhs.x > rhs.x; + } +}; + +template __global__ void prepare_varlen_num_blocks_kernel( int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static, int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new, @@ -19,16 +47,28 @@ __global__ void prepare_varlen_num_blocks_kernel( int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static, cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod, int* const tile_count_semaphore, - // int* const num_m_blocks_ptr, + int* const num_m_blocks_ptr, int* const num_splits_dynamic_ptr, - bool enable_pdl) { + int* const varlen_batch_idx_ptr, + // int* const num_n_blocks_ptr, + int* const num_nheads_in_l2_ptr, + bool enable_pdl, + bool is_causal, + bool packgqa, + int max_kvblocks_in_l2) { static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1; static constexpr int kSmemSize = 1; - // Assume that there's only one block in the grid + static constexpr int BLOCK_DIM_X = NumWarps * 32; + static constexpr int ITEMS_PER_THREAD = 1; + static_assert(BLOCK_DIM_X * ITEMS_PER_THREAD == NumWarps * 32); + using BlockMergeSort = cub::BlockMergeSort; + __shared__ int total_blocks_smem[kSmemSize]; - // There's only 1 block in the grid, so might as well start launching the main attn kernel + // Allocate shared memory for BlockMergeSort operations + __shared__ typename BlockMergeSort::TempStorage temp_storage; + if (enable_pdl) { cutlass::arch::launch_dependent_grids(); } if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; } @@ -38,8 +78,7 @@ __global__ void prepare_varlen_num_blocks_kernel( int lane = threadIdx.x % cutlass::NumThreadsPerWarp; - auto get_num_m_blocks = [&](int bidb_start) { - int batch_idx = lane + bidb_start; + auto get_num_m_blocks = [&](int batch_idx) { int seqlen; if (seqused_q) { seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0; @@ -50,13 +89,12 @@ __global__ void prepare_varlen_num_blocks_kernel( } else { seqlen = seqlen_q_static; } - seqlen *= qhead_per_khead; + if(packgqa) { seqlen *= qhead_per_khead; } return batch_idx < num_batch && lane < kNumBatchPerWarp ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0; }; - auto get_num_n_blocks = [&](int bidb_start) { - int batch_idx = lane + bidb_start; + auto get_num_n_blocks = [&](int batch_idx) { int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0; int seqlen; if (seqused_k) { @@ -83,42 +121,130 @@ __global__ void prepare_varlen_num_blocks_kernel( }; int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp; - int bidb_start = kNumBatchPerWarp * warp_idx; - int num_m_blocks = get_num_m_blocks(bidb_start); - int num_n_blocks = get_num_n_blocks(bidb_start); - - int total_blocks = num_m_blocks * num_n_blocks; - // Warp sum - #pragma unroll - for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { - total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + int batch_cta_idx_offset = int(blockIdx.x) * 992; + int bidb_start = batch_cta_idx_offset + kNumBatchPerWarp * warp_idx; + int batch_idx = lane + bidb_start; + int num_m_blocks = get_num_m_blocks(batch_idx); + int num_n_blocks = get_num_n_blocks(batch_idx); + + auto get_nheads_in_l2 = [&](int n_blocks) { + int nheads_in_l2 = n_blocks * 16 <= max_kvblocks_in_l2 ? 16 + : n_blocks * 8 <= max_kvblocks_in_l2 ? 8 + : n_blocks * 4 <= max_kvblocks_in_l2 ? 4 + : n_blocks * 2 <= max_kvblocks_in_l2 ? 2 + : 1; + if(!packgqa) { nheads_in_l2 *= qhead_per_khead; } + return min(nheads_in_l2, num_head); + }; + + int num_splits_dynamic; + if (int(gridDim.x) > 1 || num_splits_static == 1) { + // set num splits for all batches to 1 (note that user expects num_splits_static to mean upper bound on splits) + // for batch size > 992, we expect GPU occupancy to not be an issue except in degenerate cases (e.g., most are zero-length) + num_splits_dynamic = 1; + } else { + int total_blocks = num_m_blocks * num_n_blocks; + // Warp sum + #pragma unroll + for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) { + total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i); + } + if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } + __syncthreads(); + total_blocks = total_blocks_smem[0]; + // 10% margin + int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); + // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM + num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); + // num_n_blocks per work tile for the batch + num_n_blocks = cutlass::ceil_div(num_n_blocks, num_splits_dynamic); } - if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); } - __syncthreads(); - total_blocks = total_blocks_smem[0]; - // 10% margin - int blocks_per_sm = static_cast(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm))); - // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM - int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1); - if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) { - num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic; - // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + + if constexpr (Sort) { + if(lane == kNumBatchPerWarp || batch_idx >= num_batch) { + num_n_blocks = INT_MIN; // sort last + } else if (is_causal) { + // sort by shortest member to process + num_n_blocks = num_n_blocks * blockn_divmod.divisor - num_m_blocks * blockm_divmod.divisor; + } + int4 batch_coords[ITEMS_PER_THREAD]; // 1 item per thread + batch_coords[0] = make_int4(num_n_blocks, num_m_blocks, num_splits_dynamic, batch_idx); + + // if (threadIdx.x == 0) { + // printf("Unsorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); + + // Sort batches by num_n_blocks in descending order + BlockMergeSort(temp_storage).Sort(batch_coords, PrepareSortOp()); + + // if (threadIdx.x == 0) { + // printf("Sorted: num_n_blocks - num_m_blocks = %d, num_m_blocks = %d, num_splits = %d, batch_idx = %d.\n", + // batch_coords[0].x, batch_coords[0].y, batch_coords[0].z, batch_coords[0].w); + // } __syncthreads(); + + if (is_causal) { + // reset value to num_n_blocks + batch_coords[0].x = blockn_divmod.div(batch_coords[0].x + batch_coords[0].y * blockm_divmod.divisor); + } + + // When sorting, we re-index some metadata by 'virtual batch index' + // and also store the vbidx -> bidx mapping. + // 1. num_nheads_in_l2_ptr: virtual_batch_idx -> num_nheads_in_l2[batch_idx] + // 2. num_splits_dynamic_ptr: virtual_batch_idx -> num_splits[batch_idx] + // 3. num_m_blocks_ptr: virtual_batch_idx -> num_m_blocks[batch_idx] + // 4. varlen_batch_idx_ptr: virtual_batch_idx -> batch_idx + batch_idx = batch_cta_idx_offset + threadIdx.x; + if (batch_idx < num_batch && threadIdx.x < 992) { + // num_n_blocks_ptr[threadIdx.x] = max(batch_coords[0].x, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(batch_coords[0].x, 1)); } + num_m_blocks_ptr[batch_idx] = batch_coords[0].y; + num_splits_dynamic_ptr[batch_idx] = batch_coords[0].z; + varlen_batch_idx_ptr[batch_idx] = batch_coords[0].w; + } + } else { + if (batch_idx < num_batch && lane < kNumBatchPerWarp) { + // num_n_blocks_ptr[batch_idx] = max(num_n_blocks, 1); + if(num_nheads_in_l2_ptr) { num_nheads_in_l2_ptr[batch_idx] = get_nheads_in_l2(max(num_n_blocks, 1)); } + num_splits_dynamic_ptr[batch_idx] = num_splits_dynamic; + num_m_blocks_ptr[batch_idx] = num_m_blocks; + // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic); + } } + } } // flash void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl) { - // Only support batch <= 992 (32 warps, each with 31 batches) - int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k); - flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>( - params.seqlen_q, params.seqlen_k, params.seqlen_knew, - params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, - params.seqused_q, params.seqused_k, params.leftpad_k, - params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, - cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), - params.tile_count_semaphore, - // params.num_m_blocks_ptr, - params.num_splits_dynamic_ptr, enable_pdl); + int qhead_per_khead = cutlass::ceil_div(params.h, params.h_k); + int num_warps = cutlass::ceil_div(params.b, 31); // warp switch will cap this at 32 + int num_ctas = cutlass::ceil_div(params.b, 31 * 32); + // int const size_l2 = 50 * 1024 * 1024; // 50 MB + int const size_l2 = 8 * 1024 * 1024; // underestimate seems better in practice + int const element_size = params.is_e4m3 ? 1 : 2; + int const size_one_kvblock = blockN * (params.d + params.dv) * element_size; + // printf("block size = %d, element size = %d, headdim = %d, headdim_v = %d, size 1 kblock = %d.\n", blockN, element_size, params.d, params.dv, size_one_kvblock); + int const max_kvblocks_in_l2 = size_l2 / size_one_kvblock; + BOOL_SWITCH(params.varlen_sort_batches, Sort, [&] { + NUM_WARP_SWITCH(num_warps, NumWarps, [&] { + flash::prepare_varlen_num_blocks_kernel<<>>( + params.seqlen_q, params.seqlen_k, params.seqlen_knew, + params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew, + params.seqused_q, params.seqused_k, params.leftpad_k, + params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits, + cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN), + params.tile_count_semaphore, + params.num_m_blocks_ptr, + params.num_splits_dynamic_ptr, + params.varlen_batch_idx_ptr, + // params.num_n_blocks_ptr, + params.num_nheads_in_l2_ptr, + enable_pdl, + params.is_causal, + packgqa, + max_kvblocks_in_l2); + }); + }); } diff --git a/hopper/setup.py b/hopper/setup.py index c15c438f56c..850fb0b520c 100644 --- a/hopper/setup.py +++ b/hopper/setup.py @@ -64,6 +64,8 @@ ENABLE_VCOLMAJOR = os.getenv("FLASH_ATTENTION_ENABLE_VCOLMAJOR", "FALSE") == "TRUE" +DISABLE_HDIMDIFF64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF64", "FALSE") == "TRUE" +DISABLE_HDIMDIFF192 = os.getenv("FLASH_ATTENTION_DISABLE_HDIMDIFF192", "FALSE") == "TRUE" # HACK: we monkey patch pytorch's _write_ninja_file to pass # "-gencode arch=compute_sm90a,code=sm_90a" to files ending in '_sm90.cu', @@ -468,10 +470,13 @@ def nvcc_threads_args(): + (["-DFLASHATTENTION_DISABLE_HDIM256"] if DISABLE_HDIM256 else []) + (["-DFLASHATTENTION_DISABLE_SM8x"] if DISABLE_SM8x else []) + (["-DFLASHATTENTION_ENABLE_VCOLMAJOR"] if ENABLE_VCOLMAJOR else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF64"] if DISABLE_HDIMDIFF64 else []) + + (["-DFLASHATTENTION_DISABLE_HDIMDIFF192"] if DISABLE_HDIMDIFF192 else []) ) DTYPE_FWD_SM80 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) + (["e4m3"] if not DISABLE_FP8 else []) + HALF_DTYPE_FWD_SM90 = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) DTYPE_BWD = ["bf16"] + (["fp16"] if not DISABLE_FP16 else []) HEAD_DIMENSIONS_BWD = ( [] @@ -481,7 +486,18 @@ def nvcc_threads_args(): + ([192] if not DISABLE_HDIM192 else []) + ([256] if not DISABLE_HDIM256 else []) ) - HEAD_DIMENSIONS_FWD = ["all", "diff"] + # build will now explode with this compilation grouping given all our templating + # HEAD_DIMENSIONS_FWD = ["all", "diff"] + HEAD_DIMENSIONS_FWD = HEAD_DIMENSIONS_BWD + HEAD_DIMENSIONS_DIFF64_FWD = ( + [] + + (["64_256"] if not DISABLE_HDIMDIFF64 else []) + + (["64_512"] if not DISABLE_HDIMDIFF64 else []) + ) + HEAD_DIMENSIONS_DIFF192_FWD = ( + [] + + (["192_128"] if not DISABLE_HDIMDIFF192 else []) + ) HEAD_DIMENSIONS_FWD_SM80 = HEAD_DIMENSIONS_BWD SPLIT = [""] + (["_split"] if not DISABLE_SPLIT else []) PAGEDKV = [""] + (["_paged"] if not DISABLE_PAGEDKV else []) @@ -495,6 +511,14 @@ def nvcc_threads_args(): sources_fwd_sm90 = [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF64: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF64_FWD, HALF_DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] + if not DISABLE_HDIMDIFF192: + sources_fwd_sm90 += [f"instantiations/flash_fwd_hdim{hdim}_{dtype}{paged}{split}{softcap}{packgqa}_sm90.cu" + for hdim, dtype, split, paged, softcap, packgqa in itertools.product(HEAD_DIMENSIONS_DIFF192_FWD, DTYPE_FWD_SM90, SPLIT, PAGEDKV, SOFTCAP, PACKGQA) + if not (packgqa and (paged or split))] sources_bwd_sm80 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm80.cu" for hdim, dtype, softcap in itertools.product(HEAD_DIMENSIONS_BWD, DTYPE_BWD, SOFTCAP)] sources_bwd_sm90 = [f"instantiations/flash_bwd_hdim{hdim}_{dtype}{softcap}_sm90.cu" diff --git a/hopper/static_switch.h b/hopper/static_switch.h index 5e13b5f93a8..15a7d51364b 100644 --- a/hopper/static_switch.h +++ b/hopper/static_switch.h @@ -179,3 +179,26 @@ return __VA_ARGS__(); \ } \ }() + +#define NUM_WARP_SWITCH(VALUE, CONST_NAME, ...) \ + [&] { \ + if (VALUE <= 1) { \ + constexpr static int CONST_NAME = 1; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 2) { \ + constexpr static int CONST_NAME = 2; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 4) { \ + constexpr static int CONST_NAME = 4; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 8) { \ + constexpr static int CONST_NAME = 8; \ + return __VA_ARGS__(); \ + } else if (VALUE <= 16) { \ + constexpr static int CONST_NAME = 16; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int CONST_NAME = 32; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index f1247e689da..0b5a0e2af98 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -55,8 +55,8 @@ # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [True]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -75,7 +75,7 @@ # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128, 192]) @pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -107,6 +107,8 @@ def test_flash_attn_output( ): if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn): pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn") + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(0) @@ -121,8 +123,11 @@ def test_flash_attn_output( dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] + if has_qv: + dv_vals = [256, 512] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -193,6 +198,7 @@ def test_flash_attn_output( pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") out = flash_attn_func( q, k, @@ -286,8 +292,8 @@ def test_flash_attn_output( # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) -# @pytest.mark.parametrize("has_qv", [False, True]) -@pytest.mark.parametrize("has_qv", [False]) +@pytest.mark.parametrize("has_qv", [False, True]) +# @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) @@ -295,7 +301,7 @@ def test_flash_attn_output( @pytest.mark.parametrize("local", [False] + ([True] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) -# @pytest.mark.parametrize("causal", [False]) +# @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("add_unused_qkv", [False, True]) # @pytest.mark.parametrize("add_unused_qkv", [True]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) @@ -305,7 +311,7 @@ def test_flash_attn_output( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) # @pytest.mark.parametrize("d", [64, 96, 128]) @pytest.mark.parametrize("d", COMPILED_HDIMS) -# @pytest.mark.parametrize("d", [128]) +# @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -328,28 +334,38 @@ def test_flash_attn_output( (1024, 1024), (1023, 1024), (1024, 1023), + (1024, 1024), (2048, 2048), + (4096, 4096), ], ) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype + seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv, mha_type, dtype, ): + if has_qv and (d != 64 or dtype == torch.float8_e4m3fn): + pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)") device = "cuda" # set seed torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) # batch_size = 40 # nheads = 16 batch_size = 9 if seqlen_q <= 2048 else 2 + # batch_size = 32 nheads = 6 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) # batch_size = 2 # nheads = 1 - nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + # nheads_kv = nheads + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) if dtype == torch.float8_e4m3fn: dv_vals = [d] + if has_qv: + dv_vals = [256, 512] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if seqlen_q <= seqlen_k and not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") q_ref = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref) if softcap > 0.0: # Ensure the values of qk are at least within softcap range. @@ -458,8 +474,15 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): rtol = 2 if softcap == 0.0 else 3 pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] - num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + # pack_gqa_vals = [False] + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] + # num_splits_vals = [1] + # print("cu_seqlens_q: ", cu_seqlens_q) + # print("cu_seqlens_k: ", cu_seqlens_k) + # print("seqused_q: ", seqused_q) + # print("seqused_k: ", seqused_k) for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + print(f"{pack_gqa = }, {num_splits = }") out_unpad = flash_attn_varlen_func( q_unpad, k_unpad, @@ -477,6 +500,8 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): window_size=window_size, attention_chunk=attention_chunk, softcap=softcap, + pack_gqa=pack_gqa, + num_splits=num_splits, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -580,16 +605,16 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) # @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) -# @pytest.mark.parametrize("new_kv", [True]) +# @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("causal,local", [(False, False), (True, False)] + ([(False, True)] if not DISABLE_LOCAL else [])) # @pytest.mark.parametrize("causal,local", [(False, False), (True, False)]) -# @pytest.mark.parametrize("causal,local", [(False, False)]) +# @pytest.mark.parametrize("causal,local", [(True, False)]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False] if not DISABLE_APPENDKV else [True]) -# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) -@pytest.mark.parametrize("has_rotary_seqlens", [False, True]) -# @pytest.mark.parametrize("has_rotary_seqlens", [False]) +# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False]) +# @pytest.mark.parametrize("has_rotary_seqlens", [False, True]) +@pytest.mark.parametrize("has_rotary_seqlens", [False]) @pytest.mark.parametrize("rotary_interleaved", [False, True] if not DISABLE_APPENDKV else [False]) -# @pytest.mark.parametrize("rotary_interleaved", [True]) +# @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0] if (not DISABLE_APPENDKV) and (apply_rotary_emb is not None) else [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128] if not DISABLE_PAGEDKV else [])) @@ -597,9 +622,9 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): @pytest.mark.parametrize("has_leftpad", [False, True]) # @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False, True]) -# @pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("has_batch_idx", [True]) @pytest.mark.parametrize("varlen_q", [False, True]) -# @pytest.mark.parametrize("varlen_q", [False]) +# @pytest.mark.parametrize("varlen_q", [True]) # @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -669,6 +694,7 @@ def test_flash_attn_kvcache( dv_vals = [d] attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) and not DISABLE_LOCAL else [0] for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals): + print(f"{dv = }, {attention_chunk = }") has_qv = d == 64 and dv >= 256 q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref).to(dtype).to(dtype_ref) if has_qv: @@ -850,17 +876,21 @@ def test_flash_attn_kvcache( sin = sin.to(dtype) if sin is not None else None k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone() v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone() - num_splits_vals = [1, 0] if not DISABLE_SPLIT else [1] + num_splits_vals = [1, 3, 0] if not DISABLE_SPLIT else [1] precompute_metadata_vals = [False, True] for num_splits, precompute_metadata in itertools.product(num_splits_vals, precompute_metadata_vals): + print(f"{num_splits = }, {precompute_metadata = }") if precompute_metadata: scheduler_metadata = get_scheduler_metadata( - batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d, + batch_size, + max_seqlen_q if varlen_q else seqlen_q, + seqlen_k if page_size is None else page_table.shape[1] * page_size, + nheads, nheads_k, d, cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad, max_seqlen_k_new=seqlen_new, page_size=page_size, causal=causal, window_size=window_size, attention_chunk=attention_chunk, - num_splits=num_splits + num_splits=num_splits, ) else: scheduler_metadata = None @@ -895,7 +925,7 @@ def test_flash_attn_kvcache( rotary_interleaved=rotary_interleaved, scheduler_metadata=scheduler_metadata, num_splits=num_splits, - return_softmax_lse=True + return_softmax_lse=True, ) if varlen_q: out = output_pad_fn(out) diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 1f90f66adc2..41e0bab1624 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -24,8 +24,11 @@ struct TileSchedulerArguments { int* const tile_count_semaphore = nullptr; int const* const cu_seqlens = nullptr; int const* const seqused = nullptr; - // int const* const num_m_blocks_ptr = nullptr; int const* const num_splits_dynamic_ptr = nullptr; + int const* const num_m_blocks_ptr = nullptr; + int const* const varlen_batch_idx_ptr = nullptr; + // int const* const num_n_blocks_ptr = nullptr; + int const* const num_nheads_in_l2_ptr = nullptr; }; /////////////////////////////////////////////////////////////////////////////// @@ -463,7 +466,8 @@ class SingleTileBwdLPTScheduler { /////////////////////////////////////////////////////////////////////////////// -template +template class VarlenDynamicPersistentTileScheduler { static_assert(WarpSpecialized || NumProducerThreads == NumMmaThreads); @@ -482,13 +486,17 @@ class VarlenDynamicPersistentTileScheduler { int num_head, num_batch; int const qhead_per_khead; int const seqlen; + // int const max_kvblocks_in_l2; cutlass::FastDivmod head_divmod; cutlass::FastDivmod nsplits_divmod; int* const tile_count_semaphore; int const* const cu_seqlens; int const* const seqused; - // int* const num_m_blocks_ptr; int const* const num_splits_dynamic_ptr; + int const* const num_m_blocks_ptr; + int const* const varlen_batch_idx_ptr; + // int const* const num_n_blocks_ptr; + int const* const num_nheads_in_l2_ptr; }; static Params @@ -498,13 +506,20 @@ class VarlenDynamicPersistentTileScheduler { assert(args.tile_count_semaphore != nullptr); assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx assert(!Split || args.num_splits < (1 << 8)); // We use the top 8 bits to store num_splits + // int const size_l2 = 50 * 1024 * 1024; // 50 MB + // int const size_one_kvblock = kBlockN * (args.headdim + args.headdim_v) * args.element_size; + // int max_kvblocks_in_l2 = size_l2 / size_one_kvblock; return {args.num_head, args.num_batch, args.qhead_per_khead, args.seqlen, + // max_kvblocks_in_l2, cutlass::FastDivmod(args.num_head), cutlass::FastDivmod(!Split ? 1 : args.num_splits), args.tile_count_semaphore, args.cu_seqlens, args.seqused, - // args.num_m_blocks_ptr, - args.num_splits_dynamic_ptr}; + args.num_splits_dynamic_ptr, + args.num_m_blocks_ptr, + args.varlen_batch_idx_ptr, + // aras.num_n_blocks_ptr, + args.num_nheads_in_l2_ptr}; } static dim3 @@ -525,8 +540,15 @@ class VarlenDynamicPersistentTileScheduler { CUTLASS_DEVICE cute::tuple get_block_coord(Params const& params) const { + auto get_actual_batch = [&](int virtual_batch) { + if constexpr(Prepared && Sort) { + return params.varlen_batch_idx_ptr[virtual_batch]; + } else { + return virtual_batch; + } + }; if constexpr (!Split) { - return {block, bidh, bidb, 0 /*split_idx*/}; + return {block, bidh, get_actual_batch(bidb), 0 /*split_idx*/}; } else { // the top 8 bits of bidh store num_splits and the next 8 bits store split_idx // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift @@ -540,7 +562,7 @@ class VarlenDynamicPersistentTileScheduler { // if (threadIdx.x == 128) { // printf("blockIdx.x = %d, bidb = %d, bidh = %d, bidh_actual = %d, split_idx = %d\n", blockIdx.x, bidb, bidh, bidh_actual, split_idx); // } - return {block, bidh_actual, bidb, split_idx}; + return {block, bidh_actual, get_actual_batch(bidb), split_idx}; } } }; @@ -554,31 +576,39 @@ class VarlenDynamicPersistentTileScheduler { int lane = threadIdx.x % cutlass::NumThreadsPerWarp; auto get_num_m_blocks = [&] (int bidb_start) { int batch_idx = lane + bidb_start; - int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); - if (seqlen > kBlock) { - if (params.seqused) { - seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; - } else if (params.cu_seqlens) { - int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; - int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); - seqlen = next_cu_seqlen - cur_cu_seqlen; - } else { - seqlen = params.seqlen; + if constexpr (Prepared) { + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? params.num_m_blocks_ptr[batch_idx] : 0; + } else { + int seqlen = params.seqlen * (!PackGQA ? 1 : params.qhead_per_khead); + if (seqlen > kBlockM) { + if (params.seqused) { + seqlen = batch_idx < params.num_batch ? params.seqused[batch_idx] : 0; + } else if (params.cu_seqlens) { + int cur_cu_seqlen = batch_idx <= params.num_batch ? params.cu_seqlens[batch_idx] : 0; + int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1); + seqlen = next_cu_seqlen - cur_cu_seqlen; + } else { + seqlen = params.seqlen; + } + if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } } - if constexpr (PackGQA) { seqlen *= params.qhead_per_khead; } + return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 + ? cute::ceil_div(seqlen, kBlockM) : 0; + // ? params.num_m_blocks_ptr[batch_idx] : 0; } - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? cute::ceil_div(seqlen, kBlock) : 0; - // ? params.num_m_blocks_ptr[batch_idx] : 0; }; auto get_num_splits = [&] (int bidb_start) { int batch_idx = lane + bidb_start; - return batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1 - ? (!Split ? 1 : (params.num_splits_dynamic_ptr - ? params.num_splits_dynamic_ptr[batch_idx] - : params.nsplits_divmod.divisor)) - : 0; + bool is_valid = batch_idx < params.num_batch && lane < cutlass::NumThreadsPerWarp - 1; + if constexpr (!Split) { + return is_valid ? 1 : 0; + } else if constexpr(Prepared) { + return is_valid ? params.num_splits_dynamic_ptr[batch_idx] : 0; + } else { + return is_valid ? params.nsplits_divmod.divisor : 0; + } }; int num_m_blocks = get_num_m_blocks(current_work.bidb); // Different for each lane @@ -589,12 +619,14 @@ class VarlenDynamicPersistentTileScheduler { // Total number of blocks for the next 31 batches int m_blocks_in_group = __shfl_sync(0xffffffff, num_m_blocks_cumulative, cutlass::NumThreadsPerWarp - 1); // Only the lower 16 bits are the actual bidh - int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); - int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes - if constexpr (Split) { - int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; - group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); - } + // int current_bidh = !Split ? current_work.bidh : (current_work.bidh & 0x0000FFFF); + // int group_end_tile = current_work.tile_idx - current_work.block - current_bidh * __shfl_sync(0xffffffff, num_split_m_blocks, 0 /*lane*/) + m_blocks_in_group * params.num_head; // Same for all lanes + // if constexpr (Split) { + // int current_split_idx = (current_work.bidh & 0x00FF0000) >> 16; + // group_end_tile -= current_split_idx * __shfl_sync(0xffffffff, num_m_blocks, 0 /*lane*/); + // } + // NEW: current_work.tile_idx holds group_start_tile for starting batch + int group_end_tile = current_work.tile_idx + m_blocks_in_group * params.num_head; // Same for all lanes int bidb = current_work.bidb; // if (blockIdx.x <= 9 && threadIdx.x == 0) { // printf("Before while, blockIdx.x = %d, threadIdx.x = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, cur tile_idx = %d, cur block = %d, cur bidh = %d, num_split_m_blocks = %d, group_end_tile = %d, m_blocks_in_group = %d\n", blockIdx.x, threadIdx.x, current_work.bidb, num_m_blocks, next_tile_idx, current_work.tile_idx, current_work.block, current_bidh, num_split_m_blocks, group_end_tile, m_blocks_in_group); @@ -626,27 +658,81 @@ class VarlenDynamicPersistentTileScheduler { bidb += batch_idx_in_group; num_m_blocks = __shfl_sync(0xffffffff, num_m_blocks, batch_idx_in_group); if constexpr (Split) { num_splits = __shfl_sync(0xffffffff, num_splits, batch_idx_in_group); } - int mh_block = next_tile_idx - group_start_tile - (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; - int bidh = mh_block / num_m_blocks; - int block = mh_block - bidh * num_m_blocks; - if constexpr (Split) { - int bidh_actual = bidh / num_splits; - int split_idx = bidh - bidh_actual * num_splits; - // TODO: idk why this gives wrong answer nondeterministically - // int bidh_actual, split_idx; - // split_idx = params.head_divmod.divmod(bidh_actual, bidh); - // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx - // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift - uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); - // if (threadIdx.x == 0) { - // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + group_start_tile += (batch_idx_in_group == 0 ? 0 : __shfl_sync(0xffffffff, num_m_blocks_cumulative, batch_idx_in_group - 1)) * params.num_head; + int mh_block = next_tile_idx - group_start_tile; + int block, bidh; + if constexpr (LPT) { + if (!Split || num_splits == 1) { + // NOTE: code for computing nheads_in_l2 directly left as reference + // int num_n_blocks = params.num_n_blocks_ptr ? params.num_n_blocks_ptr[bidb] : num_m_blocks; + // auto find_log2_floor = [&](int n) { return 31 - cutlass::clz(n); }; + // int nheads_in_l2 = params.max_kvblocks_in_l2 < num_n_blocks + // ? 1 : 1 << find_log2_floor(params.max_kvblocks_in_l2 / num_n_blocks); + // if constexpr (!PackGQA) { nheads_in_l2 *= params.qhead_per_khead; } + // nheads_in_l2 = min(nheads_in_l2, params.num_head); + auto get_nheads_in_l2 = [&](int batch_idx) { + if constexpr(Prepared) { + return params.num_nheads_in_l2_ptr[batch_idx]; + } else { + return !PackGQA ? params.qhead_per_khead : 1; + } + }; + int nheads_in_l2 = get_nheads_in_l2(bidb); + int mh_in_l2 = nheads_in_l2 * num_m_blocks; + int section_idx = mh_block / mh_in_l2; + int l2_mod = mh_block - section_idx * mh_in_l2; + // tail section + int nheads_remainder = params.num_head - section_idx * nheads_in_l2; + int nheads_in_this_section = nheads_in_l2 <= nheads_remainder ? nheads_in_l2 : nheads_remainder; + block = l2_mod / nheads_in_this_section; + int bidh_residual = l2_mod - block * nheads_in_this_section; + bidh = section_idx * nheads_in_l2 + bidh_residual; + if constexpr(Split) { + // remember to set num_splits = 1 in work tile + uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(num_splits) << 24); + bidh = reinterpret_cast(bidh_packed); + } + } else { + // NOTE: leave traverse heads first version for reference + // block = params.head_divmod.divmod(bidh, mh_block); + // if constexpr (Split) { + // int split_idx = block / num_m_blocks; + // block = block - split_idx * num_m_blocks; + // uint32_t bidh_packed = reinterpret_cast(bidh) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // bidh = reinterpret_cast(bidh_packed); + // } + bidh = mh_block / num_m_blocks; + block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + bidh = reinterpret_cast(bidh_packed); + } + } + block = num_m_blocks - 1 - block; + } else { + bidh = mh_block / num_m_blocks; + block = mh_block - bidh * num_m_blocks; + if constexpr (Split) { + int bidh_actual = bidh / num_splits; + int split_idx = bidh - bidh_actual * num_splits; + // TODO: idk why this gives wrong answer nondeterministically + // int bidh_actual, split_idx; + // split_idx = params.head_divmod.divmod(bidh_actual, bidh); + // Use the top 8 bits to store num_splits and the next 8 bits to store split_idx + // reinterpret_cast to uint32_t to make sure we're not doing sign extension when we shift + uint32_t bidh_packed = reinterpret_cast(bidh_actual) + (reinterpret_cast(split_idx) << 16) + (reinterpret_cast(num_splits) << 24); + // if (threadIdx.x == 0) { + // printf("blockIdx.x = %d, group_start_tiled = %d, bidb = %d, batch_idx_in_group = %d, mh_block = %d, num_m_blocks = %d, bidh = %d, bidh_actual = %d, split_idx = %d, num_splits = %d, bidh_packed = %d\n", blockIdx.x, group_start_tile, bidb, batch_idx_in_group, mh_block, num_m_blocks, bidh, bidh_actual, split_idx, num_splits, bidh_packed); + // } + bidh = reinterpret_cast(bidh_packed); + } + // if (blockIdx.x <= 9 && threadIdx.x == 0) { + // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); // } - bidh = reinterpret_cast(bidh_packed); } - // if (blockIdx.x <= 9 && threadIdx.x == 0) { - // printf("Before returning, blockIdx.x = %d, threadIdx.x = %d, group_start_tile = %d, batch_idx_in_group = %d, bidb = %d, num_m_blocks = %d, next_tile_idx = %d, group_end_tile = %d, m_blocks_in_group = %d, mh_block = %d, bidh = %d, block = %d\n", blockIdx.x, threadIdx.x, group_start_tile, batch_idx_in_group, bidb, num_m_blocks, next_tile_idx, group_end_tile, m_blocks_in_group, mh_block, bidh, block); - // } - return {next_tile_idx, block, bidh, bidb}; + return {group_start_tile, block, bidh, bidb}; } template diff --git a/hopper/tile_size.h b/hopper/tile_size.h index e6cb31515c7..8353542c477 100644 --- a/hopper/tile_size.h +++ b/hopper/tile_size.h @@ -21,7 +21,7 @@ constexpr std::tuple tile_size_fwd_sm90( return {128, 96, true, false}; } else { // Switch to tile size 192 x 192 for now - bool const use_blockN_128 = is_causal || is_local; + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; return {192, use_blockN_128 ? 128 : 192, use_blockN_128, true}; } // Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen @@ -29,8 +29,9 @@ constexpr std::tuple tile_size_fwd_sm90( } else if (headdim <= 96) { return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true}; } else if (headdim <= 128) { - return {128, is_causal || is_local || paged_kv_non_TMA ? 128 : 176, true, true}; - // {128, 192, false, false} and {192, 128, false, true} are quite good too + bool const use_blockN_128 = is_causal || is_local || paged_kv_non_TMA; + return {128, use_blockN_128 ? 128 : 176, true, true}; + // {128, 192, true, false} and {192, 128, false, true} are quite good too // 128 x 192 hits the limit of smem if MmaPV_is_RS, 128 x 144 hits the limit if !MmaPV_is_RS } else if (headdim <= 192) { return {128, paged_kv_non_TMA || is_local ? 96 : (headdim_v <= 128 ? 128 : 112), true, true}; // 128 x 112 hits the limit of smem