18
18
#include < cstdint>
19
19
20
20
#include " cpu_reference.h"
21
+ #include " flashinfer/pos_enc.cuh"
21
22
#include " flashinfer_ops.cuh"
22
23
#include " utils.h"
23
24
@@ -237,12 +238,13 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
237
238
std::vector<int32_t > q_lens (batch_size);
238
239
utils::vec_randint_ (q_lens, 1 , 64 );
239
240
std::vector<int32_t > kv_lens (q_lens);
241
+
240
242
std::vector<int32_t > q_indptr{0 };
241
- for (uint32_t i = 0 ; i < batch_size; ++i ) {
242
- q_indptr.push_back (q_indptr.back () + q_lens[i ]);
243
+ for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx ) {
244
+ q_indptr.push_back (q_indptr.back () + q_lens[request_idx ]);
243
245
}
244
246
std::vector<int32_t > append_indptr{0 };
245
- for (size_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
247
+ for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
246
248
append_indptr.push_back (append_indptr.back () + kv_lens[request_idx]);
247
249
}
248
250
std::vector<T> kv_data;
@@ -295,7 +297,6 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
295
297
q.push_back (qi);
296
298
}
297
299
for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
298
- // create one-hot queries
299
300
int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx];
300
301
std::vector<T> o_ref_i = cpu_reference::single_mha<T, T, T>(
301
302
q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads,
@@ -318,7 +319,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
318
319
thrust::device_vector<char > buffer (workspace_size_in_bytes);
319
320
320
321
handler.BeginForward <T, int32_t >((void *)thrust::raw_pointer_cast (buffer.data ()),
321
- workspace_size_in_bytes, append_indptr .data (), kv_indptr.data (),
322
+ workspace_size_in_bytes, q_indptr .data (), kv_indptr.data (),
322
323
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
323
324
324
325
auto status =
@@ -350,6 +351,128 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
350
351
EXPECT_EQ (nan_detected, false ) << " NaN detected in output." ;
351
352
}
352
353
354
+ template <typename T>
355
+ void _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness (
356
+ size_t batch_size, size_t num_kv_heads, size_t num_qo_heads, size_t page_size, size_t head_dim,
357
+ bool allow_fp16_qk_reduction, uint32_t q_len_min, uint32_t q_len_max, uint32_t kv_len_min,
358
+ uint32_t kv_len_max) {
359
+ std::vector<int32_t > q_lens (batch_size);
360
+ utils::vec_randint_ (q_lens, q_len_min, q_len_max);
361
+ std::vector<int32_t > kv_lens (batch_size);
362
+ utils::vec_randint_ (kv_lens, kv_len_min, kv_len_max);
363
+
364
+ std::vector<int32_t > q_indptr{0 };
365
+ for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
366
+ q_indptr.push_back (q_indptr.back () + q_lens[request_idx]);
367
+ }
368
+ std::vector<int32_t > append_indptr{0 };
369
+ for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
370
+ append_indptr.push_back (append_indptr.back () + kv_lens[request_idx]);
371
+ }
372
+ std::vector<T> kv_data;
373
+ std::vector<int32_t > kv_indptr{0 };
374
+ std::vector<int32_t > kv_indices;
375
+ std::vector<int32_t > kv_last_page_len;
376
+ size_t page_counter = 0 ;
377
+ std::vector<std::vector<T>> key, value;
378
+ for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
379
+ size_t kv_len = kv_lens[request_idx];
380
+ size_t num_pages = (kv_len + page_size - 1 ) / page_size;
381
+ size_t last_page_len = num_pages == 0 ? 0 : (kv_len - 1 ) % page_size + 1 ;
382
+ std::vector<T> k (kv_len * num_kv_heads * head_dim), v (kv_len * num_kv_heads * head_dim);
383
+ utils::vec_normal_ (k);
384
+ utils::vec_normal_ (v);
385
+ key.push_back (k);
386
+ value.push_back (v);
387
+ kv_last_page_len.push_back (last_page_len);
388
+ kv_indptr.push_back (kv_indptr.back () + num_pages);
389
+ for (size_t j = 0 ; j < num_pages; ++j) {
390
+ kv_indices.push_back (page_counter++);
391
+ }
392
+ }
393
+
394
+ kv_data.resize (page_counter * 2 * num_kv_heads * page_size * head_dim);
395
+ flashinfer::paged_kv_t <PageStorage::kIndices , kv_layout, T, int32_t > paged_kv_cpu (
396
+ num_kv_heads, page_size, head_dim, batch_size, kv_data.data (), kv_indices.data (),
397
+ kv_indptr.data (), kv_last_page_len.data ());
398
+ cpu_reference::append_paged_kv_cache<kv_layout, T, int32_t >(paged_kv_cpu, key, value,
399
+ append_indptr);
400
+
401
+ // copy data to device
402
+ thrust::device_vector<T> kv_data_device (kv_data);
403
+ thrust::device_vector<int32_t > kv_indptr_device (kv_indptr);
404
+ thrust::device_vector<int32_t > kv_indices_device (kv_indices);
405
+ thrust::device_vector<int32_t > kv_last_page_len_device (kv_last_page_len);
406
+
407
+ // create paged_kv object
408
+ flashinfer::paged_kv_t <PageStorage::kIndices , kv_layout, T, int32_t > paged_kv = paged_kv_cpu;
409
+ paged_kv.data = thrust::raw_pointer_cast (kv_data_device.data ());
410
+ paged_kv.indices = thrust::raw_pointer_cast (kv_indices_device.data ());
411
+ paged_kv.indptr = thrust::raw_pointer_cast (kv_indptr_device.data ());
412
+ paged_kv.last_page_len = thrust::raw_pointer_cast (kv_last_page_len_device.data ());
413
+
414
+ std::vector<std::vector<T>> q, o_ref;
415
+ for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
416
+ int32_t q_len = q_lens[request_idx];
417
+ std::vector<T> qi (q_len * num_qo_heads * head_dim);
418
+ utils::vec_normal_ (qi);
419
+ q.push_back (qi);
420
+ }
421
+ for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
422
+ int32_t q_len = q_lens[request_idx], kv_len = kv_lens[request_idx];
423
+ std::vector<T> o_ref_i = cpu_reference::single_mha<T, T, T>(
424
+ q[request_idx], key[request_idx], value[request_idx], q_len, kv_len, num_qo_heads,
425
+ num_kv_heads, head_dim, /* causal=*/ false , QKVLayout::kNHD ,
426
+ /* pos_encoding_mode*/ PosEncodingMode::kNone );
427
+ o_ref.push_back (o_ref_i);
428
+ }
429
+
430
+ std::vector<T> q_concat, o_concat_ref;
431
+ for (uint32_t request_idx = 0 ; request_idx < batch_size; ++request_idx) {
432
+ q_concat.insert (q_concat.end (), q[request_idx].begin (), q[request_idx].end ());
433
+ o_concat_ref.insert (o_concat_ref.end (), o_ref[request_idx].begin (), o_ref[request_idx].end ());
434
+ }
435
+ thrust::device_vector<T> q_device (q_concat);
436
+
437
+ thrust::device_vector<int32_t > q_indptr_device (q_indptr);
438
+ thrust::device_vector<T> o_device (o_concat_ref.size ());
439
+
440
+ BatchPrefillHandler handler;
441
+ size_t workspace_size_in_bytes = 32 * 1024 * 1024 ;
442
+ thrust::device_vector<char > buffer (workspace_size_in_bytes);
443
+
444
+ handler.BeginForward <T, int32_t >((void *)thrust::raw_pointer_cast (buffer.data ()),
445
+ workspace_size_in_bytes, q_indptr.data (), kv_indptr.data (),
446
+ batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
447
+
448
+ auto status =
449
+ BatchPrefillWithPagedKVCacheWrapper<PageStorage::kIndices , kv_layout, T, T, int32_t >(
450
+ &handler, thrust::raw_pointer_cast (q_device.data ()),
451
+ thrust::raw_pointer_cast (q_indptr_device.data ()),
452
+ /* q_offset=*/ nullptr , paged_kv, thrust::raw_pointer_cast (o_device.data ()),
453
+ /* lse=*/ nullptr , num_qo_heads, /* causal=*/ false ,
454
+ /* pos_encoding_mode*/ PosEncodingMode::kNone );
455
+ EXPECT_EQ (status, cudaSuccess) << " CUDA error: " + std::string (cudaGetErrorString (status));
456
+
457
+ thrust::host_vector<T> o_host (o_device);
458
+ size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0 ;
459
+ bool nan_detected = false ;
460
+ for (size_t i = 0 ; i < o_concat_ref.size (); ++i) {
461
+ if (std::isnan (float (o_host[i]))) {
462
+ nan_detected = true ;
463
+ }
464
+ num_result_errors_atol_1e_3_rtol_1e_3 +=
465
+ (!utils::isclose (float (o_host[i]), float (o_concat_ref[i]), 1e-3 , 1e-3 ));
466
+ }
467
+ float result_accuracy =
468
+ 1 . - float (num_result_errors_atol_1e_3_rtol_1e_3) / max (float (o_concat_ref.size ()), 1 .f );
469
+ std::cout << " batch_size=" << batch_size << " , page_size=" << page_size
470
+ << " , num_qo_heads=" << num_qo_heads << " , num_kv_heads=" << num_kv_heads
471
+ << " , head_dim=" << head_dim << " , result_accuracy=" << result_accuracy << std::endl;
472
+ EXPECT_GT (result_accuracy, 0.99 ) << " Result correctness test failed." ;
473
+ EXPECT_EQ (nan_detected, false ) << " NaN detected in output." ;
474
+ }
475
+
353
476
template <typename T>
354
477
void _TestBatchPagedPrefillKernelLongContextCorrectness (size_t num_kv_heads, size_t num_qo_heads,
355
478
size_t page_size, size_t head_dim,
@@ -505,6 +628,27 @@ void TestBatchPagedPrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduct
505
628
}
506
629
}
507
630
631
+ template <typename T>
632
+ void TestBatchPagedPrefillKernelZeroContextCorrectness (bool allow_fp16_qk_reduction) {
633
+ for (size_t batch_size : {1 , 4 , 7 , 11 , 19 , 37 , 99 }) {
634
+ for (size_t num_kv_heads : {1 , 4 }) {
635
+ for (size_t group_size : {1 , 8 }) {
636
+ size_t num_qo_heads = num_kv_heads * group_size;
637
+ for (size_t page_size : {1 , 16 }) {
638
+ for (size_t head_dim : {64 , 128 , 256 }) {
639
+ for (size_t kv_len_max : {0 , 3 }) {
640
+ _TestBatchPagedPrefillKernelQMinMaxKVMinMaxCorrectness<T>(
641
+ batch_size, num_kv_heads, num_qo_heads, page_size, head_dim,
642
+ allow_fp16_qk_reduction,
643
+ /* q_len_min=*/ 1 , /* q_len_max=*/ 3 , /* kv_len_min=*/ 0 , kv_len_max);
644
+ }
645
+ }
646
+ }
647
+ }
648
+ }
649
+ }
650
+ }
651
+
508
652
template <typename T>
509
653
void TestBatchRaggedPrefillKernelCorrectness (bool allow_fp16_qk_reduction) {
510
654
for (size_t num_kv_heads : {4 , 8 , 32 }) {
@@ -534,6 +678,10 @@ TEST(FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16) {
534
678
TestBatchPagedPrefillKernelLongContextCorrectness<half>(false );
535
679
}
536
680
681
+ TEST (FlashInferCorrectnessTest, BatchPagedPrefillZeroContextTestFP16) {
682
+ TestBatchPagedPrefillKernelZeroContextCorrectness<half>(false );
683
+ }
684
+
537
685
TEST (FlashInferCorrectnessTest, BatchPagedPrefillLongContextTestFP16QKHalfAccum) {
538
686
TestBatchPagedPrefillKernelLongContextCorrectness<half>(true );
539
687
}
0 commit comments