diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index e15adbcb..7ffeb121 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -80,8 +80,8 @@ struct paged_kv_t { * \param head_dim The dimension of each head * \param batch_size The batch size * \param layout The layout of last 3 dimensions in KV-Cache. - * \param k_data The flattened key cache - * \param v_data The flattened value cache + * \param k_data The start pointer of key cache, k_cache should be contiguous + * \param v_data The start pointer of value cache, v_cache should be contiguous * \param indices The page indices array * \param indptr The page indptr array * \param last_page_len The offset of the last page for each request in the batch @@ -107,20 +107,19 @@ struct paged_kv_t { } /*! - * \brief Construct a paged key-value cache + * \brief Construct a paged key-value cache with custom kv-cache strides * \param num_heads The number of heads * \param page_size The size of each page * \param head_dim The dimension of each head * \param batch_size The batch size * \param layout The layout of last 3 dimensions in KV-Cache. - * \param k_data The flattened key cache - * \param v_data The flattened value cache + * \param k_data The start pointer of key cache, k_cache doesn't have to be contiguous + * \param v_data The start pointer of value cache, v_cache doesn't have to be contiguous * \param kv_strides custom strides of each dimensions of k_data and v_data * \param indices The page indices array * \param indptr The page indptr array * \param last_page_len The offset of the last page for each request in the batch * \param rope_pos_offset The start position of each request in the batch. - * \note This constructor should only be used when page_storage == kIndices */ __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, uint32_t batch_size, QKVLayout layout, DType* k_data, diff --git a/python/csrc/page.cu b/python/csrc/page.cu index 2e002338..79aab9ee 100644 --- a/python/csrc/page.cu +++ b/python/csrc/page.cu @@ -69,6 +69,13 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, num_heads = paged_k_cache.size(2); } + // get kv_cache_strides + const int64_t* kv_cache_strides = nullptr; + auto k_strides = paged_k_cache.strides(); + auto v_strides = paged_v_cache.strides(); + TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); + kv_cache_strides = k_strides.data(); + CHECK_EQ(append_key.size(1), num_heads); CHECK_EQ(append_key.size(2), head_dim); CHECK_EQ(append_value.size(1), num_heads); @@ -79,12 +86,12 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, auto kv_scalar_dtype = paged_k_cache.scalar_type(); bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_scalar_dtype, c_type, [&] { - paged_kv_t paged_kv(num_heads, page_size, head_dim, batch_size, kv_layout, - static_cast(paged_k_cache.data_ptr()), - static_cast(paged_v_cache.data_ptr()), - static_cast(kv_indices.data_ptr()), - static_cast(kv_indptr.data_ptr()), - static_cast(kv_last_page_len.data_ptr())); + paged_kv_t paged_kv( + num_heads, page_size, head_dim, batch_size, kv_layout, + static_cast(paged_k_cache.data_ptr()), + static_cast(paged_v_cache.data_ptr()), kv_cache_strides, + static_cast(kv_indices.data_ptr()), static_cast(kv_indptr.data_ptr()), + static_cast(kv_last_page_len.data_ptr())); cudaError_t status = AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), static_cast(append_value.data_ptr()),