Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Run the following command to reformat a file:
# clang-format -i -style=Google <file>
# Or use clang-format-diff to only reformat the changed lines:
# https://clang.llvm.org/docs/ClangFormat.html
BasedOnStyle: Google
DerivePointerAlignment: false
ColumnLimit: 100
PointerAlignment: Left
85 changes: 80 additions & 5 deletions src/tvm_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ using tvm::runtime::ShapeTuple;
LOG(FATAL) << "Unsupported data type " << dl_dtype.code; \
}

int _FlashInferSingleDecodeWithKVCache(DLTensor *q, DLTensor *k, DLTensor *v, DLTensor *tmp,
int _FlashInferSingleDecodeWithKVCache(DLTensor* q, DLTensor* k, DLTensor* v, DLTensor* tmp,
int64_t qkv_layout, int64_t rotary_mode, double rope_scale,
double rope_theta, DLTensor *o) {
double rope_theta, DLTensor* o) {
CHECK_EQ(q->device.device_type, kDLCUDA) << "The device of q matrix must be CUDA.";
CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be CUDA.";
CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be CUDA.";
Expand Down Expand Up @@ -68,8 +68,8 @@ int _FlashInferSingleDecodeWithKVCache(DLTensor *q, DLTensor *k, DLTensor *v, DL
SWITCH_TVM_CUDA_DTYPE(
q->dtype, dtype_in, {SWITCH_TVM_CUDA_DTYPE(o->dtype, dtype_out, {
cudaError_t status = flashinfer::SingleDecodeWithKVCache(
(dtype_in *)q->data, (dtype_in *)k->data, (dtype_in *)v->data, (dtype_out *)o->data,
(float *)tmp->data, num_heads, seq_len, head_dim, flashinfer::QKVLayout(qkv_layout),
(dtype_in*)q->data, (dtype_in*)k->data, (dtype_in*)v->data, (dtype_out*)o->data,
(float*)tmp->data, num_heads, seq_len, head_dim, flashinfer::QKVLayout(qkv_layout),
flashinfer::RotaryMode(rotary_mode), rope_scale, rope_theta, 0, dev_id);
if (status != cudaSuccess) {
LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status);
Expand All @@ -78,4 +78,79 @@ int _FlashInferSingleDecodeWithKVCache(DLTensor *q, DLTensor *k, DLTensor *v, DL
return 0;
}

TVM_DLL_EXPORT_TYPED_FUNC(FlashInferSingleDecodeWithKVCache, _FlashInferSingleDecodeWithKVCache);
TVM_DLL_EXPORT_TYPED_FUNC(FlashInferSingleDecodeWithKVCache, _FlashInferSingleDecodeWithKVCache);

void _FlashInferBatchDecodeWithPagedKVCache(DLTensor* q_data, DLTensor* pages,
DLTensor* page_table_indptr,
DLTensor* page_table_values,
DLTensor* last_page_offset, //
int64_t layer_id, //
DLTensor* output) {
CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA.";
CHECK_EQ(pages->device.device_type, kDLCUDA) << "The device of kv pages must be CUDA.";
CHECK_EQ(page_table_indptr->device.device_type, kDLCUDA)
<< "The device of page_table_indptr matrix must be CUDA.";
CHECK_EQ(page_table_values->device.device_type, kDLCUDA)
<< "The device of page_table_values matrix must be CUDA.";
CHECK_EQ(last_page_offset->device.device_type, kDLCUDA)
<< "The device of last_page_offset matrix must be CUDA.";
CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA.";

int32_t dev_id = q_data->device.device_id;
CHECK_EQ(pages->device.device_id, dev_id);
CHECK_EQ(page_table_indptr->device.device_id, dev_id);
CHECK_EQ(page_table_values->device.device_id, dev_id);
CHECK_EQ(last_page_offset->device.device_id, dev_id);
CHECK_EQ(output->device.device_id, dev_id);

CHECK(q_data->dtype.lanes == 1 && pages->dtype.lanes == 1 && output->dtype.lanes == 1);
CHECK(q_data->dtype.bits == pages->dtype.bits && q_data->dtype.code == pages->dtype.code);

CHECK_EQ(pages->ndim, 7);
CHECK_LT(layer_id, pages->shape[1]);
CHECK_GE(layer_id, 0);
CHECK_EQ(pages->shape[2], 1) << "Page chunk size should be fixed to 1 right now.";
CHECK_EQ(pages->shape[3], 2);
int64_t npage = pages->shape[0];
int64_t nlayer = pages->shape[1];
int64_t nhead = pages->shape[4];
int64_t nfeat = pages->shape[6];
int64_t page_size = pages->shape[5];

CHECK_EQ(last_page_offset->ndim, 1);
int64_t num_total_seqs = last_page_offset->shape[0];

CHECK_EQ(page_table_indptr->ndim, 1);
CHECK_EQ(page_table_indptr->shape[0], num_total_seqs + 1);
CHECK_EQ(page_table_values->ndim, 1);

CHECK_EQ(q_data->ndim, 4);
CHECK_EQ(q_data->shape[0], num_total_seqs);
CHECK_EQ(q_data->shape[1], 1);
CHECK_EQ(q_data->shape[2], nhead);
CHECK_EQ(q_data->shape[3], nfeat);

CHECK_EQ(output->ndim, 4);
CHECK_EQ(output->shape[0], num_total_seqs);
CHECK_EQ(output->shape[1], 1);
CHECK_EQ(output->shape[2], nhead);
CHECK_EQ(output->shape[3], nfeat);

SWITCH_TVM_CUDA_DTYPE(
pages->dtype, dtype_in, {SWITCH_TVM_CUDA_DTYPE(output->dtype, dtype_out, {
flashinfer::paged_kv_t<dtype_in> cache(npage, nlayer, layer_id, nhead, page_size, nfeat,
static_cast<dtype_in*>(pages->data));
cudaError_t status = flashinfer::BatchDecodeWithPagedKVCache(
(dtype_in*)q_data->data, cache, static_cast<size_t*>(page_table_indptr->data),
static_cast<size_t*>(page_table_values->data),
static_cast<size_t*>(last_page_offset->data), static_cast<dtype_out*>(output->data),
nullptr, num_total_seqs, flashinfer::RotaryMode::kNone, 1.0f, 1e4, 0,
q_data->device.device_id);
if (status != cudaSuccess) {
LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status);
}
})});
}

TVM_DLL_EXPORT_TYPED_FUNC(FlashInferBatchDecodeWithPagedKVCache,
_FlashInferBatchDecodeWithPagedKVCache);