Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/flashinfer
Submodule flashinfer updated 72 files
+8 −1 .github/workflows/release_wheel.yml
+1 −1 .release-please-manifest.json
+50 −1 CHANGELOG.md
+61 −11 CMakeLists.txt
+12 −2 README.md
+2 −0 cmake/config.cmake
+4 −0 docs/api/python/cascade.rst
+3 −0 docs/api/python/sampling.rst
+2 −2 docs/conf.py
+1 −1 docs/installation.rst
+18 −0 docs/tutorials/kv_layout.rst
+74 −0 include/flashinfer/activation.cuh
+85 −66 include/flashinfer/attention/cascade.cuh
+22 −37 include/flashinfer/attention/decode.cuh
+86 −79 include/flashinfer/attention/handler.cuh
+507 −368 include/flashinfer/attention/prefill.cuh
+9 −8 include/flashinfer/frag_layout_swizzle.cuh
+78 −2 include/flashinfer/mma.cuh
+115 −17 include/flashinfer/norm.cuh
+62 −13 include/flashinfer/permuted_smem.cuh
+19 −18 include/flashinfer/prefill_attention_decl.cuh
+504 −165 include/flashinfer/sampling.cuh
+6 −0 include/flashinfer/utils.cuh
+313 −235 include/flashinfer/vec_dtypes.cuh
+60 −0 python/csrc/activation.cu
+25 −15 python/csrc/batch_decode.cu
+222 −105 python/csrc/batch_prefill.cu
+6 −37 python/csrc/flashinfer_ops.cu
+33 −126 python/csrc/flashinfer_ops.h
+32 −0 python/csrc/flashinfer_ops_decode.cu
+59 −0 python/csrc/flashinfer_ops_decode.h
+47 −0 python/csrc/flashinfer_ops_prefill.cu
+96 −0 python/csrc/flashinfer_ops_prefill.h
+46 −16 python/csrc/norm.cu
+4 −0 python/csrc/pytorch_extension_utils.h
+176 −35 python/csrc/sampling.cu
+1 −1 python/csrc/single_decode.cu
+15 −3 python/csrc/single_prefill.cu
+28 −23 python/flashinfer/__init__.py
+102 −0 python/flashinfer/activation.py
+280 −21 python/flashinfer/cascade.py
+41 −24 python/flashinfer/decode.py
+50 −1 python/flashinfer/group_gemm.py
+27 −6 python/flashinfer/norm.py
+3 −3 python/flashinfer/page.py
+97 −46 python/flashinfer/prefill.py
+22 −0 python/flashinfer/quantization.py
+489 −41 python/flashinfer/sampling.py
+37 −10 python/flashinfer/sparse.py
+8 −6 python/generate_batch_paged_prefill_inst.py
+8 −6 python/generate_batch_ragged_prefill_inst.py
+7 −5 python/generate_single_prefill_inst.py
+75 −46 python/setup.py
+45 −0 python/tests/test_activation.py
+208 −0 python/tests/test_fp8_prefill.py
+34 −3 python/tests/test_norm.py
+179 −26 python/tests/test_sampling.py
+40 −51 python/tests/test_shared_prefix_kernels.py
+24 −14 src/bench_batch_decode.cu
+11 −6 src/bench_batch_prefill.cu
+36 −20 src/bench_cascade.cu
+4 −4 src/bench_sampling.cu
+77 −1 src/bench_single_prefill.cu
+15 −13 src/flashinfer_ops.cuh
+6 −3 src/test_batch_decode.cu
+220 −122 src/test_batch_prefill.cu
+28 −22 src/test_cascade.cu
+71 −0 src/test_fast_dequant.cu
+2 −2 src/test_sampling.cu
+108 −15 src/test_single_prefill.cu
+47 −32 src/tvm_wrapper.cu
+1 −1 version.txt
29 changes: 18 additions & 11 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ namespace relax_vm {
constexpr const int kPagedKVCacheMaxBlockDepth = 2;
/*! \brief The maximum tree size of a single sequence in tree attention. */
constexpr const int kTreeAttnMaxTreeSize = 256;
/*! \brief The 8MB workspace size for attention auxiliary data. */
constexpr const int kAttnWorkspaceByte = 128 * 1024 * 1024;
/*! \brief The 1MB workspace size for integer attention auxiliary data. */
constexpr const int kIntAttnWorkspaceByte = 1 * 1024 * 1024;
/*! \brief The 128MB workspace size for floating-point attention auxiliary data. */
constexpr const int kFloatAttnWorkspaceByte = 768 * 1024 * 1024;
/*! \brief The id of the temporary logical page, which is useful for sliding window. */
constexpr const int kPagedKVCacheTempPageId = -1;

Expand Down Expand Up @@ -915,7 +917,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
NDArray temp_attn_output_device_;
NDArray temp_attn_scores_device_;
NDArray merged_attn_scores_device_;
std::vector<NDArray> temp_attn_workspace_;
std::vector<NDArray> temp_int_attn_workspace_;
NDArray temp_float_attn_workspace_;

//-------------------------------------------
// Below are the auxiliary data structure on CPU.
Expand Down Expand Up @@ -1089,8 +1092,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {

for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
if (NeedKernelBeginForward()) {
temp_attn_workspace_.push_back(
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
temp_int_attn_workspace_.push_back(
NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device));
}
qo_indptr_on_depths_view_.push_back(NDArray());
page_indptr_on_depths_view_.push_back(NDArray());
Expand All @@ -1103,8 +1106,10 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
// Additional workspace for the "prefill with ragged kv" kernel.
if (NeedKernelBeginForward()) {
temp_attn_workspace_.push_back(
NDArray::Empty({kAttnWorkspaceByte / 4}, DataType::Float(32), device));
temp_int_attn_workspace_.push_back(
NDArray::Empty({kIntAttnWorkspaceByte / 4}, DataType::Float(32), device));
temp_float_attn_workspace_ =
NDArray::Empty({kFloatAttnWorkspaceByte / 4}, DataType::Float(32), device);
}

temp_attn_q_device_ =
Expand Down Expand Up @@ -2324,7 +2329,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
if (!append_before_attn_) {
if (is_chain_on_depths_[0]) {
f_attention_prefill_ragged_begin_forward_.value()(
temp_attn_workspace_[0], cur_append_lengths_indptr_host_.as_ndarray(),
temp_float_attn_workspace_, temp_int_attn_workspace_[0],
cur_append_lengths_indptr_host_.as_ndarray(),
cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, num_qo_heads_,
num_kv_heads_, head_dim_, copy_stream_);
}
Expand All @@ -2336,14 +2342,15 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
CHECK(!support_sliding_window_) << "Kernel BeginForward doesn't support sliding window.";
if (use_decode_kernel_[d]) {
f_attention_decode_begin_forward_.value()(
d, temp_attn_workspace_[d + 1], page_indptr_on_depths_host_[d].as_ndarray(),
d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1],
page_indptr_on_depths_host_[d].as_ndarray(),
last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, num_kv_heads_, head_dim_,
page_size_,
/*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
} else {
f_attention_prefill_begin_forward_.value()(
/*depth=*/d, temp_attn_workspace_[d + 1], qo_indptr_on_depths_host_[d].as_ndarray(),
page_indptr_on_depths_host_[d].as_ndarray(),
/*depth=*/d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1],
qo_indptr_on_depths_host_[d].as_ndarray(), page_indptr_on_depths_host_[d].as_ndarray(),
static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1, num_qo_heads_,
num_kv_heads_, head_dim_, page_size_, copy_stream_);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def set_global_func():
)
fattention_merge_state = tvm.get_global_func("flashinfer.merge_state_in_place")

target = tvm.target.Target("nvidia/geforce-rtx-3090-ti")
target = tvm.target.Target.from_device(device)
builts = []
for tir_func in [
kv_cache_transpose_append,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def set_global_func(head_dim, dtype):
fis_empty = tvm.get_global_func("vm.builtin.attention_kv_cache_empty")
fdebug_get_kv = tvm.get_global_func("vm.builtin.attention_kv_cache_debug_get_kv")

target = tvm.target.Target("cuda")
target = tvm.target.Target.from_device(device)
builts = []
for tir_func in [
_kv_cache_transpose_append(num_kv_heads, head_dim, dtype),
Expand Down