Skip to content

Commit 7adc8cf

Browse files
authored
bugfix: fix prefill/append kernel behavior for empty kv-cache. (#353)
The prefill kernels was buggy when some of the requests have empty kv-cache, this PR fixes the issue.
1 parent d1d443a commit 7adc8cf

File tree

3 files changed

+175
-11
lines changed

3 files changed

+175
-11
lines changed

Diff for: include/flashinfer/attention/handler.cuh

+4-2
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ inline std::tuple<bool, uint32_t, uint32_t> PrefillBinarySearchKVChunkSize(
114114

115115
new_batch_size = 0;
116116
for (uint32_t i = 0; i < batch_size; ++i) {
117-
new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * ceil_div(kv_len_arr[i], low);
117+
new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) *
118+
ceil_div(std::max(int(kv_len_arr[i]), 1), low);
118119
}
119120
return {low < max_kv_len, low, new_batch_size};
120121
}
@@ -571,7 +572,8 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz
571572
// step 3: split qo_indptr and kv_indptr
572573
total_num_tiles_q = 0;
573574
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
574-
int64_t packed_qo_len = packed_qo_len_arr[request_idx], kv_len = kv_len_arr[request_idx];
575+
int64_t packed_qo_len = packed_qo_len_arr[request_idx],
576+
kv_len = std::max(int(kv_len_arr[request_idx]), 1);
575577
int64_t num_tiles_q = ceil_div(packed_qo_len, qo_chunk_size),
576578
num_tiles_kv = ceil_div(kv_len, kv_chunk_size);
577579
total_num_tiles_q += num_tiles_q;

Diff for: include/flashinfer/attention/prefill.cuh

+18-4
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base,
619619
reg_id % 2;
620620
const bool out_of_boundary =
621621
(mask_mode == MaskMode::kCausal
622-
? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end))
622+
? (kv_idx + qo_len > kv_len + q_idx || (partition_kv && kv_idx >= chunk_end))
623623
: kv_idx >= chunk_end);
624624
s_frag[fx][fz][reg_id] =
625625
(out_of_boundary ||
@@ -1503,9 +1503,11 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
15031503
kv_tile_idx = kv_tile_indices[bx];
15041504
constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16;
15051505
const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx],
1506-
kv_len = (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] - 1) *
1507-
paged_kv.page_size +
1508-
paged_kv.last_page_len[request_idx];
1506+
kv_len = (paged_kv.indptr[request_idx + 1] != paged_kv.indptr[request_idx])
1507+
? (paged_kv.indptr[request_idx + 1] - paged_kv.indptr[request_idx] -
1508+
1) * paged_kv.page_size +
1509+
paged_kv.last_page_len[request_idx]
1510+
: 0;
15091511
const uint32_t chunk_size = partition_kv ? kv_chunk_size : kv_len;
15101512
const uint32_t chunk_start = partition_kv ? kv_tile_idx * chunk_size : 0;
15111513
const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * chunk_size, kv_len) : kv_len;
@@ -1908,6 +1910,12 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
19081910
const uint32_t group_size = num_qo_heads / num_kv_heads;
19091911
const uint_fastdiv group_size_fastdiv(group_size);
19101912

1913+
if (padded_batch_size == 0) {
1914+
// No request, skip
1915+
// this won't happen in CUDAGraph mode because we fixed the padded_batch_size
1916+
return cudaSuccess;
1917+
}
1918+
19111919
dim3 nblks(padded_batch_size, 1, num_kv_heads);
19121920
dim3 nthrs(32, num_warps_x, num_warps_z);
19131921
constexpr uint32_t num_frags_y = HEAD_DIM / 16;
@@ -2040,6 +2048,12 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
20402048
const uint32_t group_size = num_qo_heads / num_kv_heads;
20412049
const uint_fastdiv group_size_fastdiv(group_size);
20422050

2051+
if (padded_batch_size == 0) {
2052+
// No request, skip
2053+
// this won't happen in CUDAGraph mode because we fixed the padded_batch_size
2054+
return cudaSuccess;
2055+
}
2056+
20432057
dim3 nblks(padded_batch_size, 1, num_kv_heads);
20442058
dim3 nthrs(32, num_warps_x, num_warps_z);
20452059

Diff for: src/test_batch_prefill.cu

+153-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <cstdint>
1919

2020
#include "cpu_reference.h"
21+
#include "flashinfer/pos_enc.cuh"
2122
#include "flashinfer_ops.cuh"
2223
#include "utils.h"
2324

@@ -237,12 +238,13 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
237238
std::vector<int32_t> q_lens(batch_size);
238239
utils::vec_randint_(q_lens, 1, 64);
239240
std::vector<int32_t> kv_lens(q_lens);
241+
240242
std::vector<int32_t> q_indptr{0};
241-
for (uint32_t i = 0; i < batch_size; ++i) {
242-
q_indptr.push_back(q_indptr.back() + q_lens[i]);
243+
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
244+
q_indptr.push_back(q_indptr.back() + q_lens[request_idx]);
243245
}
244246
std::vector<int32_t> append_indptr{0};
245-
for (size_t request_idx = 0; request_idx < batch_size; ++request_idx) {
247+
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
246248
append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]);
247249
}
248250
std::vector<T> kv_data;
@@ -295,7 +297,6 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
295297
q.push_back(qi);
296298
}
297299
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
298-
// create one-hot queries
299300
int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx];
300301
std::vector<T> o_ref_i = cpu_reference::single_mha<T, T, T>(
301302
q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads,
@@ -318,7 +319,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
318319
thrust::device_vector<char> buffer(workspace_size_in_bytes);
319320

320321
handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
321-
workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(),
322+
workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(),
322323
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
323324

324325
auto status =
@@ -350,6 +351,128 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
350351
EXPECT_EQ(nan_detected, false) << "NaN detected in output.";
351352
}
352353

354+
template <typename T>
355+
void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness(
356+
size_t batch_size, size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim,
357+
bool allow_fp16_qk_reduction, uint32_t q_len_min, uint32_t q_len_max, uint32_t kv_len_min,
358+
uint32_t kv_len_max) {
359+
std::vector<int32_t> q_lens(batch_size);
360+
utils::vec_randint_(q_lens, q_len_min, q_len_max);
361+
std::vector<int32_t> kv_lens(batch_size);
362+
utils::vec_randint_(kv_lens, kv_len_min, kv_len_max);
363+
364+
std::vector<int32_t> q_indptr{0};
365+
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
366+
q_indptr.push_back(q_indptr.back() + q_lens[request_idx]);
367+
}
368+
std::vector<int32_t> append_indptr{0};
369+
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
370+
append_indptr.push_back(append_indptr.back() + kv_lens[request_idx]);
371+
}
372+
std::vector<T> kv_data;
373+
std::vector<int32_t> kv_indptr{0};
374+
std::vector<int32_t> kv_indices;
375+
std::vector<int32_t> kv_last_page_len;
376+
size_t page_counter = 0;
377+
std::vector<std::vector<T>> key, value;
378+
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
379+
size_t kv_len = kv_lens[request_idx];
380+
size_t num_pages = (kv_len + page_size - 1) / page_size;
381+
size_t last_page_len = num_pages == 0 ? 0 : (kv_len - 1) % page_size + 1;
382+
std::vector<T> k(kv_len * num_kv_heads * head_dim), v(kv_len * num_kv_heads * head_dim);
383+
utils::vec_normal_(k);
384+
utils::vec_normal_(v);
385+
key.push_back(k);
386+
value.push_back(v);
387+
kv_last_page_len.push_back(last_page_len);
388+
kv_indptr.push_back(kv_indptr.back() + num_pages);
389+
for (size_t j = 0; j < num_pages; ++j) {
390+
kv_indices.push_back(page_counter++);
391+
}
392+
}
393+
394+
kv_data.resize(page_counter * 2 * num_kv_heads * page_size * head_dim);
395+
flashinfer::paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> paged_kv_cpu(
396+
num_kv_heads, page_size, head_dim, batch_size, kv_data.data(), kv_indices.data(),
397+
kv_indptr.data(), kv_last_page_len.data());
398+
cpu_reference::append_paged_kv_cache<kv_layout, T, int32_t>(paged_kv_cpu, key, value,
399+
append_indptr);
400+
401+
// copy data to device
402+
thrust::device_vector<T> kv_data_device(kv_data);
403+
thrust::device_vector<int32_t> kv_indptr_device(kv_indptr);
404+
thrust::device_vector<int32_t> kv_indices_device(kv_indices);
405+
thrust::device_vector<int32_t> kv_last_page_len_device(kv_last_page_len);
406+
407+
// create paged_kv object
408+
flashinfer::paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> paged_kv = paged_kv_cpu;
409+
paged_kv.data = thrust::raw_pointer_cast(kv_data_device.data());
410+
paged_kv.indices = thrust::raw_pointer_cast(kv_indices_device.data());
411+
paged_kv.indptr = thrust::raw_pointer_cast(kv_indptr_device.data());
412+
paged_kv.last_page_len = thrust::raw_pointer_cast(kv_last_page_len_device.data());
413+
414+
std::vector<std::vector<T>> q, o_ref;
415+
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
416+
int32_t q_len = q_lens[request_idx];
417+
std::vector<T> qi(q_len * num_qo_heads * head_dim);
418+
utils::vec_normal_(qi);
419+
q.push_back(qi);
420+
}
421+
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
422+
int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx];
423+
std::vector<T> o_ref_i = cpu_reference::single_mha<T, T, T>(
424+
q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads,
425+
num_kv_heads, head_dim, /*causal=*/false, QKVLayout::kNHD,
426+
/*pos_encoding_mode*/ PosEncodingMode::kNone);
427+
o_ref.push_back(o_ref_i);
428+
}
429+
430+
std::vector<T> q_concat, o_concat_ref;
431+
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
432+
q_concat.insert(q_concat.end(), q[request_idx].begin(), q[request_idx].end());
433+
o_concat_ref.insert(o_concat_ref.end(), o_ref[request_idx].begin(), o_ref[request_idx].end());
434+
}
435+
thrust::device_vector<T> q_device(q_concat);
436+
437+
thrust::device_vector<int32_t> q_indptr_device(q_indptr);
438+
thrust::device_vector<T> o_device(o_concat_ref.size());
439+
440+
BatchPrefillHandler handler;
441+
size_t workspace_size_in_bytes = 32 * 1024 * 1024;
442+
thrust::device_vector<char> buffer(workspace_size_in_bytes);
443+
444+
handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
445+
workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(),
446+
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
447+
448+
auto status =
449+
BatchPrefillWithPagedKVCacheWrapper<PageStorage::kIndices, kv_layout, T, T, int32_t>(
450+
&handler, thrust::raw_pointer_cast(q_device.data()),
451+
thrust::raw_pointer_cast(q_indptr_device.data()),
452+
/*q_offset=*/nullptr, paged_kv, thrust::raw_pointer_cast(o_device.data()),
453+
/*lse=*/nullptr, num_qo_heads, /*causal=*/false,
454+
/*pos_encoding_mode*/ PosEncodingMode::kNone);
455+
EXPECT_EQ(status, cudaSuccess) << "CUDA error: " + std::string(cudaGetErrorString(status));
456+
457+
thrust::host_vector<T> o_host(o_device);
458+
size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0;
459+
bool nan_detected = false;
460+
for (size_t i = 0; i < o_concat_ref.size(); ++i) {
461+
if (std::isnan(float(o_host[i]))) {
462+
nan_detected = true;
463+
}
464+
num_result_errors_atol_1e_3_rtol_1e_3 +=
465+
(!utils::isclose(float(o_host[i]), float(o_concat_ref[i]), 1e-3, 1e-3));
466+
}
467+
float result_accuracy =
468+
1. - float(num_result_errors_atol_1e_3_rtol_1e_3) / max(float(o_concat_ref.size()), 1.f);
469+
std::cout << "batch_size=" << batch_size << ", page_size=" << page_size
470+
<< ", num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads
471+
<< ", head_dim=" << head_dim << ", result_accuracy=" << result_accuracy << std::endl;
472+
EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed.";
473+
EXPECT_EQ(nan_detected, false) << "NaN detected in output.";
474+
}
475+
353476
template <typename T>
354477
void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, size_t num_qo_heads,
355478
size_t page_size, size_t head_dim,
@@ -505,6 +628,27 @@ void TestBatchPagedPrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduct
505628
}
506629
}
507630

631+
template <typename T>
632+
void TestBatchPagedPrefillKernelZeroContextCorrectness(bool allow_fp16_qk_reduction) {
633+
for (size_t batch_size : {1, 4, 7, 11, 19, 37, 99}) {
634+
for (size_t num_kv_heads : {1, 4}) {
635+
for (size_t group_size : {1, 8}) {
636+
size_t num_qo_heads = num_kv_heads * group_size;
637+
for (size_t page_size : {1, 16}) {
638+
for (size_t head_dim : {64, 128, 256}) {
639+
for (size_t kv_len_max : {0, 3}) {
640+
_TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness<T>(
641+
batch_size, num_kv_heads, num_qo_heads, page_size, head_dim,
642+
allow_fp16_qk_reduction,
643+
/*q_len_min=*/1, /*q_len_max=*/3, /*kv_len_min=*/0, kv_len_max);
644+
}
645+
}
646+
}
647+
}
648+
}
649+
}
650+
}
651+
508652
template <typename T>
509653
void TestBatchRaggedPrefillKernelCorrectness(bool allow_fp16_qk_reduction) {
510654
for (size_t num_kv_heads : {4, 8, 32}) {
@@ -534,6 +678,10 @@ TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16) {
534678
TestBatchPagedPrefillKernelLongContextCorrectness<half>(false);
535679
}
536680

681+
TEST(FlashInferCorrectnessTest, BatchPagedPrefillZeroContextTestFP16) {
682+
TestBatchPagedPrefillKernelZeroContextCorrectness<half>(false);
683+
}
684+
537685
TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16QKHalfAccum) {
538686
TestBatchPagedPrefillKernelLongContextCorrectness<half>(true);
539687
}

0 commit comments

Comments
 (0)