diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 018215fa5a..e63d068d75 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -36,7 +36,6 @@ #include "../utils.cuh" #include "../vec_dtypes.cuh" #include "cascade.cuh" -#include "handler.cuh" #include "state.cuh" namespace flashinfer { diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index b982c04862..6e313b2dd2 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -240,10 +241,10 @@ cudaError_t PartitionPagedKVCacheComputeAuxiliaryInfo( return cudaSuccess; } -struct AlignedAlloactor { +struct AlignedAllocator { void* ptr; size_t space; - AlignedAlloactor(void* buf, size_t space) : ptr(buf), space(space) {} + AlignedAllocator(void* buf, size_t space) : ptr(buf), space(space) {} template T* aligned_alloc(size_t size, size_t alignment) { if (std::align(alignment, size, ptr, space)) { @@ -303,34 +304,41 @@ class BatchDecodeHandler { FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr, num_qo_heads, page_size, - /*enable_cuda_graph=*/false, stream_)); + /*enable_cuda_graph=*/IsCUDAGraphEnabled(), stream_)); batch_size_after_partition_ = new_batch_size; - if (tmp_size > 0) { - AlignedAlloactor allocator(buffer, workspace_size_in_bytes); - float_buffer_ = allocator.aligned_alloc(tmp_size, 16); - new_indptr_ = - allocator.aligned_alloc((batch_size_after_partition_ + 1) * sizeof(IdType), 16); + if (IsCUDAGraphEnabled()) { + // NOTE(Zihao): max_batch_size_after_partition_ is determined in handler initialization. + // the value should not be changed during the lifetime of the handler. + // So it should be compatible with CUDAGraph which requires fixed pointer. + size_t max_tmp_size = num_qo_heads * max_batch_size_after_partition_ * + (HEAD_DIM * sizeof(DTypeOut) + 2 * sizeof(float)); + AlignedAllocator allocator(buffer, workspace_size_in_bytes); + float_buffer_ = allocator.aligned_alloc(max_tmp_size, 16); + new_indptr_ = allocator.aligned_alloc( + (max_batch_size_after_partition_ + 1) * sizeof(IdType), 16); + void* new_indptr_h_ = page_locked_buffer_; new_last_page_len_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); void* new_last_page_len_h_ = (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = - allocator.aligned_alloc((batch_size_before_partition_ + 1) * sizeof(IdType), 16); + chunk_indptr_ = allocator.aligned_alloc( + (max_batch_size_after_partition_ + 1) * sizeof(IdType), 16); void* chunk_indptr_h_ = (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); batch_idx_map_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); void* batch_idx_map_h_ = (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); chunk_start_pos_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); void* chunk_start_pos_h_ = (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); seq_lengths_before_partition_ = - allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); void* seq_lengths_before_partition_h_ = (char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); + size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len, @@ -338,6 +346,42 @@ class BatchDecodeHandler { (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, (IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_, num_bytes_to_copy, stream_)); + } else { + if (tmp_size > 0) { + AlignedAllocator allocator(buffer, workspace_size_in_bytes); + float_buffer_ = allocator.aligned_alloc(tmp_size, 16); + new_indptr_ = + allocator.aligned_alloc((batch_size_after_partition_ + 1) * sizeof(IdType), 16); + void* new_indptr_h_ = page_locked_buffer_; + new_last_page_len_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* new_last_page_len_h_ = + (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); + chunk_indptr_ = + allocator.aligned_alloc((batch_size_before_partition_ + 1) * sizeof(IdType), 16); + void* chunk_indptr_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); + batch_idx_map_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* batch_idx_map_h_ = + (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); + chunk_start_pos_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* chunk_start_pos_h_ = + (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); + seq_lengths_before_partition_ = + allocator.aligned_alloc(batch_size_after_partition_ * sizeof(IdType), 16); + void* seq_lengths_before_partition_h_ = + (char*)page_locked_buffer_ + + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); + size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; + FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( + max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len, + (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_, + (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, + (IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_, + num_bytes_to_copy, stream_)); + } } forward_started_ = true; return cudaSuccess; @@ -372,7 +416,12 @@ class BatchDecodeHandler { void SetCUDAStream(cudaStream_t stream) { stream_ = stream; } - BatchDecodeHandler(size_t max_workspace_size_in_bytes = 64 * 1024 * 1024) + /*! + * \note (Zihao): when enable_cuda_graph is true, max_workspace_size_in_bytes will be ignored, + * when enable_cuda_graph is false, max_batch_size will be ignored. + */ + BatchDecodeHandler(size_t max_workspace_size_in_bytes = 128 * 64 * 64, + size_t max_batch_size = 16384, bool enable_cuda_graph = false) : batch_size_after_partition_(0U), float_buffer_(nullptr), new_indptr_(nullptr), @@ -382,7 +431,27 @@ class BatchDecodeHandler { chunk_start_pos_(nullptr), seq_lengths_before_partition_(nullptr), forward_started_(false), + cuda_graph_enabled_(enable_cuda_graph), stream_(nullptr) { + if (enable_cuda_graph) { + int dev_id = 0, num_sm = 0, max_thread_blocks_per_sm = 0; + cudaGetDevice(&dev_id); + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id); + cudaDeviceGetAttribute(&max_thread_blocks_per_sm, cudaDevAttrMaxBlocksPerMultiprocessor, + dev_id); + max_batch_size_after_partition_ = + std::max(max_thread_blocks_per_sm * num_sm, max_batch_size); + size_t required_max_workspace_size_in_bytes = + 6 * (sizeof(uint64_t) * (max_batch_size_after_partition_ + 1) + 16); + if (required_max_workspace_size_in_bytes > max_workspace_size_in_bytes) { + std::ostringstream err_msg; + err_msg << "RuntimeError: reserved workspace size is not enough, required size: " + << required_max_workspace_size_in_bytes + << " bytes, actual size: " << max_workspace_size_in_bytes + << " bytes, please increase workspace buffer size."; + throw std::runtime_error(err_msg.str()); + } + } cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); } ~BatchDecodeHandler() { @@ -390,7 +459,7 @@ class BatchDecodeHandler { cudaFreeHost(page_locked_buffer_); } - virtual bool IsCUDAGraphMode() const { return false; } + bool IsCUDAGraphEnabled() const { return cuda_graph_enabled_; } protected: uint32_t batch_size_before_partition_; @@ -404,87 +473,9 @@ class BatchDecodeHandler { void* chunk_start_pos_; void* seq_lengths_before_partition_; bool forward_started_; - cudaStream_t stream_; -}; - -class CUDAGraphBatchDecodeHandler : public BatchDecodeHandler { - public: - template - cudaError_t CUDAGraphBeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, - IdType* indptr, IdType* last_page_len, - uint32_t batch_size, uint32_t num_qo_heads, - uint32_t page_size) { - batch_size_before_partition_ = batch_size; - uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size; - auto work_estimation_func = - BatchDecodeWithPagedKVCacheWorkEstimationDispatched; - FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, - new_batch_size, batch_size, indptr, num_qo_heads, - page_size, - /*enable_cuda_graph=*/true, stream_)); - // NOTE(Zihao): max_batch_size_after_partition_ is determined in handler initialization. - // the value should not be changed during the lifetime of the handler. - // So it should be compatible with CUDAGraph which requires fixed pointer. - batch_size_after_partition_ = new_batch_size; - size_t max_tmp_size = num_qo_heads * max_batch_size_after_partition_ * - (HEAD_DIM * sizeof(DTypeOut) + 2 * sizeof(float)); - AlignedAlloactor allocator(buffer, workspace_size_in_bytes); - float_buffer_ = allocator.aligned_alloc(max_tmp_size, 16); - new_indptr_ = - allocator.aligned_alloc((max_batch_size_after_partition_ + 1) * sizeof(IdType), 16); - - void* new_indptr_h_ = page_locked_buffer_; - new_last_page_len_ = - allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); - void* new_last_page_len_h_ = - (char*)page_locked_buffer_ + ((char*)new_last_page_len_ - (char*)new_indptr_); - chunk_indptr_ = - allocator.aligned_alloc((max_batch_size_after_partition_ + 1) * sizeof(IdType), 16); - void* chunk_indptr_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_indptr_ - (char*)new_indptr_); - batch_idx_map_ = - allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); - void* batch_idx_map_h_ = - (char*)page_locked_buffer_ + ((char*)batch_idx_map_ - (char*)new_indptr_); - chunk_start_pos_ = - allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); - void* chunk_start_pos_h_ = - (char*)page_locked_buffer_ + ((char*)chunk_start_pos_ - (char*)new_indptr_); - seq_lengths_before_partition_ = - allocator.aligned_alloc(max_batch_size_after_partition_ * sizeof(IdType), 16); - void* seq_lengths_before_partition_h_ = - (char*)page_locked_buffer_ + ((char*)seq_lengths_before_partition_ - (char*)new_indptr_); - - size_t num_bytes_to_copy = (char*)allocator.ptr - (char*)new_indptr_; - FLASHINFER_CUDA_CALL(PartitionPagedKVCacheComputeAuxiliaryInfo( - max_num_pages_per_batch, batch_size, page_size, indptr, last_page_len, - (IdType*)new_indptr_h_, (IdType*)new_last_page_len_h_, (IdType*)chunk_indptr_h_, - (IdType*)batch_idx_map_h_, (IdType*)chunk_start_pos_h_, - (IdType*)seq_lengths_before_partition_h_, new_indptr_, page_locked_buffer_, - num_bytes_to_copy, stream_)); - forward_started_ = true; - return cudaSuccess; - } - CUDAGraphBatchDecodeHandler(size_t max_batch_size) { - int dev_id = 0, num_sm = 0, max_thread_blocks_per_sm = 0; - cudaGetDevice(&dev_id); - cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id); - cudaDeviceGetAttribute(&max_thread_blocks_per_sm, cudaDevAttrMaxBlocksPerMultiprocessor, - dev_id); - max_batch_size_after_partition_ = - std::max(max_thread_blocks_per_sm * num_sm, max_batch_size); - std::cout << max_thread_blocks_per_sm * num_sm << " " << max_batch_size << std::endl; - size_t max_workspace_size_in_bytes = - 6 * (sizeof(uint64_t) * (max_batch_size_after_partition_ + 1) + 16); - cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); - } - bool IsCUDAGraphMode() const override { return true; } - - private: + bool cuda_graph_enabled_; uint32_t max_batch_size_after_partition_; + cudaStream_t stream_; }; class BatchPrefillHandler { @@ -524,11 +515,19 @@ class BatchPrefillHandler { std::vector request_indices_vec, tile_indices_vec; std::tie(num_frags_x_, num_qo_tiles_, request_indices_vec, tile_indices_vec) = split_qo_indptr(qo_indptr, batch_size, gqa_group_size, head_dim, stream_); - AlignedAlloactor allocator(buffer, workspace_size_in_bytes); - request_indices_ = - allocator.aligned_alloc(sizeof(IdType) * request_indices_vec.size(), 16); + AlignedAllocator allocator(buffer, workspace_size_in_bytes); + if (IsCUDAGraphEnabled()) { + request_indices_ = allocator.aligned_alloc(sizeof(IdType) * max_num_qo_tiles_, 16); + } else { + request_indices_ = + allocator.aligned_alloc(sizeof(IdType) * request_indices_vec.size(), 16); + } void* request_indices_h_ = page_locked_buffer_; - tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * tile_indices_vec.size(), 16); + if (IsCUDAGraphEnabled()) { + tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * max_num_qo_tiles_, 16); + } else { + tile_indices_ = allocator.aligned_alloc(sizeof(IdType) * tile_indices_vec.size(), 16); + } void* tile_indices_h_ = (char*)page_locked_buffer_ + ((char*)tile_indices_ - (char*)request_indices_); std::copy(request_indices_vec.begin(), request_indices_vec.end(), (IdType*)request_indices_h_); @@ -554,12 +553,16 @@ class BatchPrefillHandler { void SetCUDAStream(cudaStream_t stream) { stream_ = stream; } - BatchPrefillHandler(size_t max_workspace_size_in_bytes = 64 * 1024 * 1024) + bool IsCUDAGraphEnabled() const { return enable_cuda_graph_; } + + BatchPrefillHandler(size_t max_workspace_size_in_bytes = 64 * 1024 * 1024, + bool enable_cuda_graph = false) : request_indices_(nullptr), tile_indices_(nullptr), num_frags_x_(0U), num_qo_tiles_(0U), forward_started_(false), + enable_cuda_graph_(enable_cuda_graph), stream_(nullptr) { cudaMallocHost(&page_locked_buffer_, max_workspace_size_in_bytes); } @@ -568,7 +571,7 @@ class BatchPrefillHandler { cudaFreeHost(page_locked_buffer_); } - private: + protected: void* page_locked_buffer_; void* request_indices_; void* tile_indices_; @@ -576,6 +579,8 @@ class BatchPrefillHandler { uint32_t num_qo_tiles_; bool forward_started_; cudaStream_t stream_; + bool enable_cuda_graph_; + static constexpr uint32_t max_num_qo_tiles_ = 1024 * 1024; }; } // namespace flashinfer diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 293ed1188e..cce0773d67 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -35,9 +35,7 @@ #include "../pos_enc.cuh" #include "../utils.cuh" #include "cascade.cuh" -#include "handler.cuh" #include "mask.cuh" -#include "state.cuh" namespace flashinfer { diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 660a90a32c..f29d47f4bf 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -679,7 +679,6 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token auto& temp_storage = reinterpret_cast< SamplingTempStorage&>(smem); - bool rejected = false; uint32_t pos = 0; for (pos = 0; pos < num_speculative_tokens; ++pos) { IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + pos]; diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index fa9275a988..253542b851 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -141,32 +141,17 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { return DISPATCH_pos_encoding_mode( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - if (handler_->IsCUDAGraphMode()) { - // NOTE(Zihao): use runtime dispatch because template function is not virtual - auto cuda_graph_handler_ = - dynamic_cast(handler_.get()); - cudaError_t status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched< - GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE, - c_type, nv_half, int32_t>(static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), - batch_size, num_qo_heads, page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ", - cudaGetErrorString(status)); - } else { - cudaError_t status = handler_->BeginForwardDispatched< - GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE, - c_type, nv_half, int32_t>(static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), - batch_size, num_qo_heads, page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - } + cudaError_t status = + handler_->BeginForwardDispatched( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, + page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); return true; }); }); @@ -180,32 +165,17 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { return DISPATCH_pos_encoding_mode( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - if (handler_->IsCUDAGraphMode()) { - // NOTE(Zihao): use runtime dispatch because template function is not virtual - auto cuda_graph_handler_ = - dynamic_cast(handler_.get()); - auto status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched< - GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE, - c_type, c_type, int32_t>(static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), - batch_size, num_qo_heads, page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ", - cudaGetErrorString(status)); - } else { - cudaError_t status = handler_->BeginForwardDispatched< - GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE, - c_type, c_type, int32_t>(static_cast(workspace_buffer.data_ptr()), - workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), - batch_size, num_qo_heads, page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - } + cudaError_t status = + handler_->BeginForwardDispatched( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, + page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); return true; }); }); diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 2653d913fa..b088d07e26 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -44,34 +44,30 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); py::class_(m, "BatchDecodeWithPagedKVCachePyTorchWrapper") - .def(py::init()) + .def(py::init()) .def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward) .def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward) + .def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled) .def("update_page_locked_buffer_size", &BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) .def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward); - py::class_( - m, "CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper") - .def(py::init()) - .def("begin_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward) - .def("end_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::EndForward) - .def("update_page_locked_buffer_size", - &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) - .def("forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::Forward); py::class_( m, "BatchPrefillWithPagedKVCachePyTorchWrapper") - .def(py::init()) + .def(py::init()) .def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward) .def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward) + .def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled) .def("update_page_locked_buffer_size", &BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward) .def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask); py::class_( m, "BatchPrefillWithRaggedKVCachePyTorchWrapper") - .def(py::init()) + .def(py::init()) .def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward) .def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward) + .def("is_cuda_graph_enabled", + &BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled) .def("update_page_locked_buffer_size", &BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize) .def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward) diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 1dff6ba345..a42cc1282c 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -20,11 +20,6 @@ #include #include -// namespace flashinfer { -// class BatchPrefillHandler; -// class BatchDecodeHandler; -// } // namespace flashinfer - torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, unsigned int pos_encoding_mode, unsigned int layout, float sm_scale, float rope_scale, @@ -84,6 +79,7 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { unsigned int pos_encoding_mode, torch::Tensor empty_data); void EndForward(); void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); + bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } std::vector Forward(torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, @@ -93,31 +89,24 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { std::shared_ptr handler_ptr, flashinfer::QKVLayout kv_layout) : handler_(handler_ptr), kv_layout_(kv_layout) {} BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, - unsigned int max_workspace_size_in_bytes) + unsigned int max_workspace_size_in_bytes, + unsigned int max_batch_size, bool enable_cuda_graph) : kv_layout_(flashinfer::QKVLayout(layout)), - handler_(std::make_shared(max_workspace_size_in_bytes)) {} + handler_(std::make_shared( + max_workspace_size_in_bytes, max_batch_size, enable_cuda_graph)) {} protected: std::shared_ptr handler_; flashinfer::QKVLayout kv_layout_; }; -class CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper - : public BatchDecodeWithPagedKVCachePyTorchWrapper { - public: - CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, - unsigned int max_batch_size) - : BatchDecodeWithPagedKVCachePyTorchWrapper( - std::make_shared(max_batch_size), - flashinfer::QKVLayout(layout)) {} -}; - class BatchPrefillWithPagedKVCachePyTorchWrapper { public: void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim); void EndForward(); + bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, @@ -133,9 +122,11 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout, - unsigned int max_workspace_size_in_bytes) + unsigned int max_workspace_size_in_bytes, + bool enable_cuda_graph) : kv_layout_(flashinfer::QKVLayout(layout)), - handler_(std::make_shared(max_workspace_size_in_bytes)) {} + handler_(std::make_shared(max_workspace_size_in_bytes, + enable_cuda_graph)) {} private: std::shared_ptr handler_; @@ -148,6 +139,7 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim); void EndForward(); + bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, bool causal, @@ -162,9 +154,11 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { bool allow_fp16_qk_reduction, float sm_scale, float rope_scale, float rope_theta, bool return_lse); BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout, - unsigned int max_workspace_size_in_bytes) + unsigned int max_workspace_size_in_bytes, + bool enable_cuda_graph) : kv_layout_(flashinfer::QKVLayout(layout)), - handler_(std::make_shared(max_workspace_size_in_bytes)) {} + handler_(std::make_shared(max_workspace_size_in_bytes, + enable_cuda_graph)) {} private: std::shared_ptr handler_; diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 16927ae9f0..5e5aa11527 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -437,28 +437,84 @@ class BatchDecodeWithPagedKVCacheWrapper: manages the lifecycle of these data structures. """ - def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): + def __init__( + self, + workspace_buffer: torch.Tensor, + kv_layout: str = "NHD", + enable_cuda_graph: bool = False, + paged_kv_indptr_buffer: Optional[torch.Tensor] = None, + paged_kv_indices_buffer: Optional[torch.Tensor] = None, + paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None, + ): r"""Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`. Parameters ---------- workspace_buffer : torch.Tensor The user reserved workspace buffer used to store auxiliary data structures, - recommended size is 16MB, the device of the workspace buffer should be the - same as the device of the input tensors. + recommended size is 16MB (128MB if cudagraph enabled), the device of the workspace + buffer should be the same as the device of the input tensors. + kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + + enable_cuda_graph : bool + Whether to enable CUDAGraph for batch decode attention, if enabled, the + auxiliary data structures will be stored in the provided buffers. + + indptr_buffer : Optional[torch.Tensor] + The user reserved buffer on GPU to store the indptr of the paged kv cache, should + be large enough to store the indptr of maximum batch size (``[max_batch_size + 1]``) + during the lifecycle of this wrapper. + Only needed when ``enable_cuda_graph`` is ``True``. + + indices_buffer : Optional[torch.Tensor] + The user reserved buffer on GPU to store the page indices of the paged kv cache, + should be large enough to store the maximum number of page indices + (``max_num_pages``) during the lifecycle of this wrapper. + Only needed when ``enable_cuda_graph`` is ``True``. + + last_page_len_buffer : Optional[torch.Tensor] + The user reserved buffer on GPU to store the number of entries in the last page, + should be large enough to store the maximum batch size (``[max_batch_size]``) + during the lifecycle of this wrapper. + Only needed when ``enable_cuda_graph`` is ``True``. """ check_kv_layout(kv_layout) self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer + # NOTE(Zihao): max_batch_size will only be used in cudagraph mode + max_batch_size = len(paged_kv_last_page_len_buffer) if enable_cuda_graph else 0 self._wrapper = _kernels.BatchDecodeWithPagedKVCachePyTorchWrapper( TensorLayout[kv_layout].value, workspace_buffer.numel() * workspace_buffer.element_size(), + max_batch_size, + enable_cuda_graph, ) - self._paged_kv_indptr = None - self._paged_kv_indices = None - self._paged_kv_last_page_len = None + if enable_cuda_graph: + if not torch.is_tensor(paged_kv_indptr_buffer): + raise ValueError( + "paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode" + ) + if not torch.is_tensor(paged_kv_indices_buffer): + raise ValueError( + "paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode" + ) + if not torch.is_tensor(paged_kv_last_page_len_buffer): + raise ValueError( + "paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode" + ) + self._paged_kv_indptr_buf = paged_kv_indptr_buffer + self._paged_kv_indices_buf = paged_kv_indices_buffer + self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer + else: + self._paged_kv_indptr_buf = None + self._paged_kv_indices_buf = None + self._paged_kv_last_page_len_buf = None + + @property + def is_cuda_graph_enabled(self): + return self._wrapper.is_cuda_graph_enabled() def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): r"""Reset the workspace buffer. @@ -519,9 +575,14 @@ def begin_forward( is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ - self._paged_kv_indptr = indptr - self._paged_kv_indices = indices - self._paged_kv_last_page_len = last_page_len + if self.is_cuda_graph_enabled: + self._paged_kv_indptr_buf[: len(indptr)] = indptr + self._paged_kv_indices_buf[: len(indices)] = indices + self._paged_kv_last_page_len_buf[: len(last_page_len)] = last_page_len + else: + self._paged_kv_indptr_buf = indptr + self._paged_kv_indices_buf = indices + self._paged_kv_last_page_len_buf = last_page_len batch_size = len(indptr) - 1 # NOTE(Zihao): the following tensor acts as placeholder to pass dtype info @@ -546,9 +607,10 @@ def begin_forward( def end_forward(self): r"""Clear auxiliary data structures created by :meth:`begin_forward`.""" - self._paged_kv_indptr = None - self._paged_kv_indices = None - self._paged_kv_last_page_len = None + if not self.is_cuda_graph_enabled: + self._paged_kv_indptr_buf = None + self._paged_kv_indices_buf = None + self._paged_kv_last_page_len_buf = None self._wrapper.end_forward() def forward( @@ -614,9 +676,9 @@ def forward( out = self._wrapper.forward( q, paged_kv_data, - self._paged_kv_indptr, - self._paged_kv_indices, - self._paged_kv_last_page_len, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, PosEncodingMode[pos_encoding_mode].value, sm_scale, rope_scale, @@ -697,9 +759,9 @@ def forward_return_lse( V, s = self._wrapper.forward( q, paged_kv_data, - self._paged_kv_indptr, - self._paged_kv_indices, - self._paged_kv_last_page_len, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, PosEncodingMode[pos_encoding_mode].value, sm_scale, rope_scale, @@ -711,7 +773,7 @@ def forward_return_lse( return V, s -class CUDAGraphBatchDecodeWithPagedKVCacheWrapper: +class CUDAGraphBatchDecodeWithPagedKVCacheWrapper(BatchDecodeWithPagedKVCacheWrapper): r"""CUDAGraph-compatible Wrapper class for decode attention with paged kv-cache (first proposed in `vLLM `_) for batch of requests. @@ -720,7 +782,6 @@ class CUDAGraphBatchDecodeWithPagedKVCacheWrapper: to accomodate the CUDAGraph requirement. Check :ref:`our tutorial` for page table layout. - # TODO(Zihao): update documentation Note ---- @@ -747,279 +808,30 @@ def __init__( The user reserved workspace buffer on GPU used to store auxiliary data structures, recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors. + + kv_layout : str + The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + indptr_buffer : torch.Tensor The user reserved buffer on GPU to store the indptr of the paged kv cache, should be large enough to store the indptr of maximum batch size (``[max_batch_size + 1]``) during the lifecycle of this wrapper. + indices_buffer : torch.Tensor The user reserved buffer on GPU to store the page indices of the paged kv cache, should be large enough to store the maximum number of page indices (``max_num_pages``) during the lifecycle of this wrapper. + last_page_len_buffer : torch.Tensor The user reserved buffer on GPU to store the number of entries in the last page, should be large enough to store the maximum batch size (``[max_batch_size]``) during the lifecycle of this wrapper. - kv_layout : str - The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. - """ - check_kv_layout(kv_layout) - self._kv_layout = kv_layout - self._workspace_buffer = workspace_buffer - max_batch_size = len(last_page_len_buffer) - self._wrapper = _kernels.CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper( - TensorLayout[kv_layout].value, - max_batch_size, - ) - self._paged_kv_indptr_buf = indptr_buffer - self._paged_kv_indices_buf = indices_buffer - self._paged_kv_last_page_len_buf = last_page_len_buffer - - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): - r"""Reset the workspace buffer. - - Parameters - ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should - be the same as the device of the input tensors. - """ - self._workspace_buffer = new_workspace_buffer - - def begin_forward( - self, - indptr: torch.Tensor, - indices: torch.Tensor, - last_page_len: torch.Tensor, - num_qo_heads: int, - num_kv_heads: int, - head_dim: int, - page_size: int, - pos_encoding_mode: str = "NONE", - data_type: Union[str, torch.dtype] = "float16", - ): - r"""Create auxiliary data structures for batch decode for multiple forward calls - within the same decode step. - - Parameters - ---------- - indptr : torch.Tensor - The indptr of the paged kv cache, shape: ``[batch_size + 1]`` - indices_host : torch.Tensor - The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]`` - last_page_len : torch.Tensor - The number of entries in the last page of each request in the paged kv - cache, shape: ``[batch_size]`` - num_qo_heads : int - The number of query/output heads - num_kv_heads : int - The number of key/value heads - head_dim : int - The dimension of the heads - page_size : int - The page size of the paged kv cache - pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be - ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. - data_type : Union[str, torch.dtype] - The data type of the paged kv cache - - Note - ---- - The :meth:`begin_forward` method should be called before any :meth:`forward` or - :meth:`forward_return_lse` calls, auxiliary data structures will be created - during this call and cached for multiple forward calls. - - The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` - is not equal to ``num_kv_heads``, the function will use - `grouped query attention `_. """ - - self._paged_kv_indptr_buf[: len(indptr)] = indptr - self._paged_kv_indices_buf[: len(indices)] = indices - self._paged_kv_last_page_len_buf[: len(last_page_len)] = last_page_len - - batch_size = len(indptr) - 1 - # NOTE(Zihao): the following tensor acts as placeholder to pass dtype info - empty_data = torch.empty( - 0, - dtype=( - getattr(torch, data_type) if isinstance(data_type, str) else data_type - ), - ) - self._wrapper.begin_forward( - self._workspace_buffer, - indptr, - last_page_len, - batch_size, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - PosEncodingMode[pos_encoding_mode].value, - empty_data, - ) - - def end_forward(self): - r"""Clear auxiliary data structures created by :meth:`begin_forward`.""" - self._wrapper.end_forward() - - def forward( - self, - q: torch.Tensor, - paged_kv_data: torch.Tensor, - pos_encoding_mode: str = "NONE", - q_scale: Optional[float] = None, - k_scale: Optional[float] = None, - v_scale: Optional[float] = None, - sm_scale: Optional[float] = None, - rope_scale: Optional[float] = None, - rope_theta: Optional[float] = None, - ): - r"""Compute batch decode attention between query and paged kv cache. - - Parameters - ---------- - q : torch.Tensor - The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]`` - paged_kv_data : torch.Tensor - A 5-D tensor of the reserved paged kv-cache data, shape: - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if - :attr:`kv_layout` is ``NHD``, or - ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if - :attr:`kv_layout` is ``HND``. - pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be - ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. - q_scale : Optional[float] - The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. - k_scale : Optional[float] - The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. - v_scale : Optional[float] - The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. - sm_scale : Optional[float] - The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. - rope_scale : Optional[float] - The scale used in RoPE interpolation, if not provided, will be set to - ``1.0``. - rope_theta : Optional[float] - The theta used in RoPE, if not provided, will be set to ``1e4``. - - Returns - ------- - torch.Tensor - The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. - """ - check_pos_encoding_mode(pos_encoding_mode) - if sm_scale is None: - head_dim = q.shape[-1] - sm_scale = 1.0 / math.sqrt(head_dim) - if q_scale is not None: - sm_scale *= q_scale - if k_scale is not None: - sm_scale *= k_scale - if rope_scale is None: - rope_scale = 1.0 - if rope_theta is None: - rope_theta = 1e4 - - paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) - out = self._wrapper.forward( - q, - paged_kv_data, - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - PosEncodingMode[pos_encoding_mode].value, - sm_scale, - rope_scale, - rope_theta, - False, - )[0] - if v_scale is not None: - out *= v_scale - return out - - def forward_return_lse( - self, - q: torch.Tensor, - paged_kv_data: torch.Tensor, - pos_encoding_mode: str = "NONE", - q_scale: Optional[float] = None, - k_scale: Optional[float] = None, - v_scale: Optional[float] = None, - sm_scale: Optional[float] = None, - rope_scale: Optional[float] = None, - rope_theta: Optional[float] = None, - ): - r"""Compute batch decode attention with paged kv cache, return attention output - and logsumexp of attention scores. - - Parameters - ---------- - q : torch.Tensor - The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]`` - paged_kv_data : torch.Tensor - A 5-D tensor of the reserved paged kv-cache data, shape: - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if - :attr:`kv_layout` is ``NHD``, or - ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if - :attr:`kv_layout` is ``HND``. - pos_encoding_mode : str - Whether to apply RoPE on-the-fly inside attention kernels, could be - ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. - q_scale : Optional[float] - The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. - k_scale : Optional[float] - The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. - v_scale : Optional[float] - The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. - sm_scale : Optional[float] - The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. - rope_scale : Optional[float] - The scale used in RoPE interpolation, if not provided, will be set to - ``1.0``. - rope_theta : Optional[float] - The theta used in RoPE, if not provided, will be set to ``1e4``. - - Returns - ------- - V : torch.Tensor - The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. - S : torch.Tensor - The logsumexp of attention scores, Shape: ``[batch_size, num_qo_heads]``. - - Notes - ----- - Please refer to the :ref:`tutorial ` for a detailed - explanation of the log-sum-exp function and attention states. - """ - check_pos_encoding_mode(pos_encoding_mode) - if sm_scale is None: - head_dim = q.shape[-1] - sm_scale = 1.0 / math.sqrt(head_dim) - if q_scale is not None: - sm_scale *= q_scale - if k_scale is not None: - sm_scale *= k_scale - if rope_scale is None: - rope_scale = 1.0 - if rope_theta is None: - rope_theta = 1e4 - paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) - V, s = self._wrapper.forward( - q, - paged_kv_data, - self._paged_kv_indptr_buf, - self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - self._batch_size, - self._nnz_pages, - PosEncodingMode[pos_encoding_mode].value, - sm_scale, - rope_scale, - rope_theta, + super().__init__( + workspace_buffer, + kv_layout, True, + indptr_buffer, + indices_buffer, + last_page_len_buffer, ) - if v_scale is not None: - V *= v_scale - return V, s diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index fb2f5db14e..897a907c5a 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -473,7 +473,18 @@ class BatchPrefillWithPagedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ - def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): + def __init__( + self, + workspace_buffer: torch.Tensor, + kv_layout: str = "NHD", + enable_cuda_graph: bool = False, + qo_indptr_buf: Optional[torch.Tensor] = None, + paged_kv_indptr_buf: Optional[torch.Tensor] = None, + paged_kv_indices_buf: Optional[torch.Tensor] = None, + paged_kv_last_page_len_buf: Optional[torch.Tensor] = None, + custom_mask_buf: Optional[torch.Tensor] = None, + qk_indptr_buf: Optional[torch.Tensor] = None, + ): r"""Constructor of :class:`BatchDecodeWithPagedKVCacheWrapper`. Parameters @@ -482,8 +493,49 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): The user reserved workspace buffer used to store auxiliary data structures, recommended size is 16MB, the device of the workspace buffer should be the same as the device of the input tensors. + kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + + enable_cuda_graph : bool + Whether to enable CUDA graph capture for the prefill kernels, if enabled, the + auxiliary data structures will be stored in provided buffers. + + qo_indptr_buf : Optional[torch.Tensor] + The user reserved buffer to store the ``qo_indptr`` array, should be large + enough to store the maximum possible size of the ``qo_indptr`` array during the + lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph`` + is set to ``True``. + + paged_kv_indptr_buf : Optional[torch.Tensor] + The user reserved buffer to store the ``paged_kv_indptr`` array, should be large + enough to store the maximum possible size of the ``paged_kv_indptr`` array during + the lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph`` + is set to ``True``. + + paged_kv_indices_buf : Optional[torch.Tensor] + The user reserved buffer to store the ``paged_kv_indices`` array, should be large + enough to store the maximum possible size of the ``paged_kv_indices`` array during + the lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph`` + is set to ``True``. + + paged_kv_last_page_len_buf : Optional[torch.Tensor] + The user reserved buffer to store the ``paged_kv_last_page_len`` array, should be + large enough to store the maximum possible size of the ``paged_kv_last_page_len`` array + during the lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph`` + is set to ``True``. + + custom_mask_buf : Optional[torch.Tensor] + The user reserved buffer to store the custom mask tensor, should be large enough to + store the maximum possible size of the custom mask tensor during the lifetime of the + wrapper. This argument is only effective when ``enable_cuda_graph`` is set to ``True`` + and the custom mask will be used in attention computation. + + qk_indptr_buf : Optional[torch.Tensor] + The user reserved buffer to store the ``qk_indptr`` array, should be large enough to + store the maximum possible size of the ``qk_indptr`` array during the lifetime of the + wrapper. This argument is only effective when ``enable_cuda_graph`` is set to ``True`` + and the custom mask will be used in attention computation. """ check_kv_layout(kv_layout) self._kv_layout = kv_layout @@ -491,13 +543,37 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper( TensorLayout[kv_layout].value, workspace_buffer.numel() * workspace_buffer.element_size(), + enable_cuda_graph, ) - self._qo_indptr = None - self._paged_kv_indptr = None - self._paged_kv_indices = None - self._paged_kv_last_page_len = None - self._custom_mask = None - self._qk_indptr = None + if enable_cuda_graph: + if not torch.is_tensor(qo_indptr_buf): + raise ValueError( + "qo_indptr_buf should be a torch.Tensor in CUDA graph mode" + ) + if not torch.is_tensor(paged_kv_indptr_buf): + raise ValueError( + "paged_kv_indptr_buf should be a torch.Tensor in CUDA graph mode" + ) + if not torch.is_tensor(paged_kv_indices_buf): + raise ValueError( + "paged_kv_indices_buf should be a torch.Tensor in CUDA graph mode" + ) + if not torch.is_tensor(paged_kv_last_page_len_buf): + raise ValueError( + "paged_kv_last_page_len_buf should be a torch.Tensor in CUDA graph mode" + ) + # NOTE(Zihao): do not check custom_mask_buf and qk_indptr_buf here, as they are optional + + self._qo_indptr_buf = qo_indptr_buf + self._paged_kv_indptr_buf = paged_kv_indptr_buf + self._paged_kv_indices_buf = paged_kv_indices_buf + self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf + self._custom_mask_buf = custom_mask_buf + self._qk_indptr_buf = qk_indptr_buf + + @property + def is_cuda_graph_enabled(self): + return self._wrapper.is_cuda_graph_enabled() def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): r"""Reset the workspace buffer. @@ -563,18 +639,44 @@ def begin_forward( `grouped query attention `_. """ batch_size = len(qo_indptr) - 1 - self._qo_indptr = qo_indptr - self._paged_kv_indptr = paged_kv_indptr - self._paged_kv_indices = paged_kv_indices - self._paged_kv_last_page_len = paged_kv_last_page_len - if custom_mask is not None: - self._custom_mask = custom_mask - self._qk_indptr = _compute_page_qk_indptr( - qo_indptr, - paged_kv_indptr, - paged_kv_last_page_len, - page_size, + if self.is_cuda_graph_enabled: + self._qo_indptr_buf[: len(qo_indptr)] = qo_indptr + self._paged_kv_indptr_buf[: len(paged_kv_indptr)] = paged_kv_indptr + self._paged_kv_indices_buf[: len(paged_kv_indices)] = paged_kv_indices + self._paged_kv_last_page_len_buf[: len(paged_kv_last_page_len)] = ( + paged_kv_last_page_len ) + + if custom_mask is not None: + if not torch.is_tensor(self._custom_mask_buf): + raise ValueError( + "custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." + ) + if not torch.is_tensor(self._qk_indptr_buf): + raise ValueError( + "qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." + ) + self._custom_mask_buf[: len(custom_mask)] = custom_mask + # NOTE(Zihao): qk_indptr has the same length as qo_indptr + self._qk_indptr_buf[: len(qo_indptr)] = _compute_page_qk_indptr( + qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, + page_size, + ) + else: + self._qo_indptr_buf = qo_indptr + self._paged_kv_indptr_buf = paged_kv_indptr + self._paged_kv_indices_buf = paged_kv_indices + self._paged_kv_last_page_len_buf = paged_kv_last_page_len + if custom_mask is not None: + self._custom_mask = custom_mask + self._qk_indptr = _compute_page_qk_indptr( + qo_indptr, + paged_kv_indptr, + paged_kv_last_page_len, + page_size, + ) self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, @@ -586,12 +688,13 @@ def begin_forward( def end_forward(self): r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" - self._qo_indptr = None - self._paged_kv_indptr = None - self._paged_kv_indices = None - self._paged_kv_last_page_len = None - self._custom_mask = None - self._qk_indptr = None + if not self.is_cuda_graph_enabled: + self._qo_indptr = None + self._paged_kv_indptr = None + self._paged_kv_indices = None + self._paged_kv_last_page_len = None + self._custom_mask = None + self._qk_indptr = None self._wrapper.end_forward() def forward( @@ -657,14 +760,14 @@ def forward( paged_kv_data = paged_kv_data.to(torch.float16) paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) - if self._custom_mask is None: + if self._custom_mask_buf is None: return self._wrapper.forward( q, - self._qo_indptr, + self._qo_indptr_buf, paged_kv_data, - self._paged_kv_indptr, - self._paged_kv_indices, - self._paged_kv_last_page_len, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, causal, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, @@ -676,13 +779,13 @@ def forward( else: return self._wrapper.forward_custom_mask( q, - self._qo_indptr, + self._qo_indptr_buf, paged_kv_data, - self._paged_kv_indptr, - self._paged_kv_indices, - self._paged_kv_last_page_len, - self._custom_mask, - self._qk_indptr, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, + self._custom_mask_buf, + self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, @@ -758,11 +861,11 @@ def forward_return_lse( if self._custom_mask is None: return self._wrapper.forward( q, - self._qo_indptr, + self._qo_indptr_buf, paged_kv_data, - self._paged_kv_indptr, - self._paged_kv_indices, - self._paged_kv_last_page_len, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, causal, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, @@ -774,13 +877,13 @@ def forward_return_lse( else: return self._wrapper.forward( q, - self._qo_indptr, + self._qo_indptr_buf, paged_kv_data, - self._paged_kv_indptr, - self._paged_kv_indices, - self._paged_kv_last_page_len, - self._custom_mask, - self._qk_indptr, + self._paged_kv_indptr_buf, + self._paged_kv_indices_buf, + self._paged_kv_last_page_len_buf, + self._custom_mask_buf, + self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, @@ -841,7 +944,7 @@ class BatchPrefillWithRaggedKVCacheWrapper: ... ) >>> outputs = [] >>> for i in range(num_layers): - ... q = q_at_layer[i] + ... q = q_at_layer[i] ... k = k_at_layer[i] ... v = v_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures @@ -886,7 +989,7 @@ class BatchPrefillWithRaggedKVCacheWrapper: ... >>> # clear auxiliary data structures >>> prefill_wrapper.end_forward() - + Note ---- @@ -896,7 +999,16 @@ class BatchPrefillWithRaggedKVCacheWrapper: wrapper class manages the lifecycle of these data structures. """ - def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): + def __init__( + self, + workspace_buffer: torch.Tensor, + kv_layout: str = "NHD", + enable_cuda_graph: bool = False, + qo_indptr_buf: Optional[torch.Tensor] = None, + kv_indptr_buf: Optional[torch.Tensor] = None, + custom_mask_buf: Optional[torch.Tensor] = None, + qk_indptr_buf: Optional[torch.Tensor] = None, + ): r"""Constructor of :class:`BatchDecodeWithRaggedKVCacheWrapper`. Parameters @@ -905,8 +1017,38 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): The user reserved workspace buffer used to store auxiliary data structures, recommended size is 16MB, the device of the workspace buffer should be the same as the device of the input tensors. + kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. + + enable_cuda_graph : bool + Whether to enable CUDA graph capture for the prefill kernels, if enabled, the + auxiliary data structures will be stored in the provided buffers. + + qo_indptr_buf : Optional[torch.Tensor] + The user reserved GPU buffer to store the ``qo_indptr`` array, should be large + enough to store the maximum possible size of the ``qo_indptr`` array during the + lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph`` + is ``True``. + + kv_indptr_buf : Optional[torch.Tensor] + The user reserved GPU buffer to store the ``kv_indptr`` array, should be large + enough to store the maximum possible size of the ``kv_indptr`` array during the + lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph`` + is ``True``. + + custom_mask_buf : Optional[torch.Tensor] + The user reserved GPU buffer to store the custom mask tensor, should be large + enough to store the maximum possible size of the custom mask tensor during the + lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph`` + is ``True`` and custom mask will be used in attention computation. + + qk_indptr_buf : Optional[torch.Tensor] + The user reserved GPU buffer to store the ``qk_indptr`` array, should be large + enough to store the maximum possible size of the ``qk_indptr`` array during the + lifetime of the wrapper. This argument is only effective when ``enable_cuda_graph`` + is ``True`` and custom mask will be used in attention computation. + """ check_kv_layout(kv_layout) self._kv_layout = kv_layout @@ -914,11 +1056,28 @@ def __init__(self, workspace_buffer: torch.Tensor, kv_layout: str = "NHD"): self._wrapper = _kernels.BatchPrefillWithRaggedKVCachePyTorchWrapper( TensorLayout[kv_layout].value, workspace_buffer.numel() * workspace_buffer.element_size(), + enable_cuda_graph, ) - self._qo_indptr = None - self._kv_indptr = None - self._custom_mask = None - self._qk_indptr = None + if enable_cuda_graph: + if not torch.is_tensor(qo_indptr_buf): + raise ValueError( + "qo_indptr_buf should be a torch.Tensor in cuda graph mode" + ) + if not torch.is_tensor(kv_indptr_buf): + raise ValueError( + "kv_indptr_buf should be a torch.Tensor in cuda graph mode" + ) + # NOTE(Zihao): do not check custom_mask_buf and qk_indptr_buf here, + # as they may not be used. + + self._qo_indptr_buf = qo_indptr_buf + self._kv_indptr_buf = kv_indptr_buf + self._custom_mask_buf = custom_mask_buf + self._qk_indptr_buf = qk_indptr_buf + + @property + def is_cuda_graph_enabled(self): + return self._wrapper.is_cuda_graph_enabled() def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor): r"""Reset the workspace buffer. @@ -974,11 +1133,28 @@ def begin_forward( `grouped query attention `_. """ batch_size = len(qo_indptr) - 1 - self._qo_indptr = qo_indptr - self._kv_indptr = kv_indptr - if custom_mask is not None: - self._custom_mask = custom_mask - self._qk_indptr = _compute_qk_indptr(qo_indptr, kv_indptr) + if self.is_cuda_graph_enabled: + self._qo_indptr_buf[: len(qo_indptr)] = qo_indptr + self._kv_indptr_buf[: len(kv_indptr)] = kv_indptr + if custom_mask is not None: + if not torch.is_tensor(self._custom_mask_buf): + raise ValueError( + "custom_mask_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in attention computation." + ) + if not torch.is_tensor(self._qk_indptr_buf): + raise ValueError( + "qk_indptr_buf must be initialized with a torch.Tensor in cuda graph mode if we use custom mask in the attention computation." + ) + self._custom_mask_buf[: len(custom_mask)] = custom_mask + self._qk_indptr_buf[: len(qo_indptr)] = _compute_qk_indptr( + qo_indptr, kv_indptr + ) + else: + self._qo_indptr_buf = qo_indptr + self._kv_indptr_buf = kv_indptr + if custom_mask is not None: + self._custom_mask_buf = custom_mask + self._qk_indptr_buf = _compute_qk_indptr(qo_indptr, kv_indptr) self._wrapper.begin_forward( self._workspace_buffer, qo_indptr, @@ -990,10 +1166,11 @@ def begin_forward( def end_forward(self): r"""Clear the auxiliary data structures created by :meth:`begin_forward`.""" - self._qo_indptr = None - self._kv_indptr = None - self._custom_mask = None - self._qk_indptr = None + if not self.is_cuda_graph_enabled: + self._qo_indptr = None + self._kv_indptr = None + self._custom_mask = None + self._qk_indptr = None self._wrapper.end_forward() def forward( @@ -1057,13 +1234,13 @@ def forward( q = q.to(torch.float16) k = k.to(torch.float16) v = v.to(torch.float16) - if self._custom_mask is None: + if self._custom_mask_buf is None: return self._wrapper.forward( q, - self._qo_indptr, + self._qo_indptr_buf, k, v, - self._kv_indptr, + self._kv_indptr_buf, causal, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, @@ -1075,12 +1252,12 @@ def forward( else: return self._wrapper.forward_custom_mask( q, - self._qo_indptr, + self._qo_indptr_buf, k, v, - self._kv_indptr, - self._custom_mask, - self._qk_indptr, + self._kv_indptr_buf, + self._custom_mask_buf, + self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, @@ -1155,10 +1332,10 @@ def forward_return_lse( if self._custom_mask is None: return self._wrapper.forward( q, - self._qo_indptr, + self._qo_indptr_buf, k, v, - self._kv_indptr, + self._kv_indptr_buf, causal, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, @@ -1170,12 +1347,12 @@ def forward_return_lse( else: return self._wrapper.forward_custom_mask( q, - self._qo_indptr, + self._qo_indptr_buf, k, v, - self._kv_indptr, - self._custom_mask, - self._qk_indptr, + self._kv_indptr_buf, + self._custom_mask_buf, + self._qk_indptr_buf, PosEncodingMode[pos_encoding_mode].value, allow_fp16_qk_reduction, sm_scale, diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index 13bba6f1a1..78812bcd89 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -192,7 +192,8 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): o = wrapper.forward(q, kv_data.to(dtype), pos_encoding_mode=pos_encoding_mode) - + wrapper.end_forward() + # replay wrapper.begin_forward( kv_indptr_host, kv_indices_host, diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index a7b0754ee9..a72704dc2c 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -31,6 +31,7 @@ @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) +@pytest.mark.parametrize("enable_cuda_graph", [False, True]) def test_batch_prefill_with_paged_kv_cache( batch_size, kv_len, @@ -42,9 +43,10 @@ def test_batch_prefill_with_paged_kv_cache( causal, kv_layout, pos_encoding_mode, + enable_cuda_graph ): q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() - q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len + q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( @@ -54,41 +56,99 @@ def test_batch_prefill_with_paged_kv_cache( .to(0) .half() ) - kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq - kv_indices = torch.arange(0, total_num_pages).to(0).int() - kv_last_page_len = torch.full( + kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq + kv_indices_cpu = torch.arange(0, total_num_pages).int() + kv_last_page_len_cpu = torch.full( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 - ).to(0) + ) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) - wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, kv_layout - ) - wrapper.begin_forward( - q_indptr, - kv_indptr, - kv_indices, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - ) - o = wrapper.forward(q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode) + if not enable_cuda_graph: + q_indptr_gpu = q_indptr_cpu.to(0) + kv_indptr_gpu = kv_indptr_cpu.to(0) + kv_indices_gpu = kv_indices_cpu.to(0) + kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.begin_forward( + q_indptr_gpu, + kv_indptr_gpu, + kv_indices_gpu, + kv_last_page_len_gpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + ) + o = wrapper.forward(q, kv_data, causal=causal, pos_encoding_mode=pos_encoding_mode) + else: + q_indptr_buffer = torch.empty(batch_size + 1).int().to(0) + kv_indptr_buffer = torch.empty(batch_size + 1).int().to(0) + kv_indices_buffer = torch.empty(total_num_pages).int().to(0) + kv_last_page_len_buffer = torch.empty(batch_size).int().to(0) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, enable_cuda_graph=True, + qo_indptr_buf=q_indptr_buffer, + paged_kv_indptr_buf=kv_indptr_buffer, + paged_kv_indices_buf=kv_indices_buffer, + paged_kv_last_page_len_buf=kv_last_page_len_buffer + ) + q_indptr_warmup = torch.arange(0, batch_size + 1).int() * qo_len + kv_indptr_warmup = torch.arange(0, batch_size + 1).int() + kv_indices_warmup = torch.arange(0, batch_size).int() + kv_last_page_len_warmup = torch.full((batch_size,), page_size, dtype=torch.int32) + wrapper.begin_forward( + q_indptr_warmup, + kv_indptr_warmup, + kv_indices_warmup, + kv_last_page_len_warmup, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + ) + # warmup + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + o = wrapper.forward( + q, kv_data, pos_encoding_mode=pos_encoding_mode + ) + torch.cuda.current_stream().wait_stream(s) + # capture + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + o = wrapper.forward(q, kv_data, pos_encoding_mode=pos_encoding_mode) + wrapper.end_forward() + + wrapper.begin_forward( + q_indptr_cpu, + kv_indptr_cpu, + kv_indices_cpu, + kv_last_page_len_cpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + ) + + g.replay() for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] - qi = q[q_indptr[i] : q_indptr[i + 1]] + qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] ki = torch.cat( [ - kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] + kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 0] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( - kv_data[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] + kv_data[kv_indptr_cpu[i + 1] - 1, 0, :, : kv_last_page_len_cpu[i]] if kv_layout == "HND" - else kv_data[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :] + else kv_data[kv_indptr_cpu[i + 1] - 1, 0, : kv_last_page_len_cpu[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), @@ -97,13 +157,13 @@ def test_batch_prefill_with_paged_kv_cache( ) vi = torch.cat( [ - kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] + kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 1] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( - kv_data[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] + kv_data[kv_indptr_cpu[i + 1] - 1, 1, :, : kv_last_page_len_cpu[i]] if kv_layout == "HND" - else kv_data[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :] + else kv_data[kv_indptr_cpu[i + 1] - 1, 1, : kv_last_page_len_cpu[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), @@ -113,7 +173,7 @@ def test_batch_prefill_with_paged_kv_cache( o_ref_i = flashinfer.single_prefill_with_kv_cache( qi, ki, vi, causal=causal, pos_encoding_mode=pos_encoding_mode ) - o_i_np = o[q_indptr[i] : q_indptr[i + 1]].cpu().numpy() + o_i_np = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]].cpu().numpy() o_ref_i_np = o_ref_i.cpu().numpy() numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) @@ -311,13 +371,14 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( if __name__ == "__main__": test_batch_prefill_with_paged_kv_cache( - 12, 54, 37, 8, 8, 8, 128, True, "HND", "NONE" + 12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True ) test_batch_prefill_with_paged_kv_cache( - 12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE" + 12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE", False ) test_batch_prefill_with_paged_kv_cache_custom_mask( 12, 137, 137, 1, 8, 8, 128, "HND", "NONE" ) test_batch_prefill_with_ragged_kv_cache(12, 54, 37, 8, 8, 128, True, "NONE") test_batch_prefill_with_ragged_kv_cache_custom_mask(12, 137, 137, 8, 8, 128, "NONE") +