Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Array<int64_t> BatchPrefillWithKVCachePlan(
TensorView kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size,
bool disable_split_kv, int64_t num_colocated_ctas = 0) {
bool disable_split_kv, int64_t fixed_cta_tile_q = -1, int64_t num_colocated_ctas = 0) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer);
size_t int_workspace_size_in_bytes =
Expand All @@ -66,8 +66,8 @@ Array<int64_t> BatchPrefillWithKVCachePlan(
int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(qo_indptr.data_ptr()),
static_cast<IdType*>(kv_indptr.data_ptr()), total_num_rows, batch_size, num_qo_heads,
num_kv_heads, head_dim_qk, head_dim_vo, page_size, enable_cuda_graph,
/*sizeof_dtype_o=*/2, window_left, fixed_split_size, disable_split_kv, num_colocated_ctas,
stream);
/*sizeof_dtype_o=*/2, window_left, fixed_split_size, disable_split_kv, fixed_cta_tile_q,
num_colocated_ctas, stream);

TVM_FFI_ICHECK(status == cudaSuccess)
<< "Failed to plan prefill with error: " << cudaGetErrorString(status);
Expand Down
2 changes: 1 addition & 1 deletion csrc/batch_prefill_jit_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Array<int64_t> BatchPrefillWithKVCachePlan(
TensorView kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk,
int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size,
bool disable_split_kv, int64_t num_colocated_ctas);
bool disable_split_kv, int64_t fixed_cta_tile_q, int64_t num_colocated_ctas);

void BatchPrefillWithRaggedKVCacheRun(TensorView float_workspace_buffer,
TensorView int_workspace_buffer, Array<int64_t> plan_info_vec,
Expand Down
30 changes: 29 additions & 1 deletion flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
TensorLayout,
_check_block_tables_shape,
_check_cached_qkv_data_type,
_validate_fixed_cta_tile_q,
_check_kv_layout,
_check_pos_encoding_mode,
check_shape_dtype_device,
Expand Down Expand Up @@ -861,6 +862,7 @@ def plan(
seq_lens: Optional[torch.Tensor] = None,
fixed_split_size: Optional[int] = None,
disable_split_kv: bool = False,
fixed_cta_tile_q: Optional[int] = None,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
) -> None:
r"""Plan batch decode for given problem specification.

Expand Down Expand Up @@ -919,6 +921,9 @@ def plan(
and lead to a varied number of launched CTAs.
disable_split_kv : bool,
Whether to disable the split-kv for determinism in CUDA Graph, defaults to ``False``.
fixed_cta_tile_q : Optional[int]
Fixed CTA tile size for FA2 attention planning. Supported values are ``16``, ``64``, and
``128``. Defaults to ``None`` (auto heuristic).
Note
----
The :meth:`plan` method should be called before any :meth:`run` or
Expand Down Expand Up @@ -995,8 +1000,13 @@ def plan(
raise ValueError(
"fixed_split_size is only supported by tensor core decode for now."
)
if fixed_cta_tile_q is not None and not self.use_tensor_cores:
raise ValueError(
"fixed_cta_tile_q is only supported by tensor core decode for now."
)
if fixed_split_size is None:
fixed_split_size = -1
fixed_cta_tile_q = _validate_fixed_cta_tile_q(fixed_cta_tile_q, head_dim)

self._cached_q_data_type = q_data_type
self._cached_kv_data_type = kv_data_type
Expand Down Expand Up @@ -1075,6 +1085,11 @@ def plan(
)
else:
self._backend = "fa2"
if fixed_cta_tile_q != -1 and self._backend != "fa2":
raise ValueError(
f"fixed_cta_tile_q is only supported for the fa2 backend, "
f"got backend={self._backend!r}"
)
self._cached_module = get_batch_prefill_module(
self._backend,
q_data_type,
Expand Down Expand Up @@ -1110,6 +1125,7 @@ def plan(
if self._backend == "fa2":
args.append(fixed_split_size)
args.append(disable_split_kv)
args.append(fixed_cta_tile_q)
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
Expand Down Expand Up @@ -2857,6 +2873,7 @@ def fast_decode_plan(
fixed_split_size: Optional[int] = None,
disable_split_kv: bool = False,
global_override_indptr_cpu: Optional[torch.Tensor] = None,
fixed_cta_tile_q: Optional[int] = None,
) -> None:
"""
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
Expand All @@ -2880,11 +2897,21 @@ def fast_decode_plan(
if kv_data_type is None:
kv_data_type = q_data_type

if fixed_cta_tile_q is not None and not self.use_tensor_cores:
raise ValueError(
"fixed_cta_tile_q is only supported by tensor core decode for now."
)
if self.use_tensor_cores:
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")
# Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function
if fixed_split_size is None:
fixed_split_size = -1
fixed_cta_tile_q = _validate_fixed_cta_tile_q(fixed_cta_tile_q, head_dim)
if fixed_cta_tile_q != -1 and self._backend != "fa2":
raise ValueError(
f"fixed_cta_tile_q is only supported for the fa2 backend, "
f"got backend={self._backend!r}"
)

if self.is_cuda_graph_enabled:
if batch_size != self._fixed_batch_size:
Expand Down Expand Up @@ -2947,7 +2974,7 @@ def fast_decode_plan(
kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size)

try:
# Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend
# Make sure we pass exactly 20 arguments for fa2 backend and 16 arguments for fa3 backend
args = [
self._float_workspace_buffer,
self._int_workspace_buffer,
Expand All @@ -2969,6 +2996,7 @@ def fast_decode_plan(
if self._backend == "fa2":
args.append(fixed_split_size)
args.append(disable_split_kv)
args.append(fixed_cta_tile_q)
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
Expand Down
3 changes: 3 additions & 0 deletions flashinfer/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def plan(
window_left,
-1, # fixed_split_size
False, # disable_split_kv
-1, # fixed_cta_tile_q
0, # num_colocated_ctas
)

Expand Down Expand Up @@ -981,6 +982,7 @@ def plan(
window_left,
-1, # fixed_split_size
False, # disable_split_kv
-1, # fixed_cta_tile_q
0, # num_colocated_ctas
)

Expand All @@ -1007,6 +1009,7 @@ def plan(
window_left,
-1, # fixed_split_size
False, # disable_split_kv
-1, # fixed_cta_tile_q
num_colocated_ctas,
)
self._indptr_type = kv_indptr_p.dtype
Expand Down
23 changes: 23 additions & 0 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
TensorLayout,
_check_block_tables_shape,
_check_cached_qkv_data_type,
_validate_fixed_cta_tile_q,
_check_kv_layout,
_check_pos_encoding_mode,
check_shape_dtype_device,
Expand Down Expand Up @@ -1731,6 +1732,7 @@ def plan(
max_sequence_kv: Optional[int] = None,
fixed_split_size: Optional[int] = None,
disable_split_kv: bool = False,
fixed_cta_tile_q: Optional[int] = None,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
) -> None:
r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification.

Expand Down Expand Up @@ -1841,6 +1843,9 @@ def plan(
and lead to a varied number of launched CTAs.
disable_split_kv : bool,
Whether to disable the split-kv for determinism in CUDA Graph, defaults to ``False``.
fixed_cta_tile_q : Optional[int]
Fixed CTA tile size for FA2 prefill. Supported values are ``16``, ``64``, and ``128``.
Defaults to ``None`` (auto heuristic).
Note
----
The :meth:`plan` method should be called before any :meth:`run` or
Expand All @@ -1867,6 +1872,7 @@ def plan(
head_dim_vo = head_dim_qk
if fixed_split_size is None:
fixed_split_size = -1
fixed_cta_tile_q = _validate_fixed_cta_tile_q(fixed_cta_tile_q, head_dim_vo)

batch_size = len(qo_indptr) - 1
self._batch_size = batch_size
Expand Down Expand Up @@ -2010,6 +2016,11 @@ def plan(
q_data_type,
kv_data_type,
)
if fixed_cta_tile_q != -1 and self._backend != "fa2":
raise ValueError(
f"fixed_cta_tile_q is only supported for the fa2 backend, "
f"got backend={self._backend!r}"
)
if self._backend != "cudnn":
get_module_args = (
q_data_type,
Expand Down Expand Up @@ -2083,6 +2094,7 @@ def plan(
if self._backend == "fa2":
args.append(fixed_split_size or -1) # fixed_split_size
args.append(disable_split_kv) # disable_split_kv
args.append(fixed_cta_tile_q) # fixed_cta_tile_q
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
Expand Down Expand Up @@ -2819,6 +2831,7 @@ def plan(
max_sequence_kv: Optional[int] = None,
v_indptr: Optional[torch.Tensor] = None,
o_indptr: Optional[torch.Tensor] = None,
fixed_cta_tile_q: Optional[int] = None,
) -> None:
r"""Plan batch prefill/append attention on Ragged KV-Cache for given problem specification.

Expand Down Expand Up @@ -2913,6 +2926,9 @@ def plan(
and lead to a varied number of launched CTAs.
disable_split_kv : bool,
Whether to disable the split-kv for determinism in CUDA Graph, defaults to ``False``.
fixed_cta_tile_q : Optional[int]
Fixed CTA tile size for FA2 prefill. Supported values are ``16``, ``64``, and ``128``.
Defaults to ``None`` (auto heuristic).
seq_lens: Optional[torch.Tensor]
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``.
seq_lens_q: Optional[torch.Tensor]
Expand Down Expand Up @@ -2949,6 +2965,7 @@ def plan(
head_dim_vo = head_dim_qk
if fixed_split_size is None:
fixed_split_size = -1
fixed_cta_tile_q = _validate_fixed_cta_tile_q(fixed_cta_tile_q, head_dim_vo)
if logits_soft_cap is None:
logits_soft_cap = 0.0

Expand Down Expand Up @@ -3104,6 +3121,11 @@ def plan(
q_data_type,
kv_data_type,
)
if fixed_cta_tile_q != -1 and self._backend != "fa2":
raise ValueError(
f"fixed_cta_tile_q is only supported for the fa2 backend, "
f"got backend={self._backend!r}"
)

get_module_args = (
q_data_type,
Expand Down Expand Up @@ -3156,6 +3178,7 @@ def plan(
if self._backend == "fa2":
args.append(fixed_split_size or -1) # fixed_split_size
args.append(disable_split_kv) # disable_split_kv
args.append(fixed_cta_tile_q) # fixed_cta_tile_q
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def plan(
if self._backend == "fa2":
args.append(-1) # fixed_split_size
args.append(False) # disable_split_kv
args.append(-1) # fixed_cta_tile_q
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
Expand Down Expand Up @@ -1000,6 +1001,7 @@ def _block_mask_map_to_expanded_indices(
if self._backend == "fa2":
args.append(-1) # fixed_split_size
args.append(False) # disable_split_kv
args.append(-1) # fixed_cta_tile_q
args.append(0) # num_colocated_ctas
self._plan_info = self._cached_module.plan(
*args,
Expand Down
16 changes: 16 additions & 0 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,22 @@ def _check_kv_layout(kv_layout: str) -> None:
raise KeyError("Invalid kv_layout {}".format(kv_layout))


def _validate_fixed_cta_tile_q(fixed_cta_tile_q: Optional[int], head_dim: int) -> int:
"""Validate fixed_cta_tile_q and return the integer value to pass to the kernel
(-1 means auto-select)."""
if fixed_cta_tile_q is None:
return -1
if fixed_cta_tile_q not in (16, 64, 128):
raise ValueError(
f"fixed_cta_tile_q should be one of {{16, 64, 128}}, got {fixed_cta_tile_q}"
)
if fixed_cta_tile_q == 128 and head_dim >= 256:
raise ValueError(
f"fixed_cta_tile_q=128 is not supported with head_dim={head_dim} (requires head_dim < 256)"
)
return fixed_cta_tile_q


def is_float8(x: torch.Tensor) -> bool:
return x.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]

Expand Down
36 changes: 27 additions & 9 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size, uint32_t max_batch_size_if_split,
bool enable_cuda_graph, int32_t window_left,
int32_t fixed_split_size, bool disable_split_kv) {
int32_t fixed_split_size, bool disable_split_kv,
int32_t fixed_cta_tile_q) {
std::vector<IdType> request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr;
merge_indptr.push_back(0);
o_indptr.push_back(0);
Expand Down Expand Up @@ -527,27 +528,42 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U);
uint32_t cta_tile_q;
uint32_t total_num_tiles_q;
if (enable_cuda_graph) {
if (fixed_cta_tile_q > 0) {
if (fixed_cta_tile_q != 16 && fixed_cta_tile_q != 64 && fixed_cta_tile_q != 128) {
std::ostringstream err_msg;
err_msg << "fixed_cta_tile_q should be one of {16, 64, 128}, but got " << fixed_cta_tile_q;
FLASHINFER_ERROR(err_msg.str());
}
if (fixed_cta_tile_q == 128 && head_dim >= 256) {
std::ostringstream err_msg;
err_msg << "fixed_cta_tile_q=128 is not supported with head_dim=" << head_dim
<< " (requires head_dim < 256)";
FLASHINFER_ERROR(err_msg.str());
}
cta_tile_q = fixed_cta_tile_q;
} else if (enable_cuda_graph) {
// When CUDA graphs are enabled, the lengths of sequences determined by
// qo_indptr_h can vary. We assume that the dummy data based on which
// the CUDA graph is created fixes the maximum number of tokens.
const uint64_t max_seq_len = total_num_rows - batch_size + 1;
uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size;
cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim);

// Find an upper bound for the number of tiles, derived from the total
// number of rows and the batch size. The sum of qo lengths rounded
// up to cta_tile_q will not exceed this number derived from the total
// number of rows.
total_num_tiles_q = ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - 1;
} else {
int64_t sum_packed_qo_len = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
sum_packed_qo_len += packed_qo_len_arr[i];
}
const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim);
}

if (enable_cuda_graph) {
// Find an upper bound for the number of tiles, derived from the total
// number of rows and the batch size. The sum of qo lengths rounded
// up to cta_tile_q will not exceed this number derived from the total
// number of rows.
total_num_tiles_q = ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - 1;
} else {
total_num_tiles_q = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q);
Expand Down Expand Up @@ -699,6 +715,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size,
bool enable_cuda_graph, uint32_t sizeof_dtype_o, int32_t window_left,
int32_t fixed_split_size, bool disable_split_kv,
int32_t fixed_cta_tile_q,
int64_t num_colocated_ctas, // for POD attention, limit prefill
// splits by #colocated decode CTAs
cudaStream_t stream) {
Expand All @@ -724,7 +741,8 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] =
PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads,
num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split,
enable_cuda_graph, window_left, fixed_split_size, disable_split_kv);
enable_cuda_graph, window_left, fixed_split_size, disable_split_kv,
fixed_cta_tile_q);

plan_info.cta_tile_q = cta_tile_q;
plan_info.total_num_rows = total_num_rows;
Expand Down
Loading
Loading