diff --git a/csrc/cutlass b/csrc/cutlass index bbe579a9e3b..751eb9a8859 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49 +Subproject commit 751eb9a8859ac36bfc77551f9e4a957c31a5a8b1 diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 001acacaf6d..75ba3ed22dc 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -561,7 +561,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); const int page_block_size = !paged_KV ? 1 : k.size(1); - TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16"); if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case if (is_causal) { window_size_right = 0; } @@ -1285,7 +1285,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : kcache.size(0); const int page_block_size = !paged_KV ? 1 : kcache.size(1); - TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16"); const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int num_heads_k = kcache.size(2); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index bd29d567077..34922d51943 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -136,6 +136,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); @@ -560,16 +561,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; - const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; - const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; const index_t row_offset_k = block_table == nullptr ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride - : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + : (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread + const index_t row_offset_v = block_table == nullptr ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride - : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + : (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -581,7 +583,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), Shape, Int>{}, make_stride(params.v_row_stride, _1{})); - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQ{}); Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); @@ -589,15 +590,30 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV; + auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + + Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK); + Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV); + + Tensor tKgK = make_tensor(tKgK_.data(), reshape_thread_tile(tKgK_.layout())); + Tensor tKsK = make_tensor(tKsK_.data(), reshape_thread_tile(tKsK_.layout())); + Tensor tVgV = make_tensor(tVgV_.data(), reshape_thread_tile(tVgV_.layout())); + Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout())); + + if (block_table != nullptr) { + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); + } typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); @@ -635,8 +651,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV_ = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKVcKV = make_tensor(tKVcKV_.data(), reshape_thread_tile(tKVcKV_.layout())); // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); @@ -653,11 +670,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Prologue // Copy from Knew to K, optionally apply rotary embedding. - typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; - auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; - auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); if constexpr (Append_KV) { + typename Kernel_traits::GmemTiledCopyRotcossinPaged gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinContPaged gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. @@ -674,10 +692,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); - Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); - Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); - Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); - Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + + Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont_ = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + + Tensor tRgCos = make_tensor(tRgCos_.data(), reshape_thread_tile(tRgCos_.layout())); + Tensor tRgSin = make_tensor(tRgSin_.data(), reshape_thread_tile(tRgSin_.layout())); + Tensor tRgCosCont = make_tensor(tRgCosCont_.data(), reshape_flatten_thread_tile(tRgCosCont_.layout())); + Tensor tRgSinCont = make_tensor(tRgSinCont_.data(), reshape_flatten_thread_tile(tRgSinCont_.layout())); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } @@ -698,8 +723,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), Shape, Int>{}, make_stride(params.vnew_row_stride, _1{})); - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV_new; + auto gmem_thr_copy_KV_new = gmem_tiled_copy_KV_new.get_thread_slice(tidx); + Tensor tKgKnew_ = gmem_thr_copy_KV_new.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew_ = gmem_thr_copy_KV_new.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + auto tKgKnew = make_tensor(tKgKnew_.data(), reshape_thread_tile(tKgKnew_.layout())); + auto tVgVnew = make_tensor(tVgVnew_.data(), reshape_thread_tile(tVgVnew_.layout())); const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); auto tKgK_data = tKgK.data(); @@ -739,14 +769,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { if (n_block > n_block_copy_min) { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; - const int offset_diff = block_table_offset_next - block_table_offset_cur; - tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; - tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } } } @@ -759,9 +785,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Read Q from gmem to smem, optionally apply rotary embedding. if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. @@ -796,7 +826,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); @@ -835,17 +865,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); @@ -870,13 +897,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -913,13 +937,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + tVgV.data() = gV.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -934,13 +955,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + tKgK.data() = gK.data() + flash::resolve_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index a7a5cf1edd7..04a9b3b2920 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -131,6 +131,17 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read + + // from how many rows does each thread have to fetch + static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); + // Here we assign a contiguous tile to each thread, rather than a 1x8 row every + // (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread + // do not cross a page boundary. This way, each thread need only fetch 1 page index per + // mainloop iteration. R>udimentary testing shows no slowdown. + using GmemTiledCopyQKVPaged = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout, _8>, Stride<_8, _1>>{})); using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, @@ -156,6 +167,14 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load + using GmemTiledCopyRotcossinPaged = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinContPaged = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 2b45e87b20f..4f999a6b764 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -292,6 +292,53 @@ void cp_async_wait() { //////////////////////////////////////////////////////////////////////////////////////////////////// +// resolves offset of a slice of a paged kv copy from gmem. +// assumes that the tensor has already been positioned at the correct head. +template +__forceinline__ __device__ +int resolve_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, + const int* block_table, const int page_stride, const int row_stride) { + constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; + constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; + constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad; + constexpr int kBlockN = Kernel_traits::kBlockN; + + const int col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad; + const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; + const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN; + const int page_offset = global_row_offset % page_block_size; + const int virtual_page_idx = global_row_offset / page_block_size; + + return block_table[virtual_page_idx] * page_stride + + page_offset * row_stride + + col_offset; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Layout reshape function. Given a layout with modes ((v1, v2), m, k), returns (v1, v2, k), +// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures +// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors. +template +__forceinline__ __device__ +auto reshape_thread_tile(Layout l) { + return make_layout(append(get<0>(l.shape()), get<2>(l.shape())), + append(get<0>(l.stride()), get<2>(l.stride()))); +} + +// reshapes and flattens the thread tile layout. A separate function is needed for the case where +// one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact +// for the case of swizzled layouts +template +__forceinline__ __device__ +auto reshape_flatten_thread_tile(Layout l) { + auto mode_0 = filter(flatten(get<0>(l))); + return make_layout(append(mode_0.shape(), get<2>(l.shape())), + append(mode_0.stride(), get<2>(l.stride()))); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 308e30bec48..65af6af1025 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1543,7 +1543,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): ], ) # TODO: add smaller page sizes when https://github.com/Dao-AILab/flash-attention/pull/824 is merged -@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) +@pytest.mark.parametrize("paged_kv_block_size", [None, 16, 256, 512]) # @pytest.mark.parametrize("seqlen_q,seqlen_k", [(256, 128)]) def test_flash_attn_varlen_causal( seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype @@ -1832,9 +1832,8 @@ def test_flash_attn_splitkv( # @pytest.mark.parametrize("rotary_interleaved", [False]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0]) -@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) -# @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) -# @pytest.mark.parametrize("paged_kv_block_size", [256]) +# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) +@pytest.mark.parametrize("paged_kv_block_size", [None, 16, 256, 512]) @pytest.mark.parametrize("has_batch_idx", [False, True]) # @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) @@ -2462,3 +2461,47 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus assert torch.equal(dv, dv) assert torch.equal(dk, dk) assert torch.equal(dq, dq) + + +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("paged_kv_block_size", [16]) +# @pytest.mark.parametrize("has_batch_idx", [False]) +@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize("nheads", [32]) +@pytest.mark.parametrize("b", [4]) +@pytest.mark.parametrize("n", [10]) +@pytest.mark.parametrize("seqlen_q,seqlen_k", [(170, 170)]) +def test_flash_attn_paged_kvcache_overflow( + seqlen_q, + seqlen_k, + d, + nheads, + b, + n, + paged_kv_block_size, + causal, + dtype, +): + device = "cuda" + num_blocks = 1000*16//paged_kv_block_size + key_cache = torch.rand([num_blocks, paged_kv_block_size, nheads, d], dtype=dtype, device=device) + value_cache = torch.rand([num_blocks, paged_kv_block_size, nheads, d], dtype=dtype, device=device) + cache_seqlens = torch.zeros(b, dtype=torch.int32, device=device) + + for _ in range(n): + query = torch.rand([b, seqlen_q, nheads, d], dtype=dtype, device=device) + key = torch.rand([b, seqlen_k, nheads, d], dtype=dtype, device=device) + value = torch.rand([b, seqlen_k, nheads, d], dtype=dtype, device=device) + block_tables = torch.randint(0, num_blocks, size=(b, (seqlen_k + paged_kv_block_size - 1) // paged_kv_block_size), dtype=torch.int32, device=device) + output = flash_attn_with_kvcache( + query, + key_cache, + value_cache, + k=key, + v=value, + cache_seqlens=cache_seqlens, + block_table=block_tables, + causal=causal, + )