diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000..9d622b98ba --- /dev/null +++ b/.clang-format @@ -0,0 +1,8 @@ +# Run the following command to reformat a file: +# clang-format -i -style=Google +# Or use clang-format-diff to only reformat the changed lines: +# https://clang.llvm.org/docs/ClangFormat.html +BasedOnStyle: Google +DerivePointerAlignment: false +ColumnLimit: 100 +PointerAlignment: Left diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index 9324177149..106f153420 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -31,9 +31,9 @@ using tvm::runtime::ShapeTuple; LOG(FATAL) << "Unsupported data type " << dl_dtype.code; \ } -int _FlashInferSingleDecodeWithKVCache(DLTensor *q, DLTensor *k, DLTensor *v, DLTensor *tmp, +int _FlashInferSingleDecodeWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* tmp, int64_t qkv_layout, int64_t rotary_mode, double rope_scale, - double rope_theta, DLTensor *o) { + double rope_theta, DLTensor* o) { CHECK_EQ(q->device.device_type, kDLCUDA) << "The device of q matrix must be CUDA."; CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be CUDA."; CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be CUDA."; @@ -68,8 +68,8 @@ int _FlashInferSingleDecodeWithKVCache(DLTensor *q, DLTensor *k, DLTensor *v, DL SWITCH_TVM_CUDA_DTYPE( q->dtype, dtype_in, {SWITCH_TVM_CUDA_DTYPE(o->dtype, dtype_out, { cudaError_t status = flashinfer::SingleDecodeWithKVCache( - (dtype_in *)q->data, (dtype_in *)k->data, (dtype_in *)v->data, (dtype_out *)o->data, - (float *)tmp->data, num_heads, seq_len, head_dim, flashinfer::QKVLayout(qkv_layout), + (dtype_in*)q->data, (dtype_in*)k->data, (dtype_in*)v->data, (dtype_out*)o->data, + (float*)tmp->data, num_heads, seq_len, head_dim, flashinfer::QKVLayout(qkv_layout), flashinfer::RotaryMode(rotary_mode), rope_scale, rope_theta, 0, dev_id); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); @@ -78,4 +78,79 @@ int _FlashInferSingleDecodeWithKVCache(DLTensor *q, DLTensor *k, DLTensor *v, DL return 0; } -TVM_DLL_EXPORT_TYPED_FUNC(FlashInferSingleDecodeWithKVCache, _FlashInferSingleDecodeWithKVCache); \ No newline at end of file +TVM_DLL_EXPORT_TYPED_FUNC(FlashInferSingleDecodeWithKVCache, _FlashInferSingleDecodeWithKVCache); + +void _FlashInferBatchDecodeWithPagedKVCache(DLTensor* q_data, DLTensor* pages, + DLTensor* page_table_indptr, + DLTensor* page_table_values, + DLTensor* last_page_offset, // + int64_t layer_id, // + DLTensor* output) { + CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA."; + CHECK_EQ(pages->device.device_type, kDLCUDA) << "The device of kv pages must be CUDA."; + CHECK_EQ(page_table_indptr->device.device_type, kDLCUDA) + << "The device of page_table_indptr matrix must be CUDA."; + CHECK_EQ(page_table_values->device.device_type, kDLCUDA) + << "The device of page_table_values matrix must be CUDA."; + CHECK_EQ(last_page_offset->device.device_type, kDLCUDA) + << "The device of last_page_offset matrix must be CUDA."; + CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA."; + + int32_t dev_id = q_data->device.device_id; + CHECK_EQ(pages->device.device_id, dev_id); + CHECK_EQ(page_table_indptr->device.device_id, dev_id); + CHECK_EQ(page_table_values->device.device_id, dev_id); + CHECK_EQ(last_page_offset->device.device_id, dev_id); + CHECK_EQ(output->device.device_id, dev_id); + + CHECK(q_data->dtype.lanes == 1 && pages->dtype.lanes == 1 && output->dtype.lanes == 1); + CHECK(q_data->dtype.bits == pages->dtype.bits && q_data->dtype.code == pages->dtype.code); + + CHECK_EQ(pages->ndim, 7); + CHECK_LT(layer_id, pages->shape[1]); + CHECK_GE(layer_id, 0); + CHECK_EQ(pages->shape[2], 1) << "Page chunk size should be fixed to 1 right now."; + CHECK_EQ(pages->shape[3], 2); + int64_t npage = pages->shape[0]; + int64_t nlayer = pages->shape[1]; + int64_t nhead = pages->shape[4]; + int64_t nfeat = pages->shape[6]; + int64_t page_size = pages->shape[5]; + + CHECK_EQ(last_page_offset->ndim, 1); + int64_t num_total_seqs = last_page_offset->shape[0]; + + CHECK_EQ(page_table_indptr->ndim, 1); + CHECK_EQ(page_table_indptr->shape[0], num_total_seqs + 1); + CHECK_EQ(page_table_values->ndim, 1); + + CHECK_EQ(q_data->ndim, 4); + CHECK_EQ(q_data->shape[0], num_total_seqs); + CHECK_EQ(q_data->shape[1], 1); + CHECK_EQ(q_data->shape[2], nhead); + CHECK_EQ(q_data->shape[3], nfeat); + + CHECK_EQ(output->ndim, 4); + CHECK_EQ(output->shape[0], num_total_seqs); + CHECK_EQ(output->shape[1], 1); + CHECK_EQ(output->shape[2], nhead); + CHECK_EQ(output->shape[3], nfeat); + + SWITCH_TVM_CUDA_DTYPE( + pages->dtype, dtype_in, {SWITCH_TVM_CUDA_DTYPE(output->dtype, dtype_out, { + flashinfer::paged_kv_t cache(npage, nlayer, layer_id, nhead, page_size, nfeat, + static_cast(pages->data)); + cudaError_t status = flashinfer::BatchDecodeWithPagedKVCache( + (dtype_in*)q_data->data, cache, static_cast(page_table_indptr->data), + static_cast(page_table_values->data), + static_cast(last_page_offset->data), static_cast(output->data), + nullptr, num_total_seqs, flashinfer::RotaryMode::kNone, 1.0f, 1e4, 0, + q_data->device.device_id); + if (status != cudaSuccess) { + LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); + } + })}); +} + +TVM_DLL_EXPORT_TYPED_FUNC(FlashInferBatchDecodeWithPagedKVCache, + _FlashInferBatchDecodeWithPagedKVCache);