From dc187ad5dbe6f4a2a60343b8b7ebead0e5d45f8c Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Mon, 26 Feb 2024 20:46:56 -0800 Subject: [PATCH 1/3] [fix] added small page size support for flash attention. *apply change from pull request 824. --- src/engine/engine.cpp | 2 +- src/kernels/flash_attn/flash_api.cpp | 2 +- src/kernels/flash_attn/src/flash_fwd_kernel.h | 136 ++++++++++-------- src/kernels/flash_attn/src/kernel_traits.h | 20 +++ src/kernels/flash_attn/src/utils.h | 47 ++++++ src/layers/attention/attention_test.cpp | 11 +- 6 files changed, 149 insertions(+), 69 deletions(-) diff --git a/src/engine/engine.cpp b/src/engine/engine.cpp index 98f7905f..f85bc663 100644 --- a/src/engine/engine.cpp +++ b/src/engine/engine.cpp @@ -16,7 +16,7 @@ static constexpr int64_t GB = int64_t(1024) * 1024 * 1024; -DEFINE_int32(block_size, 256, "slots per block, value must be multiple of 256"); +DEFINE_int32(block_size, 16, "slots per block, value must be multiple of 16"); DEFINE_int64(max_cache_size, 10 * GB, "max cache size in bytes, default 10GB"); DEFINE_double(max_memory_utilization, 0.9, diff --git a/src/kernels/flash_attn/flash_api.cpp b/src/kernels/flash_attn/flash_api.cpp index 6882ec58..42b1eab8 100644 --- a/src/kernels/flash_attn/flash_api.cpp +++ b/src/kernels/flash_attn/flash_api.cpp @@ -257,7 +257,7 @@ mha_varlen_fwd(at::Tensor& out, // [n_tokens, n_heads, head_dim] const int n_blocks = !paged_KV ? 0 : k.size(0); const int block_size = !paged_KV ? 1 : k.size(1); // TODO: support smaller block sizes - TORCH_CHECK(!paged_KV || block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + TORCH_CHECK(!paged_KV || block_size % 16 == 0, "Paged KV cache block size must be divisible by 16"); const int batch_size = cu_seqlens_q.numel() - 1; // [n_tokens, n_heads, head_dim] diff --git a/src/kernels/flash_attn/src/flash_fwd_kernel.h b/src/kernels/flash_attn/src/flash_fwd_kernel.h index bb02a301..1af17464 100644 --- a/src/kernels/flash_attn/src/flash_fwd_kernel.h +++ b/src/kernels/flash_attn/src/flash_fwd_kernel.h @@ -515,16 +515,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; - const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; - const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; const index_t row_offset_k = block_table == nullptr ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride - : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + : (bidh / params.h_h_k_ratio) * params.k_head_stride; // block addresses are later resolved per-thread + const index_t row_offset_v = block_table == nullptr ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride - : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + : (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -544,15 +543,30 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV; + auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + + Tensor tKgK_ = gmem_thr_copy_KV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK_ = gmem_thr_copy_KV.partition_D(sK); + Tensor tVgV_ = gmem_thr_copy_KV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV_ = gmem_thr_copy_KV.partition_D(sV); + + Tensor tKgK = make_tensor(tKgK_.data(), reshape_thread_tile(tKgK_.layout())); + Tensor tKsK = make_tensor(tKsK_.data(), reshape_thread_tile(tKsK_.layout())); + Tensor tVgV = make_tensor(tVgV_.data(), reshape_thread_tile(tVgV_.layout())); + Tensor tVsV = make_tensor(tVsV_.data(), reshape_thread_tile(tVsV_.layout())); + + if (block_table != nullptr) { + tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); + tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block_max, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); + } typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); @@ -590,8 +604,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV_ = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKVcKV = make_tensor(tKVcKV_.data(), reshape_thread_tile(tKVcKV_.layout())); // Allocate predicate tensors for k Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); @@ -608,11 +623,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Prologue // Copy from Knew to K, optionally apply rotary embedding. - typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; - auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; - auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); if constexpr (Append_KV) { + typename Kernel_traits::GmemTiledCopyRotcossinPaged gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinContPaged gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. @@ -629,10 +645,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_sin_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); - Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); - Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); - Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); - Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + + Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont_ = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + + Tensor tRgCos = make_tensor(tRgCos_.data(), reshape_thread_tile(tRgCos_.layout())); + Tensor tRgSin = make_tensor(tRgSin_.data(), reshape_thread_tile(tRgSin_.layout())); + Tensor tRgCosCont = make_tensor(tRgCosCont_.data(), reshape_flatten_thread_tile(tRgCosCont_.layout())); + Tensor tRgSinCont = make_tensor(tRgSinCont_.data(), reshape_flatten_thread_tile(tRgSinCont_.layout())); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } @@ -653,8 +676,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), Shape, Int>{}, make_stride(params.vnew_row_stride, _1{})); - Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) - Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + typename Kernel_traits::GmemTiledCopyQKVPaged gmem_tiled_copy_KV_new; + auto gmem_thr_copy_KV_new = gmem_tiled_copy_KV_new.get_thread_slice(tidx); + Tensor tKgKnew_ = gmem_thr_copy_KV_new.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew_ = gmem_thr_copy_KV_new.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + auto tKgKnew = make_tensor(tKgKnew_.data(), reshape_thread_tile(tKgKnew_.layout())); + auto tVgVnew = make_tensor(tVgVnew_.data(), reshape_thread_tile(tVgVnew_.layout())); const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); auto tKgK_data = tKgK.data(); @@ -694,14 +722,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { if (n_block > n_block_copy_min) { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; - const int offset_diff = block_table_offset_next - block_table_offset_cur; - tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; - tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); + tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } } } @@ -714,9 +738,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Read Q from gmem to smem, optionally apply rotary embedding. if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs - flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. @@ -751,7 +779,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); @@ -790,17 +818,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); @@ -825,13 +850,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -868,13 +890,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); } else { - const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; - const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = n_block * kBlockN / params.page_block_size; - const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; - tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + tVgV.data() = gV.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block + 1, params.page_block_size, + block_table, params.v_batch_stride, params.v_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -889,13 +908,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (block_table == nullptr) { tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); } else { - const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; - const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; - const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; - const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; - tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + tKgK.data() = gK.data() + flash::init_thread_kv_page_slice_offset(tidx, n_block, params.page_block_size, + block_table, params.k_batch_stride, params.k_row_stride); } - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); diff --git a/src/kernels/flash_attn/src/kernel_traits.h b/src/kernels/flash_attn/src/kernel_traits.h index e1ee4b83..c0fd202a 100644 --- a/src/kernels/flash_attn/src/kernel_traits.h +++ b/src/kernels/flash_attn/src/kernel_traits.h @@ -127,6 +127,18 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per read + + // from how many rows does each thread have to fetch + static constexpr int kGmemRowsPerThread = kBlockN / (kNThreads / kGmemThreadsPerRow); + // Here we assign a contiguous tile to each thread, rather than a 1x8 row every + // (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread + // do not cross a page boundary. This way, each thread need only fetch 1 page index per + // mainloop iteration. R>udimentary testing shows no slowdown. + using GmemTiledCopyQKVPaged = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout, _8>, Stride<_8, _1>>{})); + using GmemTiledCopyO = decltype( make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, @@ -152,6 +164,14 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtomRotcossin{}, Layout>{})); // Val layout, 8 vals per load + using GmemTiledCopyRotcossinPaged = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinContPaged = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomRotcossin{}, + Layout, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/kernels/flash_attn/src/utils.h b/src/kernels/flash_attn/src/utils.h index 4bcfa7f6..05f1d9e8 100644 --- a/src/kernels/flash_attn/src/utils.h +++ b/src/kernels/flash_attn/src/utils.h @@ -379,4 +379,51 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor const &S //////////////////////////////////////////////////////////////////////////////////////////////////// +// resolves initial base offset of a slice of a paged kv copy from gmem. +// assumes that the tensor has already been positioned at the correct head. +template +__forceinline__ __device__ +int init_thread_kv_page_slice_offset(const int tidx, const int n_block_max, const int page_block_size, + const int* block_table, const int page_stride, const int row_stride) { + constexpr int kGmemThreadsPerRow = Kernel_traits::kGmemThreadsPerRow; + constexpr int kGmemRowsPerThread = Kernel_traits::kGmemRowsPerThread; + constexpr int kGmemElemsPerLoad = Kernel_traits::kGmemElemsPerLoad; + constexpr int kBlockN = Kernel_traits::kBlockN; + + const int col_offset = tidx % kGmemThreadsPerRow * kGmemElemsPerLoad; + const int block_row_offset = tidx / kGmemThreadsPerRow * kGmemRowsPerThread; + const int global_row_offset = block_row_offset + (n_block_max - 1) * kBlockN; + const int page_offset = global_row_offset % page_block_size; + const int virtual_page_idx = global_row_offset / page_block_size; + + return block_table[virtual_page_idx] * page_stride + + page_offset * row_stride + + col_offset; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Layout reshape function. Given a layout with modes ((v1, v2), m, k), returns (v1, v2, k), +// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures +// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors. +template +__forceinline__ __device__ +auto reshape_thread_tile(Layout l) { + return make_layout(append(get<0>(l.shape()), get<2>(l.shape())), + append(get<0>(l.stride()), get<2>(l.stride()))); +} + +// reshapes and flattens the thread tile layout. A separate function is needed for the case where +// one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact +// for the case of swizzled layouts +template +__forceinline__ __device__ +auto reshape_flatten_thread_tile(Layout l) { + auto mode_0 = filter(flatten(get<0>(l))); + return make_layout(append(mode_0.shape(), get<2>(l.shape())), + append(mode_0.stride(), get<2>(l.stride()))); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash diff --git a/src/layers/attention/attention_test.cpp b/src/layers/attention/attention_test.cpp index 09861289..f10b5774 100644 --- a/src/layers/attention/attention_test.cpp +++ b/src/layers/attention/attention_test.cpp @@ -166,8 +166,7 @@ class AttentionDecodeTest int64_t /*n_kv_heads*/, int64_t /*head_dim*/, float /*scale*/, - bool /*alibi*/, - int32_t /*num_splits*/>> {}; + bool /*alibi*/>> {}; TEST_P(AttentionDecodeTest, KVCache) { const auto& [device, @@ -180,8 +179,7 @@ TEST_P(AttentionDecodeTest, KVCache) { n_kv_heads, head_dim, scale, - alibi, - num_splits] = GetParam(); + alibi] = GetParam(); // make sure kv_max_seq_len >= q_max_seq_len if (kv_max_seq_len < q_max_seq_len) { GTEST_SKIP() << "kv_max_seq_len < q_max_seq_len"; @@ -325,15 +323,14 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(torch::kCUDA), ::testing::Values(torch::kHalf, torch::kBFloat16), ::testing::Values(1, 10), // batch_size - ::testing::Values(256), // block_size + ::testing::Values(16, 80, 256), // block_size ::testing::Values(1, 10), // q_max_seq_len ::testing::Values(100, 200), // kv_max_seq_len ::testing::Values(6), // n_heads ::testing::Values(6 /*mha*/, 3 /*gqa*/, 1 /*mqa*/), // n_kv_heads ::testing::Values(32, 40, 64, 128), // head_dim ::testing::Values(0.9, 1.0), // scale - ::testing::Values(false, true), // alibi - ::testing::Values(1) // num_splits + ::testing::Values(false, true) // alibi )); } // namespace llm From 1908eb24a46500ea29f1f116aad571a17222f469 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 22 Mar 2024 10:37:11 -0700 Subject: [PATCH 2/3] remove check --- src/engine/engine.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/engine/engine.cpp b/src/engine/engine.cpp index f85bc663..da69e655 100644 --- a/src/engine/engine.cpp +++ b/src/engine/engine.cpp @@ -294,8 +294,6 @@ bool Engine::init_kv_cache(int64_t cache_size_in_bytes) { CHECK_GT(cache_size_in_bytes, 0); LOG(INFO) << "Initializing kv cache with size: " << readable_size(cache_size_in_bytes); - CHECK(FLAGS_block_size % 256 == 0) - << "cache block size must be divisible by 256"; const int64_t block_size = FLAGS_block_size; From cd6ba13c3b59cb3e35fe49be6acc210b035a70e2 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 22 Mar 2024 10:39:56 -0700 Subject: [PATCH 3/3] fix check. --- src/request/sequence.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/request/sequence.cpp b/src/request/sequence.cpp index 5720b6d2..7dcf8e39 100644 --- a/src/request/sequence.cpp +++ b/src/request/sequence.cpp @@ -207,7 +207,7 @@ std::vector Sequence::kv_cache_slots(int32_t pos_start, } void Sequence::commit_kv_cache(size_t size) { - CHECK(kv_cache_pos_ + size < kv_cache_capacity()); + CHECK(kv_cache_pos_ + size <= kv_cache_capacity()); kv_cache_pos_ += size; }