Skip to content

Commit 081a4c5

Browse files
authored
Revert "feat: support cuda graph for batched multi-query(prefill/append) attention" (#276)
Reverts #275
1 parent 83ceb67 commit 081a4c5

File tree

9 files changed

+538
-489
lines changed

9 files changed

+538
-489
lines changed

include/flashinfer/attention/decode.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "../utils.cuh"
3737
#include "../vec_dtypes.cuh"
3838
#include "cascade.cuh"
39+
#include "handler.cuh"
3940
#include "state.cuh"
4041

4142
namespace flashinfer {

include/flashinfer/attention/handler.cuh

Lines changed: 102 additions & 107 deletions
Large diffs are not rendered by default.

include/flashinfer/attention/prefill.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
#include "../pos_enc.cuh"
3636
#include "../utils.cuh"
3737
#include "cascade.cuh"
38+
#include "handler.cuh"
3839
#include "mask.cuh"
40+
#include "state.cuh"
3941

4042
namespace flashinfer {
4143

include/flashinfer/sampling.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,7 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token
679679
auto& temp_storage = reinterpret_cast<
680680
SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem);
681681

682+
bool rejected = false;
682683
uint32_t pos = 0;
683684
for (pos = 0; pos < num_speculative_tokens; ++pos) {
684685
IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + pos];

python/csrc/batch_decode.cu

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -141,17 +141,32 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
141141
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
142142
return DISPATCH_pos_encoding_mode(
143143
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
144-
cudaError_t status =
145-
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
146-
KV_LAYOUT, POS_ENCODING_MODE, c_type,
147-
nv_half, int32_t>(
148-
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
149-
static_cast<int32_t*>(indptr.data_ptr()),
150-
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
151-
page_size);
152-
TORCH_CHECK(status == cudaSuccess,
153-
"BatchDecodeWithPagedKVCache failed with error ",
154-
cudaGetErrorString(status));
144+
if (handler_->IsCUDAGraphMode()) {
145+
// NOTE(Zihao): use runtime dispatch because template function is not virtual
146+
auto cuda_graph_handler_ =
147+
dynamic_cast<CUDAGraphBatchDecodeHandler*>(handler_.get());
148+
cudaError_t status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched<
149+
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
150+
c_type, nv_half, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
151+
workspace_size_in_bytes,
152+
static_cast<int32_t*>(indptr.data_ptr()),
153+
static_cast<int32_t*>(last_page_len.data_ptr()),
154+
batch_size, num_qo_heads, page_size);
155+
TORCH_CHECK(status == cudaSuccess,
156+
"BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ",
157+
cudaGetErrorString(status));
158+
} else {
159+
cudaError_t status = handler_->BeginForwardDispatched<
160+
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
161+
c_type, nv_half, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
162+
workspace_size_in_bytes,
163+
static_cast<int32_t*>(indptr.data_ptr()),
164+
static_cast<int32_t*>(last_page_len.data_ptr()),
165+
batch_size, num_qo_heads, page_size);
166+
TORCH_CHECK(status == cudaSuccess,
167+
"BatchDecodeWithPagedKVCache failed with error ",
168+
cudaGetErrorString(status));
169+
}
155170
return true;
156171
});
157172
});
@@ -165,17 +180,32 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(
165180
return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] {
166181
return DISPATCH_pos_encoding_mode(
167182
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
168-
cudaError_t status =
169-
handler_->BeginForwardDispatched<GROUP_SIZE, HEAD_DIM, PageStorage::kIndices,
170-
KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type,
171-
int32_t>(
172-
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
173-
static_cast<int32_t*>(indptr.data_ptr()),
174-
static_cast<int32_t*>(last_page_len.data_ptr()), batch_size, num_qo_heads,
175-
page_size);
176-
TORCH_CHECK(status == cudaSuccess,
177-
"BatchDecodeWithPagedKVCache failed with error ",
178-
cudaGetErrorString(status));
183+
if (handler_->IsCUDAGraphMode()) {
184+
// NOTE(Zihao): use runtime dispatch because template function is not virtual
185+
auto cuda_graph_handler_ =
186+
dynamic_cast<CUDAGraphBatchDecodeHandler*>(handler_.get());
187+
auto status = cuda_graph_handler_->CUDAGraphBeginForwardDispatched<
188+
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
189+
c_type, c_type, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
190+
workspace_size_in_bytes,
191+
static_cast<int32_t*>(indptr.data_ptr()),
192+
static_cast<int32_t*>(last_page_len.data_ptr()),
193+
batch_size, num_qo_heads, page_size);
194+
TORCH_CHECK(status == cudaSuccess,
195+
"BatchDecodeWithPagedKVCache (CUDAGraph Mode) failed with error ",
196+
cudaGetErrorString(status));
197+
} else {
198+
cudaError_t status = handler_->BeginForwardDispatched<
199+
GROUP_SIZE, HEAD_DIM, PageStorage::kIndices, KV_LAYOUT, POS_ENCODING_MODE,
200+
c_type, c_type, int32_t>(static_cast<void*>(workspace_buffer.data_ptr()),
201+
workspace_size_in_bytes,
202+
static_cast<int32_t*>(indptr.data_ptr()),
203+
static_cast<int32_t*>(last_page_len.data_ptr()),
204+
batch_size, num_qo_heads, page_size);
205+
TORCH_CHECK(status == cudaSuccess,
206+
"BatchDecodeWithPagedKVCache failed with error ",
207+
cudaGetErrorString(status));
208+
}
179209
return true;
180210
});
181211
});

python/csrc/flashinfer_ops.cu

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,30 +44,34 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4444
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
4545
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
4646
"BatchDecodeWithPagedKVCachePyTorchWrapper")
47-
.def(py::init<unsigned int, unsigned int, bool>())
47+
.def(py::init<unsigned int, unsigned int>())
4848
.def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
4949
.def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
50-
.def("is_cuda_graph_enabled", &BatchDecodeWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
5150
.def("update_page_locked_buffer_size",
5251
&BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
5352
.def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
53+
py::class_<CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper>(
54+
m, "CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper")
55+
.def(py::init<unsigned int, unsigned int>())
56+
.def("begin_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
57+
.def("end_forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
58+
.def("update_page_locked_buffer_size",
59+
&CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
60+
.def("forward", &CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
5461
py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>(
5562
m, "BatchPrefillWithPagedKVCachePyTorchWrapper")
56-
.def(py::init<unsigned int, unsigned int, bool>())
63+
.def(py::init<unsigned int, unsigned int>())
5764
.def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward)
5865
.def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward)
59-
.def("is_cuda_graph_enabled", &BatchPrefillWithPagedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
6066
.def("update_page_locked_buffer_size",
6167
&BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
6268
.def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward)
6369
.def("forward_custom_mask", &BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask);
6470
py::class_<BatchPrefillWithRaggedKVCachePyTorchWrapper>(
6571
m, "BatchPrefillWithRaggedKVCachePyTorchWrapper")
66-
.def(py::init<unsigned int, unsigned int, bool>())
72+
.def(py::init<unsigned int, unsigned int>())
6773
.def("begin_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward)
6874
.def("end_forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward)
69-
.def("is_cuda_graph_enabled",
70-
&BatchPrefillWithRaggedKVCachePyTorchWrapper::IsCUDAGraphEnabled)
7175
.def("update_page_locked_buffer_size",
7276
&BatchPrefillWithRaggedKVCachePyTorchWrapper::UpdatePageLockedBufferSize)
7377
.def("forward", &BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward)

python/csrc/flashinfer_ops.h

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
8484
unsigned int pos_encoding_mode, torch::Tensor empty_data);
8585
void EndForward();
8686
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
87-
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
8887
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor paged_kv_data,
8988
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
9089
torch::Tensor paged_kv_last_page_len,
@@ -93,24 +92,32 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
9392
BatchDecodeWithPagedKVCachePyTorchWrapper(
9493
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_ptr, flashinfer::QKVLayout kv_layout)
9594
: handler_(handler_ptr), kv_layout_(kv_layout) {}
96-
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout, unsigned int max_batch_size,
97-
bool enable_cuda_graph)
95+
BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout,
96+
unsigned int max_workspace_size_in_bytes)
9897
: kv_layout_(flashinfer::QKVLayout(layout)),
99-
handler_(
100-
std::make_shared<flashinfer::BatchDecodeHandler>(max_batch_size, enable_cuda_graph)) {}
98+
handler_(std::make_shared<flashinfer::BatchDecodeHandler>(max_workspace_size_in_bytes)) {}
10199

102100
protected:
103101
std::shared_ptr<flashinfer::BatchDecodeHandler> handler_;
104102
flashinfer::QKVLayout kv_layout_;
105103
};
106104

105+
class CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper
106+
: public BatchDecodeWithPagedKVCachePyTorchWrapper {
107+
public:
108+
CUDAGraphBatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout,
109+
unsigned int max_batch_size)
110+
: BatchDecodeWithPagedKVCachePyTorchWrapper(
111+
std::make_shared<flashinfer::CUDAGraphBatchDecodeHandler>(max_batch_size),
112+
flashinfer::QKVLayout(layout)) {}
113+
};
114+
107115
class BatchPrefillWithPagedKVCachePyTorchWrapper {
108116
public:
109117
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
110118
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
111119
unsigned int head_dim);
112120
void EndForward();
113-
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
114121
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
115122
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
116123
torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
@@ -126,11 +133,9 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper {
126133
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale,
127134
float rope_scale, float rope_theta, bool return_lse);
128135
BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout,
129-
unsigned int max_workspace_size_in_bytes,
130-
bool enable_cuda_graph)
136+
unsigned int max_workspace_size_in_bytes)
131137
: kv_layout_(flashinfer::QKVLayout(layout)),
132-
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes,
133-
enable_cuda_graph)) {}
138+
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes)) {}
134139

135140
private:
136141
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;
@@ -143,7 +148,6 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
143148
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
144149
unsigned int head_dim);
145150
void EndForward();
146-
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
147151
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);
148152
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k,
149153
torch::Tensor v, torch::Tensor kv_indptr, bool causal,
@@ -158,11 +162,9 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper {
158162
bool allow_fp16_qk_reduction, float sm_scale,
159163
float rope_scale, float rope_theta, bool return_lse);
160164
BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout,
161-
unsigned int max_workspace_size_in_bytes,
162-
bool enable_cuda_graph)
165+
unsigned int max_workspace_size_in_bytes)
163166
: kv_layout_(flashinfer::QKVLayout(layout)),
164-
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes,
165-
enable_cuda_graph)) {}
167+
handler_(std::make_shared<flashinfer::BatchPrefillHandler>(max_workspace_size_in_bytes)) {}
166168

167169
private:
168170
std::shared_ptr<flashinfer::BatchPrefillHandler> handler_;

0 commit comments

Comments
 (0)