Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions csrc/batch_prefill_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,24 @@ void BatchPrefillWithPagedKVCacheSM90Run(
params.k_stride_h = paged_k_cache.stride(2);
params.v_stride_n = paged_v_cache.stride(1);
params.v_stride_h = paged_v_cache.stride(2);
// For sparse paged KV cache, store the stride between pages
params.k_page_stride = paged_k_cache.stride(0);
params.v_page_stride = paged_v_cache.stride(0);
} else {
// (num_pages, num_heads, page_size, head_dim)
params.k_stride_h = paged_k_cache.stride(1);
params.k_stride_n = paged_k_cache.stride(2);
params.v_stride_h = paged_v_cache.stride(1);
params.v_stride_n = paged_v_cache.stride(2);
// For sparse paged KV cache, store the stride between pages
params.k_page_stride = paged_k_cache.stride(0);
params.v_page_stride = paged_v_cache.stride(0);
}
// Sparse mainloop assumes K and V have same strides for efficiency
TVM_FFI_ICHECK_EQ(params.k_page_stride, params.v_page_stride)
<< "K and V must have same page stride for sparse attention";
TVM_FFI_ICHECK_EQ(params.k_stride_n, params.v_stride_n)
<< "K and V must have same stride_n for sparse attention";
params.nnz_qo = q.size(0);
params.num_qo_heads = q.size(1);
params.num_kv_heads = num_kv_heads;
Expand Down
5 changes: 5 additions & 0 deletions csrc/batch_prefill_sm90_customize_config.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ struct PagedParams {
int64_t o_stride_h;
int64_t nnz_qo;

// NOTE: For sparse paged KV cache, we need the stride between pages
// This is paged_k_cache.stride(0), not the layout stride
int64_t k_page_stride; // Stride between pages for K
int64_t v_page_stride; // Stride between pages for V

int head_dim;
int num_qo_heads;
int num_kv_heads;
Expand Down
7 changes: 0 additions & 7 deletions csrc/flashinfer_page_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,5 @@ void append_paged_mla_kv_cache(TensorView append_ckv, TensorView append_kpe,
TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr,
TensorView kv_last_page_len);

void block_sparse_indices_to_vector_sparse_offsets(
TensorView block_sparse_indices, TensorView block_sparse_indptr,
TensorView vector_sparse_offsets, TensorView vector_sparse_indptr, TensorView kv_len_arr,
int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(append_paged_kv_cache, append_paged_kv_cache);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(append_paged_mla_kv_cache, append_paged_mla_kv_cache);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(block_sparse_indices_to_vector_sparse_offsets,
block_sparse_indices_to_vector_sparse_offsets);
25 changes: 0 additions & 25 deletions csrc/page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,31 +112,6 @@ void append_paged_kv_cache(TensorView append_key, TensorView append_value, Tenso
<< paged_k_cache.dtype();
}

void block_sparse_indices_to_vector_sparse_offsets(
TensorView block_sparse_indices, TensorView block_sparse_indptr,
TensorView vector_sparse_offsets, TensorView vector_sparse_indptr, TensorView kv_len_arr,
int64_t stride_block, int64_t stride_n, int64_t batch_size, int64_t block_size) {
CHECK_INPUT(block_sparse_indices);
CHECK_INPUT(block_sparse_indptr);
CHECK_INPUT(vector_sparse_offsets);
CHECK_INPUT(vector_sparse_indptr);
CHECK_INPUT(kv_len_arr);

cudaSetDevice(block_sparse_indices.device().device_id);
const cudaStream_t stream = get_stream(block_sparse_indices.device());

cudaError_t status = BlockSparseIndicesToVectorSparseOffset(
static_cast<int32_t*>(block_sparse_indices.data_ptr()),
static_cast<int32_t*>(block_sparse_indptr.data_ptr()),
static_cast<int32_t*>(vector_sparse_offsets.data_ptr()),
static_cast<int32_t*>(vector_sparse_indptr.data_ptr()),
static_cast<int32_t*>(kv_len_arr.data_ptr()), stride_block, stride_n, batch_size, block_size,
stream);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "BlockSparseIndicesToVectorSparseOffset failed with error: " << cudaGetErrorString(status);
}

void append_paged_mla_kv_cache(TensorView append_ckv, TensorView append_kpe,
TensorView batch_indices, TensorView positions, TensorView ckv_cache,
TensorView kpe_cache, TensorView kv_indices, TensorView kv_indptr,
Expand Down
36 changes: 0 additions & 36 deletions flashinfer/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,42 +34,6 @@ def get_page_module():
return gen_page_module().build_and_load()


def block_sparse_indices_to_vector_sparse_offsets(
block_sparse_indices: torch.Tensor,
block_sparse_indptr: torch.Tensor,
vector_sparse_offsets: torch.Tensor,
vector_sparse_indptr: torch.Tensor,
kv_lens: torch.Tensor,
stride_block: int,
stride_n: int,
block_size: int,
) -> torch.Tensor:
if block_size == 1:
if stride_block == 1:
return block_sparse_indices
else:
return block_sparse_indices * stride_block

assert block_sparse_indices.dtype == torch.int32
assert block_sparse_indptr.dtype == torch.int32
assert vector_sparse_offsets.dtype == torch.int32
assert vector_sparse_indptr.dtype == torch.int32
assert kv_lens.dtype == torch.int32
batch_size = block_sparse_indptr.size(0) - 1
get_page_module().block_sparse_indices_to_vector_sparse_offsets(
block_sparse_indices,
block_sparse_indptr,
vector_sparse_offsets,
vector_sparse_indptr,
kv_lens,
stride_block,
stride_n,
batch_size,
block_size,
)
return vector_sparse_offsets


@register_custom_op(
"flashinfer::append_paged_mla_kv_cache",
mutates_args=("ckv_cache", "kpe_cache"),
Expand Down
55 changes: 4 additions & 51 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
gen_trtllm_gen_fmha_module,
)
from .cudnn import cudnn_batch_prefill_with_kv_cache
from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens
from .page import get_seq_lens
from .quantization import packbits, segment_packbits
from .utils import (
FP4Tensor,
Expand Down Expand Up @@ -1424,16 +1424,6 @@ def __init__(
* self._float_workspace_buffer.element_size()
)
self.device = float_workspace_buffer.device
self._vector_sparse_indptr_buffer: Optional[torch.Tensor] = None
if backend in ["fa3", "auto", "trtllm-gen"]:
# NOTE(Zihao): assume maximum accumulate kv length is 16M
self._vector_sparse_indices_buffer = torch.empty(
(16 * 1024 * 1024,), dtype=torch.int32, device=self.device
)
# NOTE(Zihao): assume maximum batch size is 32768
self._vector_sparse_indptr_buffer = torch.empty(
(32768,), dtype=torch.int32, device=self.device
)

self._kv_lens_buffer = torch.empty(
(32768,), dtype=torch.int32, device=self.device
Expand Down Expand Up @@ -1839,22 +1829,6 @@ def plan(
self._backend, *get_module_args
)

if self._backend == "fa3" or self._backend == "trtllm-gen":
if page_size != 1:
vector_sparse_indptr_host = torch.cat(
[
torch.tensor(
[0], dtype=torch.int32, device=kv_lens_arr_host.device
),
torch.cumsum(kv_lens_arr_host, dim=0, dtype=torch.int32),
],
dim=0,
)
self._vector_sparse_indptr_buffer[
: len(vector_sparse_indptr_host)
].copy_(vector_sparse_indptr_host, non_blocking=non_blocking)
paged_kv_indptr_host = vector_sparse_indptr_host

self._block_tables = block_tables
if self._backend == "trtllm-gen":
assert logits_soft_cap == 0.0
Expand Down Expand Up @@ -2042,13 +2016,10 @@ def run(
q, k_cache, self._cached_q_data_type, self._cached_kv_data_type
)

stride_block = k_cache.stride(0)
if self._kv_layout == "NHD":
page_size = k_cache.shape[1]
stride_n = k_cache.stride(1)
else:
page_size = k_cache.shape[2]
stride_n = k_cache.stride(2)
window_left = self._window_left if window_left is None else window_left
if self._backend != "trtllm-gen":
# NOTE(Siyuan): since window_left is appeared in the plan function, we need to make sure it is the same as the one in the plan function.
Expand Down Expand Up @@ -2106,24 +2077,6 @@ def run(
if self._prefix_len_ptr is not None:
mask_mode = MaskMode.MULTIITEMSCORING.value

if self._backend == "fa3":
# NOTE(Zihao): we divide both stride_block and stride_n by stride_n
# because we will multiply stride_n back in the kernel
sparse_indices = block_sparse_indices_to_vector_sparse_offsets(
self._paged_kv_indices_buf,
self._paged_kv_indptr_buf,
self._vector_sparse_indices_buffer, # output
self._vector_sparse_indptr_buffer,
self._kv_lens_buffer,
stride_block // stride_n,
1, # stride_n // stride_n
page_size,
)
sparse_indptr = self._vector_sparse_indptr_buffer
else:
sparse_indices = self._paged_kv_indices_buf
sparse_indptr = self._paged_kv_indptr_buf

if self._backend == "cudnn":
if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1:
self._seq_lens_q = self._seq_lens_q.reshape(self._batch_size, 1, 1, 1)
Expand Down Expand Up @@ -2160,8 +2113,8 @@ def run(
k_cache,
v_cache,
self._qo_indptr_buf,
sparse_indptr,
sparse_indices,
self._paged_kv_indptr_buf,
self._paged_kv_indices_buf,
self._paged_kv_last_page_len_buf,
out,
lse,
Expand Down Expand Up @@ -2198,7 +2151,7 @@ def run(
self._max_kv_len,
self._batch_size,
self._qo_indptr_buf,
self._vector_sparse_indptr_buffer,
self._paged_kv_indptr_buf,
sinks,
]

Expand Down
Loading