diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index 6011ba2063..2051a9258a 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -50,7 +50,7 @@ Array BatchPrefillWithKVCachePlan( TensorView kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size, - bool disable_split_kv, int64_t num_colocated_ctas = 0) { + bool disable_split_kv, int64_t fixed_cta_tile_q = -1, int64_t num_colocated_ctas = 0) { size_t float_workspace_size_in_bytes = float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = @@ -66,8 +66,8 @@ Array BatchPrefillWithKVCachePlan( int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr.data_ptr()), static_cast(kv_indptr.data_ptr()), total_num_rows, batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, - /*sizeof_dtype_o=*/2, window_left, fixed_split_size, disable_split_kv, num_colocated_ctas, - stream); + /*sizeof_dtype_o=*/2, window_left, fixed_split_size, disable_split_kv, fixed_cta_tile_q, + num_colocated_ctas, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "Failed to plan prefill with error: " << cudaGetErrorString(status); diff --git a/csrc/batch_prefill_jit_binding.cu b/csrc/batch_prefill_jit_binding.cu index 3dda0f115a..aad0ca8c53 100644 --- a/csrc/batch_prefill_jit_binding.cu +++ b/csrc/batch_prefill_jit_binding.cu @@ -25,7 +25,7 @@ Array BatchPrefillWithKVCachePlan( TensorView kv_len_arr, int64_t total_num_rows, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size, - bool disable_split_kv, int64_t num_colocated_ctas); + bool disable_split_kv, int64_t fixed_cta_tile_q, int64_t num_colocated_ctas); void BatchPrefillWithRaggedKVCacheRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer, Array plan_info_vec, diff --git a/flashinfer/decode.py b/flashinfer/decode.py index dff3aafb7c..d0ba852d54 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -62,6 +62,7 @@ TensorLayout, _check_block_tables_shape, _check_cached_qkv_data_type, + _validate_fixed_cta_tile_q, _check_kv_layout, _check_pos_encoding_mode, check_shape_dtype_device, @@ -861,6 +862,7 @@ def plan( seq_lens: Optional[torch.Tensor] = None, fixed_split_size: Optional[int] = None, disable_split_kv: bool = False, + fixed_cta_tile_q: Optional[int] = None, ) -> None: r"""Plan batch decode for given problem specification. @@ -919,6 +921,9 @@ def plan( and lead to a varied number of launched CTAs. disable_split_kv : bool, Whether to disable the split-kv for determinism in CUDA Graph, defaults to ``False``. + fixed_cta_tile_q : Optional[int] + Fixed CTA tile size for FA2 attention planning. Supported values are ``16``, ``64``, and + ``128``. Defaults to ``None`` (auto heuristic). Note ---- The :meth:`plan` method should be called before any :meth:`run` or @@ -995,8 +1000,13 @@ def plan( raise ValueError( "fixed_split_size is only supported by tensor core decode for now." ) + if fixed_cta_tile_q is not None and not self.use_tensor_cores: + raise ValueError( + "fixed_cta_tile_q is only supported by tensor core decode for now." + ) if fixed_split_size is None: fixed_split_size = -1 + fixed_cta_tile_q = _validate_fixed_cta_tile_q(fixed_cta_tile_q, head_dim) self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type @@ -1075,6 +1085,11 @@ def plan( ) else: self._backend = "fa2" + if fixed_cta_tile_q != -1 and self._backend != "fa2": + raise ValueError( + f"fixed_cta_tile_q is only supported for the fa2 backend, " + f"got backend={self._backend!r}" + ) self._cached_module = get_batch_prefill_module( self._backend, q_data_type, @@ -1110,6 +1125,7 @@ def plan( if self._backend == "fa2": args.append(fixed_split_size) args.append(disable_split_kv) + args.append(fixed_cta_tile_q) args.append(0) # num_colocated_ctas self._plan_info = self._cached_module.plan( *args, @@ -2857,6 +2873,7 @@ def fast_decode_plan( fixed_split_size: Optional[int] = None, disable_split_kv: bool = False, global_override_indptr_cpu: Optional[torch.Tensor] = None, + fixed_cta_tile_q: Optional[int] = None, ) -> None: """ A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend. @@ -2880,11 +2897,21 @@ def fast_decode_plan( if kv_data_type is None: kv_data_type = q_data_type + if fixed_cta_tile_q is not None and not self.use_tensor_cores: + raise ValueError( + "fixed_cta_tile_q is only supported by tensor core decode for now." + ) if self.use_tensor_cores: qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") # Here we set fixed_split_size to -1 to avoid the assertion error in flashinfer's plan function if fixed_split_size is None: fixed_split_size = -1 + fixed_cta_tile_q = _validate_fixed_cta_tile_q(fixed_cta_tile_q, head_dim) + if fixed_cta_tile_q != -1 and self._backend != "fa2": + raise ValueError( + f"fixed_cta_tile_q is only supported for the fa2 backend, " + f"got backend={self._backend!r}" + ) if self.is_cuda_graph_enabled: if batch_size != self._fixed_batch_size: @@ -2947,7 +2974,7 @@ def fast_decode_plan( kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) try: - # Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend + # Make sure we pass exactly 20 arguments for fa2 backend and 16 arguments for fa3 backend args = [ self._float_workspace_buffer, self._int_workspace_buffer, @@ -2969,6 +2996,7 @@ def fast_decode_plan( if self._backend == "fa2": args.append(fixed_split_size) args.append(disable_split_kv) + args.append(fixed_cta_tile_q) args.append(0) # num_colocated_ctas self._plan_info = self._cached_module.plan( *args, diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 4fa2d9bf0d..29687f7918 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -426,6 +426,7 @@ def plan( window_left, -1, # fixed_split_size False, # disable_split_kv + -1, # fixed_cta_tile_q 0, # num_colocated_ctas ) @@ -981,6 +982,7 @@ def plan( window_left, -1, # fixed_split_size False, # disable_split_kv + -1, # fixed_cta_tile_q 0, # num_colocated_ctas ) @@ -1007,6 +1009,7 @@ def plan( window_left, -1, # fixed_split_size False, # disable_split_kv + -1, # fixed_cta_tile_q num_colocated_ctas, ) self._indptr_type = kv_indptr_p.dtype diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 8b1baa33df..30dea989d4 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -57,6 +57,7 @@ TensorLayout, _check_block_tables_shape, _check_cached_qkv_data_type, + _validate_fixed_cta_tile_q, _check_kv_layout, _check_pos_encoding_mode, check_shape_dtype_device, @@ -1731,6 +1732,7 @@ def plan( max_sequence_kv: Optional[int] = None, fixed_split_size: Optional[int] = None, disable_split_kv: bool = False, + fixed_cta_tile_q: Optional[int] = None, ) -> None: r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification. @@ -1841,6 +1843,9 @@ def plan( and lead to a varied number of launched CTAs. disable_split_kv : bool, Whether to disable the split-kv for determinism in CUDA Graph, defaults to ``False``. + fixed_cta_tile_q : Optional[int] + Fixed CTA tile size for FA2 prefill. Supported values are ``16``, ``64``, and ``128``. + Defaults to ``None`` (auto heuristic). Note ---- The :meth:`plan` method should be called before any :meth:`run` or @@ -1867,6 +1872,7 @@ def plan( head_dim_vo = head_dim_qk if fixed_split_size is None: fixed_split_size = -1 + fixed_cta_tile_q = _validate_fixed_cta_tile_q(fixed_cta_tile_q, head_dim_vo) batch_size = len(qo_indptr) - 1 self._batch_size = batch_size @@ -2010,6 +2016,11 @@ def plan( q_data_type, kv_data_type, ) + if fixed_cta_tile_q != -1 and self._backend != "fa2": + raise ValueError( + f"fixed_cta_tile_q is only supported for the fa2 backend, " + f"got backend={self._backend!r}" + ) if self._backend != "cudnn": get_module_args = ( q_data_type, @@ -2083,6 +2094,7 @@ def plan( if self._backend == "fa2": args.append(fixed_split_size or -1) # fixed_split_size args.append(disable_split_kv) # disable_split_kv + args.append(fixed_cta_tile_q) # fixed_cta_tile_q args.append(0) # num_colocated_ctas self._plan_info = self._cached_module.plan( *args, @@ -2819,6 +2831,7 @@ def plan( max_sequence_kv: Optional[int] = None, v_indptr: Optional[torch.Tensor] = None, o_indptr: Optional[torch.Tensor] = None, + fixed_cta_tile_q: Optional[int] = None, ) -> None: r"""Plan batch prefill/append attention on Ragged KV-Cache for given problem specification. @@ -2913,6 +2926,9 @@ def plan( and lead to a varied number of launched CTAs. disable_split_kv : bool, Whether to disable the split-kv for determinism in CUDA Graph, defaults to ``False``. + fixed_cta_tile_q : Optional[int] + Fixed CTA tile size for FA2 prefill. Supported values are ``16``, ``64``, and ``128``. + Defaults to ``None`` (auto heuristic). seq_lens: Optional[torch.Tensor] A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``. seq_lens_q: Optional[torch.Tensor] @@ -2949,6 +2965,7 @@ def plan( head_dim_vo = head_dim_qk if fixed_split_size is None: fixed_split_size = -1 + fixed_cta_tile_q = _validate_fixed_cta_tile_q(fixed_cta_tile_q, head_dim_vo) if logits_soft_cap is None: logits_soft_cap = 0.0 @@ -3104,6 +3121,11 @@ def plan( q_data_type, kv_data_type, ) + if fixed_cta_tile_q != -1 and self._backend != "fa2": + raise ValueError( + f"fixed_cta_tile_q is only supported for the fa2 backend, " + f"got backend={self._backend!r}" + ) get_module_args = ( q_data_type, @@ -3156,6 +3178,7 @@ def plan( if self._backend == "fa2": args.append(fixed_split_size or -1) # fixed_split_size args.append(disable_split_kv) # disable_split_kv + args.append(fixed_cta_tile_q) # fixed_cta_tile_q args.append(0) # num_colocated_ctas self._plan_info = self._cached_module.plan( *args, diff --git a/flashinfer/sparse.py b/flashinfer/sparse.py index 7e0f3d90cb..86d7056941 100644 --- a/flashinfer/sparse.py +++ b/flashinfer/sparse.py @@ -452,6 +452,7 @@ def plan( if self._backend == "fa2": args.append(-1) # fixed_split_size args.append(False) # disable_split_kv + args.append(-1) # fixed_cta_tile_q args.append(0) # num_colocated_ctas self._plan_info = self._cached_module.plan( *args, @@ -1000,6 +1001,7 @@ def _block_mask_map_to_expanded_indices( if self._backend == "fa2": args.append(-1) # fixed_split_size args.append(False) # disable_split_kv + args.append(-1) # fixed_cta_tile_q args.append(0) # num_colocated_ctas self._plan_info = self._cached_module.plan( *args, diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 6decfe1989..19365b3530 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -155,6 +155,22 @@ def _check_kv_layout(kv_layout: str) -> None: raise KeyError("Invalid kv_layout {}".format(kv_layout)) +def _validate_fixed_cta_tile_q(fixed_cta_tile_q: Optional[int], head_dim: int) -> int: + """Validate fixed_cta_tile_q and return the integer value to pass to the kernel + (-1 means auto-select).""" + if fixed_cta_tile_q is None: + return -1 + if fixed_cta_tile_q not in (16, 64, 128): + raise ValueError( + f"fixed_cta_tile_q should be one of {{16, 64, 128}}, got {fixed_cta_tile_q}" + ) + if fixed_cta_tile_q == 128 and head_dim >= 256: + raise ValueError( + f"fixed_cta_tile_q=128 is not supported with head_dim={head_dim} (requires head_dim < 256)" + ) + return fixed_cta_tile_q + + def is_float8(x: torch.Tensor) -> bool: return x.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index 286023e204..6859b33453 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -497,7 +497,8 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, uint32_t max_batch_size_if_split, bool enable_cuda_graph, int32_t window_left, - int32_t fixed_split_size, bool disable_split_kv) { + int32_t fixed_split_size, bool disable_split_kv, + int32_t fixed_cta_tile_q) { std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; merge_indptr.push_back(0); o_indptr.push_back(0); @@ -527,19 +528,26 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); uint32_t cta_tile_q; uint32_t total_num_tiles_q; - if (enable_cuda_graph) { + if (fixed_cta_tile_q > 0) { + if (fixed_cta_tile_q != 16 && fixed_cta_tile_q != 64 && fixed_cta_tile_q != 128) { + std::ostringstream err_msg; + err_msg << "fixed_cta_tile_q should be one of {16, 64, 128}, but got " << fixed_cta_tile_q; + FLASHINFER_ERROR(err_msg.str()); + } + if (fixed_cta_tile_q == 128 && head_dim >= 256) { + std::ostringstream err_msg; + err_msg << "fixed_cta_tile_q=128 is not supported with head_dim=" << head_dim + << " (requires head_dim < 256)"; + FLASHINFER_ERROR(err_msg.str()); + } + cta_tile_q = fixed_cta_tile_q; + } else if (enable_cuda_graph) { // When CUDA graphs are enabled, the lengths of sequences determined by // qo_indptr_h can vary. We assume that the dummy data based on which // the CUDA graph is created fixes the maximum number of tokens. const uint64_t max_seq_len = total_num_rows - batch_size + 1; uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim); - - // Find an upper bound for the number of tiles, derived from the total - // number of rows and the batch size. The sum of qo lengths rounded - // up to cta_tile_q will not exceed this number derived from the total - // number of rows. - total_num_tiles_q = ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - 1; } else { int64_t sum_packed_qo_len = 0; for (uint32_t i = 0; i < batch_size; ++i) { @@ -547,7 +555,15 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, } const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim); + } + if (enable_cuda_graph) { + // Find an upper bound for the number of tiles, derived from the total + // number of rows and the batch size. The sum of qo lengths rounded + // up to cta_tile_q will not exceed this number derived from the total + // number of rows. + total_num_tiles_q = ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - 1; + } else { total_num_tiles_q = 0; for (uint32_t i = 0; i < batch_size; ++i) { total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q); @@ -699,6 +715,7 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o, int32_t window_left, int32_t fixed_split_size, bool disable_split_kv, + int32_t fixed_cta_tile_q, int64_t num_colocated_ctas, // for POD attention, limit prefill // splits by #colocated decode CTAs cudaStream_t stream) { @@ -724,7 +741,8 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads, num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, - enable_cuda_graph, window_left, fixed_split_size, disable_split_kv); + enable_cuda_graph, window_left, fixed_split_size, disable_split_kv, + fixed_cta_tile_q); plan_info.cta_tile_q = cta_tile_q; plan_info.total_num_rows = total_num_rows; diff --git a/tests/attention/test_batch_invariant_fa2.py b/tests/attention/test_batch_invariant_fa2.py index ea7abeb2c7..f6453f0ac8 100644 --- a/tests/attention/test_batch_invariant_fa2.py +++ b/tests/attention/test_batch_invariant_fa2.py @@ -58,6 +58,7 @@ def warmup_jit(): @pytest.mark.parametrize("kv_len", [4096, 8192, 5000]) @pytest.mark.parametrize("fixed_split_size", [2048]) @pytest.mark.parametrize("disable_split_kv", [True, False]) +@pytest.mark.parametrize("fixed_cta_tile_q", [16, 64, 128]) @pytest.mark.parametrize("page_size", [1, 8, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("group_size", [1, 4, 8]) @@ -70,6 +71,7 @@ def test_batch_decode_tensor_cores( kv_len: int, fixed_split_size: int, disable_split_kv: bool, + fixed_cta_tile_q: int, page_size: int, num_kv_heads: int, group_size: int, @@ -77,6 +79,8 @@ def test_batch_decode_tensor_cores( kv_layout: str, pos_encoding_mode: str, ): + if head_dim >= 256 and fixed_cta_tile_q == 128: + pytest.skip("fixed_cta_tile_q=128 is not supported with head_dim >= 256") num_qo_heads = num_kv_heads * group_size q = torch.randn( batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 @@ -133,6 +137,7 @@ def test_batch_decode_tensor_cores( q_data_type=torch.float16, fixed_split_size=fixed_split_size if not disable_split_kv else None, disable_split_kv=disable_split_kv, + fixed_cta_tile_q=fixed_cta_tile_q, ) o_tensor_cores, lse_tensor_cores = wrapper_tensor_cores.run( q, kv_data, return_lse=True @@ -153,6 +158,7 @@ def test_batch_decode_tensor_cores( q_data_type=torch.float16, fixed_split_size=fixed_split_size if not disable_split_kv else None, disable_split_kv=disable_split_kv, + fixed_cta_tile_q=fixed_cta_tile_q, ) o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tensor_cores.run( q[:invariant_bs], kv_data, return_lse=True @@ -195,6 +201,7 @@ def test_batch_decode_tensor_cores( @pytest.mark.parametrize("qo_len", [128, 256]) @pytest.mark.parametrize("fixed_split_size", [2048]) @pytest.mark.parametrize("disable_split_kv", [True, False]) +@pytest.mark.parametrize("fixed_cta_tile_q", [16, 64, 128]) @pytest.mark.parametrize("page_size", [1, 8, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("group_size", [1, 4, 8]) @@ -208,6 +215,7 @@ def test_batch_prefill_tensor_cores( qo_len: int, fixed_split_size: int, disable_split_kv: bool, + fixed_cta_tile_q: int, page_size: int, num_kv_heads: int, group_size: int, @@ -215,6 +223,8 @@ def test_batch_prefill_tensor_cores( kv_layout: str, pos_encoding_mode: str, ): + if head_dim >= 256 and fixed_cta_tile_q == 128: + pytest.skip("fixed_cta_tile_q=128 is not supported with head_dim >= 256") num_qo_heads = num_kv_heads * group_size q = torch.randn( batch_size * qo_len, @@ -281,6 +291,7 @@ def test_batch_prefill_tensor_cores( kv_data_type=torch.float16, fixed_split_size=fixed_split_size if not disable_split_kv else None, disable_split_kv=disable_split_kv, + fixed_cta_tile_q=fixed_cta_tile_q, ) o_tensor_cores, lse_tensor_cores = wrapper_tensor_cores.run( q, kv_data, return_lse=True @@ -304,6 +315,7 @@ def test_batch_prefill_tensor_cores( kv_data_type=torch.float16, fixed_split_size=fixed_split_size if not disable_split_kv else None, disable_split_kv=disable_split_kv, + fixed_cta_tile_q=fixed_cta_tile_q, ) o_tensor_cores_invariant, lse_tensor_cores_invariant = wrapper_tensor_cores.run( q[: invariant_bs * qo_len], kv_data, return_lse=True diff --git a/tests/attention/test_batch_prefill.py b/tests/attention/test_batch_prefill.py index 61211b629e..fd1aaa5775 100644 --- a/tests/attention/test_batch_prefill.py +++ b/tests/attention/test_batch_prefill.py @@ -118,3 +118,137 @@ def test_kv_scale_forwarding_math_property(dtype: torch.dtype): ) out3_ref, _ = wrapper.forward_return_lse(q * k_scale, paged_kv_cache) torch.testing.assert_close(out3, out3_ref * v_scale, rtol=1e-2, atol=1e-3) + + +def test_batch_prefill_invalid_fixed_cta_tile_q(): + batch_size = 2 + qo_len = 8 + kv_len = 128 + page_size = 16 + num_kv_heads = 2 + group_size = 2 + num_qo_heads = num_kv_heads * group_size + head_dim = 64 + + q_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len + ) + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + ) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") + with pytest.raises(ValueError, match="fixed_cta_tile_q should be one of"): + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + q_data_type=torch.float16, + kv_data_type=torch.float16, + fixed_split_size=2048, + disable_split_kv=True, + fixed_cta_tile_q=32, + ) + + +def test_batch_prefill_fixed_cta_tile_q_incompatible_head_dim(): + batch_size = 2 + qo_len = 8 + kv_len = 128 + page_size = 16 + num_kv_heads = 2 + group_size = 2 + num_qo_heads = num_kv_heads * group_size + head_dim = 256 # fixed_cta_tile_q=128 is invalid for head_dim >= 256 + + q_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len + ) + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + ) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") + with pytest.raises( + ValueError, match="fixed_cta_tile_q=128 is not supported with head_dim" + ): + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + q_data_type=torch.float16, + kv_data_type=torch.float16, + fixed_split_size=2048, + disable_split_kv=True, + fixed_cta_tile_q=128, + ) + + +def test_batch_prefill_fixed_cta_tile_q_rejected_for_non_fa2_backend(): + """fixed_cta_tile_q must raise when the resolved backend is not fa2.""" + batch_size, qo_len, kv_len, page_size, num_kv_heads, head_dim = ( + 2, + 64, + 512, + 16, + 4, + 128, + ) + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + q_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len + ) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), page_size, dtype=torch.int32, device="cuda:0" + ) + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") + wrapper._backend = "fa3" + with pytest.raises( + ValueError, match="fixed_cta_tile_q is only supported for the fa2 backend" + ): + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + num_kv_heads, + num_kv_heads, + head_dim, + page_size, + q_data_type=torch.float16, + kv_data_type=torch.float16, + fixed_cta_tile_q=64, + ) diff --git a/tests/attention/test_tensor_cores_decode.py b/tests/attention/test_tensor_cores_decode.py index 19db15a640..2eeb3b986a 100644 --- a/tests/attention/test_tensor_cores_decode.py +++ b/tests/attention/test_tensor_cores_decode.py @@ -651,6 +651,74 @@ def test_batch_fast_decode_tensor_cores_cuda_graph( torch.testing.assert_close(lse, lse_tensor_cores, rtol=1e-3, atol=1e-3) +def test_batch_decode_fixed_cta_tile_q_rejected_without_tensor_cores(): + """fixed_cta_tile_q must raise when use_tensor_cores=False.""" + batch_size, kv_len, page_size, num_kv_heads, head_dim = 2, 512, 16, 4, 128 + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), page_size, dtype=torch.int32, device="cuda:0" + ) + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, use_tensor_cores=False + ) + with pytest.raises( + ValueError, match="fixed_cta_tile_q is only supported by tensor core decode" + ): + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + num_kv_heads, + num_kv_heads, + head_dim, + page_size, + data_type=torch.float16, + fixed_cta_tile_q=64, + ) + + +def test_batch_decode_fixed_cta_tile_q_rejected_for_non_fa2_backend(): + """fixed_cta_tile_q must raise when the resolved tensor-core backend is not fa2.""" + batch_size, kv_len, page_size, num_kv_heads, head_dim = 2, 512, 16, 4, 128 + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), page_size, dtype=torch.int32, device="cuda:0" + ) + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, use_tensor_cores=True + ) + # Patch _backend to simulate a resolved non-fa2 tensor-core backend without needing fa3 hardware. + wrapper._backend = "fa3" + with pytest.raises( + ValueError, match="fixed_cta_tile_q is only supported for the fa2 backend" + ): + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + num_kv_heads, + num_kv_heads, + head_dim, + page_size, + data_type=torch.float16, + fixed_cta_tile_q=64, + ) + + if __name__ == "__main__": test_batch_decode_tensor_cores_with_fast_plan( 5, 4, 4096, 2048, True, 1, 4, 1, 128, "HND", "NONE"