Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja
Original file line number Diff line number Diff line change
@@ -1 +1,15 @@
// TODO: Not implemented yet
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
#include "batch_prefill_sm90_config.inc"

namespace flashinfer {

{% for same_scheduler_for_all_heads in ["true", "false"] %}
template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched
<{{ head_dim_qk }},
{{ mask_mode }},
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
{{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream);
{% endfor %}

}; // namespace flashinfer
Comment on lines +1 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Fix namespace closing syntax.

Line 15 uses }; to close the namespace, but namespaces should be closed with just } (no semicolon).

-};  // namespace flashinfer
+}  // namespace flashinfer
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
#include "batch_prefill_sm90_config.inc"
namespace flashinfer {
{% for same_scheduler_for_all_heads in ["true", "false"] %}
template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched
<{{ head_dim_qk }},
{{ mask_mode }},
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
{{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream);
{% endfor %}
}; // namespace flashinfer
#include <flashinfer/attention/hopper/quantization/prefill_sm90.cuh>
#include "batch_prefill_sm90_config.inc"
namespace flashinfer {
{% for same_scheduler_for_all_heads in ["true", "false"] %}
template cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched
<{{ head_dim_qk }},
{{ mask_mode }},
/*USE_SLIDING_WINDOW=*/{{ use_sliding_window }},
/*SAME_SCHEDULER_FOR_ALL_HEADS=*/{{ same_scheduler_for_all_heads }},
{{ variant_name }}, RaggedParams>(RaggedParams& params, bool enable_pdl, cudaStream_t stream);
{% endfor %}
} // namespace flashinfer
πŸ€– Prompt for AI Agents
In csrc/batch_prefill_fp8_ragged_sm90_kernel_inst.jinja around lines 1 to 15,
the namespace is closed using "};" but C++ namespace blocks should be closed
with a plain "}" (no semicolon); remove the trailing semicolon after the closing
brace so the file ends with "}" to correctly close the flashinfer namespace.

100 changes: 99 additions & 1 deletion csrc/batch_prefill_fp8_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLIDING_WINDOW,
cudaError_t BatchFP8PrefillWithPagedKVCacheDispatched(Params& params, bool enable_pdl,
cudaStream_t stream);

template <uint32_t HEAD_DIM, MaskMode MASK_MODE, bool LEFT_SLIDING_WINDOW,
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename Params>
cudaError_t BatchFP8PrefillWithRaggedKVCacheDispatched(Params& params, bool enable_pdl,
cudaStream_t stream);

} // namespace flashinfer

using namespace flashinfer;
Expand Down Expand Up @@ -78,7 +83,94 @@ void BatchPrefillWithRaggedKVCacheSM90Run(ffi::TensorView float_workspace_buffer
int64_t window_left,
bool enable_pdl // placeholder
ADDITIONAL_FUNC_PARAMS) {
return; // TODO: Implement this function
PrefillPlanSM90Info plan_info;
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));

if (maybe_lse.has_value()) {
const auto& lse = maybe_lse.value();
TVM_FFI_ICHECK_EQ(lse.size(0), q.size(0));
TVM_FFI_ICHECK_EQ(lse.size(1), q.size(1));
}

void* float_buffer_ptr = float_workspace_buffer.data_ptr();
void* int_buffer_ptr = int_workspace_buffer.data_ptr();

int64_t head_dim_qk = q.size(2);
int64_t head_dim_vo = v.size(2);

QKVLayout kv_layout = static_cast<QKVLayout>(layout);

cudaSetDevice(float_workspace_buffer.device().device_id);
const cudaStream_t stream = get_stream(float_workspace_buffer.device());
const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code);
bool use_swa = window_left != -1;

DISPATCH_context(
DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW,
USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] {
RaggedParams params;

params.q_ptr = static_cast<DTypeQ*>(q.data_ptr());
params.k_ptr = static_cast<DTypeKV*>(k.data_ptr());
params.v_ptr = static_cast<DTypeKV*>(v.data_ptr());
params.o_ptr = static_cast<DTypeO*>(o.data_ptr());
params.lse_ptr = maybe_lse ? static_cast<float*>(maybe_lse.value().data_ptr()) : nullptr;
params.q_stride_n = q.stride(0);
params.q_stride_h = q.stride(1);
params.o_stride_n = o.stride(0);
params.o_stride_h = o.stride(1);
if (kv_layout == QKVLayout::kNHD) {
params.k_stride_n = k.stride(0);
params.k_stride_h = k.stride(1);
params.v_stride_n = v.stride(0);
params.v_stride_h = v.stride(1);
} else {
params.k_stride_h = k.stride(0);
params.k_stride_n = k.stride(1);
params.v_stride_h = v.stride(0);
params.v_stride_n = v.stride(1);
}
params.nnz_qo = q.size(0);
params.nnz_kv = k.size(0);
params.num_qo_heads = q.size(1);
params.num_kv_heads = k.size(1);
params.group_size = params.num_qo_heads / params.num_kv_heads;
params.window_left = window_left;
params.causal = mask_mode_code == 1;
params.qo_tile_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_tile_indices_offset);
params.qo_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_indptr_offset);
params.kv_indptr = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_indptr_offset);
params.qo_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.qo_len_offset);
params.kv_lens = GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.kv_len_offset);
params.head_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.head_indices_offset);
params.work_indptr =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.work_indptr_offset);
params.batch_indices =
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.batch_indices_offset);

ADDITIONAL_PARAMS_SETTER

// Not support various head_dim for now
static_assert(HEAD_DIM_QK == HEAD_DIM_VO, "head_dim_qk and head_dim_vo should be the same");
// Currently only support same quantization precision
static_assert(std::is_same_v<DTypeQ, DTypeKV>);

bool same_schedule_for_all_heads = plan_info.same_schedule_for_all_heads;
DISPATCH_BOOL(same_schedule_for_all_heads, SAME_SCHEDULER_FOR_ALL_HEADS, [&] {
cudaError_t status =
BatchFP8PrefillWithRaggedKVCacheDispatched<HEAD_DIM_QK, MASK_MODE, USE_SLIDING_WINDOW,
SAME_SCHEDULER_FOR_ALL_HEADS,
AttentionVariant>(params, enable_pdl,
stream);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "BatchPrefillWithRaggedKVCacheSM90Run failed with error: "
<< cudaGetErrorString(status);
return true;
});
});
}

void BatchPrefillWithPagedKVCacheSM90Run(
Expand Down Expand Up @@ -136,12 +228,18 @@ void BatchPrefillWithPagedKVCacheSM90Run(
params.k_stride_h = paged_k_cache.stride(2);
params.v_stride_n = paged_v_cache.stride(1);
params.v_stride_h = paged_v_cache.stride(2);
// For sparse paged KV cache, store the stride between pages
params.k_page_stride = paged_k_cache.stride(0);
params.v_page_stride = paged_v_cache.stride(0);
} else {
// (num_pages, num_heads, page_size, head_dim)
params.k_stride_h = paged_k_cache.stride(1);
params.k_stride_n = paged_k_cache.stride(2);
params.v_stride_h = paged_v_cache.stride(1);
params.v_stride_n = paged_v_cache.stride(2);
// For sparse paged KV cache, store the stride between pages
params.k_page_stride = paged_k_cache.stride(0);
params.v_page_stride = paged_v_cache.stride(0);
}
params.nnz_qo = q.size(0);
params.num_qo_heads = q.size(1);
Expand Down
11 changes: 11 additions & 0 deletions csrc/batch_prefill_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,24 @@ void BatchPrefillWithPagedKVCacheSM90Run(
params.k_stride_h = paged_k_cache.stride(2);
params.v_stride_n = paged_v_cache.stride(1);
params.v_stride_h = paged_v_cache.stride(2);
// For sparse paged KV cache, store the stride between pages
params.k_page_stride = paged_k_cache.stride(0);
params.v_page_stride = paged_v_cache.stride(0);
} else {
// (num_pages, num_heads, page_size, head_dim)
params.k_stride_h = paged_k_cache.stride(1);
params.k_stride_n = paged_k_cache.stride(2);
params.v_stride_h = paged_v_cache.stride(1);
params.v_stride_n = paged_v_cache.stride(2);
// For sparse paged KV cache, store the stride between pages
params.k_page_stride = paged_k_cache.stride(0);
params.v_page_stride = paged_v_cache.stride(0);
}
// Sparse mainloop assumes K and V have same strides for efficiency
TVM_FFI_ICHECK_EQ(params.k_page_stride, params.v_page_stride)
<< "K and V must have same page stride for sparse attention";
TVM_FFI_ICHECK_EQ(params.k_stride_n, params.v_stride_n)
<< "K and V must have same stride_n for sparse attention";
params.nnz_qo = q.size(0);
params.num_qo_heads = q.size(1);
params.num_kv_heads = num_kv_heads;
Expand Down
5 changes: 5 additions & 0 deletions csrc/batch_prefill_sm90_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ struct PagedParams {
int64_t o_stride_h;
int64_t nnz_qo;

// NOTE: For sparse paged KV cache, we need the stride between pages
// This is paged_k_cache.stride(0), not the layout stride
int64_t k_page_stride; // Stride between pages for K
int64_t v_page_stride; // Stride between pages for V

int head_dim;
int num_qo_heads;
int num_kv_heads;
Expand Down
7 changes: 0 additions & 7 deletions csrc/flashinfer_page_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,5 @@ void append_paged_mla_kv_cache(TensorView append_ckv, TensorView append_kpe,
TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr,
TensorView kv_last_page_len);

void block_sparse_indices_to_vector_sparse_offsets(
TensorView block_sparse_indices, TensorView block_sparse_indptr,
TensorView vector_sparse_offsets, TensorView vector_sparse_indptr, TensorView kv_len_arr,
int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(append_paged_kv_cache, append_paged_kv_cache);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(append_paged_mla_kv_cache, append_paged_mla_kv_cache);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(block_sparse_indices_to_vector_sparse_offsets,
block_sparse_indices_to_vector_sparse_offsets);
25 changes: 0 additions & 25 deletions csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,31 +112,6 @@ void append_paged_kv_cache(TensorView append_key, TensorView append_value, Tenso
<< paged_k_cache.dtype();
}

void block_sparse_indices_to_vector_sparse_offsets(
TensorView block_sparse_indices, TensorView block_sparse_indptr,
TensorView vector_sparse_offsets, TensorView vector_sparse_indptr, TensorView kv_len_arr,
int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size) {
CHECK_INPUT(block_sparse_indices);
CHECK_INPUT(block_sparse_indptr);
CHECK_INPUT(vector_sparse_offsets);
CHECK_INPUT(vector_sparse_indptr);
CHECK_INPUT(kv_len_arr);

cudaSetDevice(block_sparse_indices.device().device_id);
const cudaStream_t stream = get_stream(block_sparse_indices.device());

cudaError_t status = BlockSparseIndicesToVectorSparseOffset(
static_cast<int32_t*>(block_sparse_indices.data_ptr()),
static_cast<int32_t*>(block_sparse_indptr.data_ptr()),
static_cast<int32_t*>(vector_sparse_offsets.data_ptr()),
static_cast<int32_t*>(vector_sparse_indptr.data_ptr()),
static_cast<int32_t*>(kv_len_arr.data_ptr()), stride_block, stride_n, batch_size, block_size,
stream);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "BlockSparseIndicesToVectorSparseOffset failed with error: " << cudaGetErrorString(status);
}

void append_paged_mla_kv_cache(TensorView append_ckv, TensorView append_kpe,
TensorView batch_indices, TensorView positions, TensorView ckv_cache,
TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr,
Expand Down
36 changes: 0 additions & 36 deletions flashinfer/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,6 @@ def get_page_module():
return gen_page_module().build_and_load()


def block_sparse_indices_to_vector_sparse_offsets(
block_sparse_indices: torch.Tensor,
block_sparse_indptr: torch.Tensor,
vector_sparse_offsets: torch.Tensor,
vector_sparse_indptr: torch.Tensor,
kv_lens: torch.Tensor,
stride_block: int,
stride_n: int,
block_size: int,
) -> torch.Tensor:
if block_size == 1:
if stride_block == 1:
return block_sparse_indices
else:
return block_sparse_indices * stride_block

assert block_sparse_indices.dtype == torch.int32
assert block_sparse_indptr.dtype == torch.int32
assert vector_sparse_offsets.dtype == torch.int32
assert vector_sparse_indptr.dtype == torch.int32
assert kv_lens.dtype == torch.int32
batch_size = block_sparse_indptr.size(0) - 1
get_page_module().block_sparse_indices_to_vector_sparse_offsets(
block_sparse_indices,
block_sparse_indptr,
vector_sparse_offsets,
vector_sparse_indptr,
kv_lens,
stride_block,
stride_n,
batch_size,
block_size,
)
return vector_sparse_offsets


@register_custom_op(
"flashinfer::append_paged_mla_kv_cache",
mutates_args=("ckv_cache", "kpe_cache"),
Expand Down
Loading