Skip to content

Commit

Permalink
bugfix: fix the stride bug in page append (#527)
Browse files Browse the repository at this point in the history
We introduced a bug in #513 because we didn't consider non-contiguous
kv-cache for page append operator, this PR fix the bug.
  • Loading branch information
yzh119 authored Oct 11, 2024
1 parent 0dcd505 commit 93b5d4e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
11 changes: 5 additions & 6 deletions include/flashinfer/page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ struct paged_kv_t {
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param k_data The flattened key cache
* \param v_data The flattened value cache
* \param k_data The start pointer of key cache, k_cache should be contiguous
* \param v_data The start pointer of value cache, v_cache should be contiguous
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
Expand All @@ -107,20 +107,19 @@ struct paged_kv_t {
}

/*!
* \brief Construct a paged key-value cache
* \brief Construct a paged key-value cache with custom kv-cache strides
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param k_data The flattened key cache
* \param v_data The flattened value cache
* \param k_data The start pointer of key cache, k_cache doesn't have to be contiguous
* \param v_data The start pointer of value cache, v_cache doesn't have to be contiguous
* \param kv_strides custom strides of each dimensions of k_data and v_data
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType* k_data,
Expand Down
19 changes: 13 additions & 6 deletions python/csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
num_heads = paged_k_cache.size(2);
}

// get kv_cache_strides
const int64_t* kv_cache_strides = nullptr;
auto k_strides = paged_k_cache.strides();
auto v_strides = paged_v_cache.strides();
TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical");
kv_cache_strides = k_strides.data();

CHECK_EQ(append_key.size(1), num_heads);
CHECK_EQ(append_key.size(2), head_dim);
CHECK_EQ(append_value.size(1), num_heads);
Expand All @@ -79,12 +86,12 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
auto kv_scalar_dtype = paged_k_cache.scalar_type();

bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_scalar_dtype, c_type, [&] {
paged_kv_t<c_type, int32_t> paged_kv(num_heads, page_size, head_dim, batch_size, kv_layout,
static_cast<c_type*>(paged_k_cache.data_ptr()),
static_cast<c_type*>(paged_v_cache.data_ptr()),
static_cast<int32_t*>(kv_indices.data_ptr()),
static_cast<int32_t*>(kv_indptr.data_ptr()),
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
paged_kv_t<c_type, int32_t> paged_kv(
num_heads, page_size, head_dim, batch_size, kv_layout,
static_cast<c_type*>(paged_k_cache.data_ptr()),
static_cast<c_type*>(paged_v_cache.data_ptr()), kv_cache_strides,
static_cast<int32_t*>(kv_indices.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
static_cast<int32_t*>(kv_last_page_len.data_ptr()));
cudaError_t status =
AppendPagedKVCache(paged_kv, static_cast<c_type*>(append_key.data_ptr()),
static_cast<c_type*>(append_value.data_ptr()),
Expand Down

0 comments on commit 93b5d4e

Please sign in to comment.