diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index c77d8643e2..a2bfdf6727 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -86,7 +86,8 @@ void trtllm_paged_attention_launcher( const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale, int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q, int64_t sparse_mla_top_k, float skip_softmax_threshold_scale_factor, bool skips_softmax, - int64_t sm_count, bool enable_pdl, int64_t workspace_size, cudaStream_t stream) { + bool uses_shared_paged_kv_idx, int64_t sm_count, bool enable_pdl, int64_t workspace_size, + cudaStream_t stream) { if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads @@ -145,6 +146,7 @@ void trtllm_paged_attention_launcher( window_left == -1 ? INT_MAX : window_left + 1; // disable window attention by INT_MAX runner_params.mMaxSeqLenQ = max_q_len; runner_params.mSumOfSeqLensQ = sum_seq_q; + runner_params.mUsesSharedPagedKvIdx = uses_shared_paged_kv_idx; runner_params.ptrAttentionSinks = attention_sinks; runner_params.enable_pdl = enable_pdl; @@ -236,7 +238,8 @@ void trtllm_paged_attention_decode( int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl, int64_t workspace_size, Optional attention_sinks, Optional cum_seq_lens_q, Optional key_block_scales, - Optional value_block_scales, Optional skip_softmax_threshold_scale_factor) { + Optional value_block_scales, Optional skip_softmax_threshold_scale_factor, + Optional uses_shared_paged_kv_idx) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); @@ -269,6 +272,10 @@ void trtllm_paged_attention_decode( bool is_fp4_kv = is_4bit(kv_data_type); int stride_idx_factor = is_fp4_kv ? 2 : 1; + // FlashInfer/vLLM layout -> true; TRT-LLM layout -> false. + // Default to flashinfer/vLLM layout. + bool const uses_shared_paged_kv_idx_value = uses_shared_paged_kv_idx.value_or(true); + // Assume HND layout after Python-side transpose: [..., H, N, D] int page_size = key_cache.size(-2); int num_kv_heads = key_cache.size(-3); @@ -337,8 +344,8 @@ void trtllm_paged_attention_decode( q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, - sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax, sm_count, - enable_pdl, workspace_size, stream); + sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax, + uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, stream); } void trtllm_paged_attention_context( @@ -350,7 +357,7 @@ void trtllm_paged_attention_context( int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, int64_t workspace_size, Optional attention_sinks, Optional key_block_scales, Optional value_block_scales, - Optional skip_softmax_threshold_scale_factor) { + Optional skip_softmax_threshold_scale_factor, Optional uses_shared_paged_kv_idx) { auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); @@ -373,6 +380,10 @@ void trtllm_paged_attention_context( bool is_fp4_kv = is_4bit(kv_data_type); int stride_idx_factor = is_fp4_kv ? 2 : 1; + // FlashInfer/vLLM layout -> true; TRT-LLM layout -> false. + // Default to flashinfer/vLLM layout. + bool const uses_shared_paged_kv_idx_value = uses_shared_paged_kv_idx.value_or(true); + // Assume HND layout after Python-side transpose: [..., H, N, D] int page_size = key_cache.size(-2); int num_kv_heads = key_cache.size(-3); @@ -444,7 +455,7 @@ void trtllm_paged_attention_context( kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q, /*sparse_mla_top_k=*/0, skip_softmax_threshold_scale_factor_value, skips_softmax, - sm_count, enable_pdl, workspace_size, stream); + uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, stream); } void trtllm_ragged_attention_launcher( diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 04bb7a1112..6964f71319 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -54,6 +54,7 @@ MaskMode, PosEncodingMode, TensorLayout, + _check_block_tables_shape, _check_cached_qkv_data_type, _check_kv_layout, _check_pos_encoding_mode, @@ -1429,6 +1430,7 @@ def run( key_block_scales, value_block_scales, skip_softmax_threshold_scale_factor, + True, # uses_shared_paged_kv_idx ] self._cached_module.paged_run(*run_args) @@ -1982,6 +1984,7 @@ def _paged_run( key_block_scales: Optional[torch.Tensor] = None, value_block_scales: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, + uses_shared_paged_kv_idx: bool = True, ) -> torch.Tensor: if out is None: out = torch.empty_like(query) @@ -2026,6 +2029,7 @@ def _paged_run( key_block_scales, value_block_scales, skip_softmax_threshold_scale_factor, + uses_shared_paged_kv_idx, ) return out @@ -2090,6 +2094,7 @@ def paged_run( key_block_scales: Optional[torch.Tensor] = None, value_block_scales: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, + uses_shared_paged_kv_idx: bool = True, ) -> None: assert maybe_lse is None assert paged_kv_cache is not None @@ -2119,6 +2124,7 @@ def paged_run( key_block_scales=key_block_scales, value_block_scales=value_block_scales, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) @register_fake_op(f"flashinfer::{uri}_paged_run") @@ -2161,6 +2167,7 @@ def _fake_paged_run( key_block_scales: Optional[torch.Tensor] = None, value_block_scales: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, + uses_shared_paged_kv_idx: bool = True, ) -> None: pass @@ -2203,6 +2210,7 @@ def trtllm_batch_decode_with_kv_cache( ] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + uses_shared_paged_kv_idx: bool = True, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters @@ -2221,7 +2229,10 @@ def trtllm_batch_decode_with_kv_cache( workspace block_tables : torch.Tensor - page_table of kv cache, [batch_size, num_pages] + Page table of kv cache. + When ``uses_shared_paged_kv_idx`` is True (default): shape ``[batch_size, max_num_pages_per_seq]``. + When ``uses_shared_paged_kv_idx`` is False: shape ``[batch_size, 2, max_num_pages_per_seq]`` + where dim 1 distinguishes K (0) and V (1) page indices. seq_lens : torch.Tensor A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` @@ -2293,6 +2304,11 @@ def trtllm_batch_decode_with_kv_cache( Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length. + uses_shared_paged_kv_idx : bool = True + Whether the K and V page indices are shared as a unified index. + True (default) uses vLLM/FlashInfer layout with a 2D page table. + False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. + Returns ------- out : Union[torch.Tensor, FP4Tensor] @@ -2353,6 +2369,10 @@ def trtllm_batch_decode_with_kv_cache( raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size") if max_q_len is not None or cum_seq_lens_q is not None: raise ValueError("xqa backend does not support cum_seq_lens_q") + if not uses_shared_paged_kv_idx: + raise ValueError( + "xqa backend does not support uses_shared_paged_kv_idx=False" + ) # Handle out and out_dtype if out_dtype is None: @@ -2486,6 +2506,8 @@ def trtllm_batch_decode_with_kv_cache( assert max_q_len is not None batch_size = cum_seq_lens_q.size(0) - 1 + _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) + run_func( out, out_scale_factor, @@ -2513,6 +2535,7 @@ def trtllm_batch_decode_with_kv_cache( k_block_scales, v_block_scales, skip_softmax_threshold_scale_factor, + uses_shared_paged_kv_idx, ) return ( diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 58764bc0aa..a73d090837 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -25,6 +25,7 @@ from .jit.mla import gen_mla_module from .utils import ( MaskMode, + _check_block_tables_shape, check_shape_dtype_device, determine_mla_backend, device_support_pdl, @@ -134,6 +135,7 @@ def _check_trtllm_gen_mla_shape( sparse_mla_top_k: int, page_table: torch.Tensor, page_size: int, + uses_shared_paged_kv_idx: bool = True, ) -> torch.Tensor: if query.ndim != 4: raise ValueError(f"Expected query.ndim == 4, got {query.ndim}") @@ -173,7 +175,9 @@ def _check_trtllm_gen_mla_shape( f"Expected page_table.shape == (num_seqs, num_tokens, sparse_mla_top_k), got {page_table_shape}" ) else: - B_block_table, block_num = page_table.shape + _check_block_tables_shape(page_table, uses_shared_paged_kv_idx) + B_block_table = page_table.shape[0] + block_num = page_table.shape[-1] block_size = page_size if num_seqs != B_block_table: raise ValueError( @@ -603,6 +607,7 @@ def trtllm_batch_decode_with_kv_cache_mla( skip_softmax_threshold_scale_factor: Optional[float] = None, enable_pdl: bool | None = None, backend: str = "auto", + uses_shared_paged_kv_idx: bool = True, ) -> torch.Tensor: """ Parameters @@ -614,7 +619,11 @@ def trtllm_batch_decode_with_kv_cache_mla( kv_lora_rank: kv_lora_rank, must be 512 or 256 qk_rope_head_dim: qk_rope_head_dim, must be 64 sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA. - block_tables: page_table of kv cache, [batch_size, num_pages] + block_tables: page table of kv cache. + When ``uses_shared_paged_kv_idx`` is True (default): shape ``[batch_size, max_num_pages_per_seq]``. + When ``uses_shared_paged_kv_idx`` is False: shape ``[batch_size, 2, max_num_pages_per_seq]`` + where dim 1 distinguishes K (0) and V (1) page indices. For MLA both rows will + typically be identical since K and V share the same compressed representation. seq_lens: query_len max_seq_len: max sequence length for kv_cache out: output tensor, if not provided, will be allocated internally @@ -633,6 +642,11 @@ def trtllm_batch_decode_with_kv_cache_mla( When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability. For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend. For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend. + uses_shared_paged_kv_idx : bool = True + Whether the K and V page indices are shared as a unified index. + True (default) uses vLLM/FlashInfer layout with a 2D page table. + False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. + False is only supported for trtllm-gen backend. Note ---- @@ -679,6 +693,10 @@ def trtllm_batch_decode_with_kv_cache_mla( ) if skip_softmax_threshold_scale_factor is not None: raise ValueError("skip_softmax is not supported for XQA backend") + if not uses_shared_paged_kv_idx: + raise ValueError( + "XQA MLA does not support separate KV page indices (uses_shared_paged_kv_idx=False)" + ) return xqa_batch_decode_with_kv_cache_mla( query, kv_cache, @@ -721,6 +739,7 @@ def trtllm_batch_decode_with_kv_cache_mla( sparse_mla_top_k, block_tables, block_size, + uses_shared_paged_kv_idx, ) if out is None: @@ -767,6 +786,7 @@ def trtllm_batch_decode_with_kv_cache_mla( None, # key_block_scales None, # value_block_scales skip_softmax_threshold_scale_factor, + uses_shared_paged_kv_idx, ) return out @@ -848,6 +868,7 @@ def xqa_batch_decode_with_kv_cache_mla( 0, # sparse_mla_top_k block_tables, block_size, + True, # XQA always uses shared paged KV index layout ) if out is None: diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 6dcd78305b..8aa6727ab5 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -44,6 +44,7 @@ MaskMode, PosEncodingMode, TensorLayout, + _check_block_tables_shape, _check_cached_qkv_data_type, _check_kv_layout, _check_pos_encoding_mode, @@ -264,6 +265,7 @@ def _paged_run( key_block_scales: Optional[torch.Tensor] = None, value_block_scales: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, + uses_shared_paged_kv_idx: bool = True, ) -> torch.Tensor: sm_count = get_device_sm_count(query.device) if out is None: @@ -300,6 +302,7 @@ def _paged_run( key_block_scales, value_block_scales, skip_softmax_threshold_scale_factor, + uses_shared_paged_kv_idx, ) return out @@ -670,6 +673,7 @@ def paged_run( key_block_scales: Optional[torch.Tensor] = None, value_block_scales: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, + uses_shared_paged_kv_idx: bool = True, ) -> None: if backend == "trtllm-gen": assert maybe_lse is None @@ -706,6 +710,7 @@ def paged_run( key_block_scales=key_block_scales, value_block_scales=value_block_scales, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) elif backend == "fa2": assert not is_float8(q) @@ -844,6 +849,7 @@ def _fake_paged_run( key_block_scales: Optional[torch.Tensor] = None, value_block_scales: Optional[torch.Tensor] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, + uses_shared_paged_kv_idx: bool = True, ) -> None: pass @@ -2373,6 +2379,7 @@ def run( key_block_scales, value_block_scales, skip_softmax_threshold_scale_factor, + True, # uses_shared_paged_kv_idx ] assert self._cached_module is not None, "cached module is not initialized" @@ -3696,6 +3703,7 @@ def trtllm_batch_context_with_kv_cache( Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, + uses_shared_paged_kv_idx: bool = True, ) -> Union[torch.Tensor, FP4Tensor]: """ Parameters @@ -3711,7 +3719,10 @@ def trtllm_batch_context_with_kv_cache( workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use. workspace block_tables : torch.Tensor - page_table of kv cache, [batch_size, num_pages] + Page table of kv cache. + When ``uses_shared_paged_kv_idx`` is True (default): shape ``[batch_size, max_num_pages_per_seq]``. + When ``uses_shared_paged_kv_idx`` is False: shape ``[batch_size, 2, max_num_pages_per_seq]`` + where dim 1 distinguishes K (0) and V (1) page indices. seq_lens : torch.Tensor A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]`` max_q_len : int @@ -3757,6 +3768,10 @@ def trtllm_batch_context_with_kv_cache( If no value is provided, then standard attention is used. Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation. The actual threshold value equals the provided threshold_scale_factor divided by the context length. + uses_shared_paged_kv_idx : bool = True + Whether the K and V page indices are shared as a unified index. + True (default) uses vLLM/FlashInfer layout with a 2D page table. + False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``. Returns ------- out: Union[torch.Tensor, FP4Tensor] @@ -3889,6 +3904,7 @@ def trtllm_batch_context_with_kv_cache( bmm1_scale = bmm1_scale * log2e if isinstance(bmm2_scale, torch.Tensor): assert bmm2_scale.dtype == torch.float32 + _check_block_tables_shape(block_tables, uses_shared_paged_kv_idx) workspace_size = workspace_buffer.numel() * workspace_buffer.element_size() run_func( out, @@ -3917,6 +3933,7 @@ def trtllm_batch_context_with_kv_cache( key_block_scales, value_block_scales, skip_softmax_threshold_scale_factor, + uses_shared_paged_kv_idx, ) return ( out diff --git a/flashinfer/utils.py b/flashinfer/utils.py index bbcd3f3b96..e56098976f 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -603,6 +603,31 @@ def determine_mla_backend(device: torch.device) -> str: return "fa3" if is_sm90a_supported(device) else "fa2" +def _check_block_tables_shape( + block_tables: torch.Tensor, + uses_shared_paged_kv_idx: bool, +) -> None: + """Validate ``block_tables`` rank against the paged KV index layout. + + Shared layout (``uses_shared_paged_kv_idx=True``) expects a 2-D tensor + ``[batch_size, max_num_pages_per_seq]``. Separate layout expects a 3-D + tensor ``[batch_size, 2, max_num_pages_per_seq]`` where dim1 distinguishes + K (0) and V (1) page indices. + """ + expected_ndim = 2 if uses_shared_paged_kv_idx else 3 + if block_tables.ndim != expected_ndim: + layout = "shared" if uses_shared_paged_kv_idx else "separate" + raise ValueError( + f"block_tables must be {expected_ndim}D for {layout} paged KV layout, " + f"got ndim={block_tables.ndim}" + ) + if not uses_shared_paged_kv_idx and block_tables.shape[1] != 2: + raise ValueError( + f"block_tables must have shape[1]==2 for separate KV indices, " + f"got shape={block_tables.shape}" + ) + + def check_shape_dtype_device( x: torch.Tensor, expected_shape: Optional[Sequence[int]], diff --git a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h index 41f6e541db..44c7a749c1 100644 --- a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h +++ b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h @@ -300,6 +300,9 @@ struct TllmGenFmhaRunnerParams { bool mSparseMla; // The top k value for sparse MLA. int mSparseMlaTopK; + // Whether the indices for K & V pages are shared as unified index. + // true -> vLLM/FlashInfer; false -> TRT-LLM. + bool mUsesSharedPagedKvIdx; // The cuda stream. cudaStream_t stream; // Whether to enable PDL (Programmatic Dependent Launch). diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index 86ce33f737..2bbdb1800a 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -203,8 +203,9 @@ struct KernelParams { int32_t mSparseMlaTopK; // The flag to use block sparse attention. bool mUseBlockSparseAttention; - // Whether the indices for K & V pages are shared as unified index (vLLM/FlashInfer). - bool mUsesSharedPagedKvIdx; + // Whether the indices for K & V pages are shared as unified index. + // true -> vLLM/FlashInfer; false -> TRT-LLM. + bool mUsesSharedPagedKvIdx{true}; // Create the TMA shape/stride for Q. template @@ -831,7 +832,7 @@ struct KernelParams { // TODO: Integrate trtllm block-sparse attention kernels when needed. params.mUseBlockSparseAttention = false; // Whether the indices for K & V pages are shared as unified index (vLLM/FlashInfer). - params.mUsesSharedPagedKvIdx = true; + params.mUsesSharedPagedKvIdx = options.mUsesSharedPagedKvIdx; return params; } }; diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index b2e1306e43..efe5981dd3 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -1,4 +1,5 @@ import math +from typing import Union import pytest import torch @@ -220,6 +221,55 @@ def create_page_table(batch_size: int, seq_lens: torch.Tensor, page_size: int): return page_tables, all_page_ids, page_per_seq +def prepare_paged_kv_for_kernel( + kv_cache: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + page_table: torch.Tensor, + uses_shared_paged_kv_idx: bool, + kv_block_scales: Union[tuple[torch.Tensor, torch.Tensor], None] = None, +) -> tuple[ + Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + torch.Tensor, + Union[tuple[torch.Tensor, torch.Tensor], None], +]: + """Convert shared-page KV cache layout to separate-page layout for TRT-LLM. + + When uses_shared_paged_kv_idx is True (FlashInfer/vLLM style), returns the + original tensors unchanged. + + When False (TRT-LLM style), interleaves K and V pages so original page p + becomes K at index 2*p and V at 2*p+1, and builds a + [batch_size, 2, maxPages] page table where dim 1 distinguishes K (0) and V (1). + Returns the reshaped cache as a (cache, cache) tuple so both K and V share + the same base pointer. Block scales, if provided, are interleaved the same + way since the kernel uses the same page indices to access them. + + Returns: + (kv_cache_arg, page_table, kv_block_scales) ready to pass to the kernel. + """ + if uses_shared_paged_kv_idx: + return kv_cache, page_table, kv_block_scales + + def _interleave_kv(k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Stack [num_pages,...] K and V along dim-1 then flatten to [2*num_pages,...].""" + return torch.stack([k, v], dim=1).reshape(k.shape[0] * 2, *k.shape[1:]) + + if isinstance(kv_cache, tuple): + k_cache, v_cache = kv_cache + interleaved = _interleave_kv(k_cache, v_cache) + else: + num_pages = kv_cache.shape[0] + interleaved = kv_cache.reshape(num_pages * 2, *kv_cache.shape[2:]) + + trtllm_page_table = torch.stack([2 * page_table, 2 * page_table + 1], dim=1) + + if kv_block_scales is not None: + k_sf, v_sf = kv_block_scales + interleaved_sf = _interleave_kv(k_sf, v_sf) + kv_block_scales = (interleaved_sf, interleaved_sf) + + return (interleaved, interleaved), trtllm_page_table, kv_block_scales + + def flatten_paged_kv( ref_kv_cache: torch.Tensor, page_table: torch.Tensor, @@ -444,6 +494,7 @@ def _test_trtllm_batch_prefill( head_dim: int, non_contiguous_query: bool = False, skips_softmax: bool = False, + uses_shared_paged_kv_idx: bool = True, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: @@ -489,6 +540,12 @@ def _test_trtllm_batch_prefill( kv_indptr = generate_cumsum_lens(page_per_seq) kv_last_page_len = get_last_page_len(seq_lens, page_size) + kv_cache_kernel, page_table_kernel, kv_block_scales_kernel = ( + prepare_paged_kv_for_kernel( + kv_cache, page_table, uses_shared_paged_kv_idx, kv_block_scales + ) + ) + workspace_buffer, workspace_buffer_ref = create_workspace_buffers(GPU_DEVICE) # Create output tensor and related data @@ -576,9 +633,9 @@ def _test_trtllm_batch_prefill( output = flashinfer.prefill.trtllm_batch_context_with_kv_cache( q_input, - kv_cache, + kv_cache_kernel, workspace_buffer, - page_table, + page_table_kernel, seq_lens.to(GPU_DEVICE), torch.max(q_lens).item(), torch.max(seq_lens).item(), @@ -595,8 +652,9 @@ def _test_trtllm_batch_prefill( kv_layout=kv_layout, enable_pdl=enable_pdl, sinks=(sink if enable_sink else None), - kv_block_scales=kv_block_scales, + kv_block_scales=kv_block_scales_kernel, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future @@ -635,8 +693,8 @@ def _test_trtllm_batch_prefill( ) if ( - o_dtype != "nvfp4" and kv_dtype != "nvfp4" - ): # wrapper api does not support fp4 output/kv yet. + o_dtype != "nvfp4" and kv_dtype != "nvfp4" and uses_shared_paged_kv_idx + ): # wrapper api does not support fp4 output/kv or separate KV page indices yet. # test wrapper with trtllm-gen backend wrapper_trtllm_gen = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, backend="trtllm-gen" @@ -701,6 +759,7 @@ def _test_trtllm_batch_prefill( @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("non_contiguous_query", [False, True]) @pytest.mark.parametrize("skips_softmax", [False, True]) +@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False]) def test_trtllm_batch_prefill( kv_layout: str, batch_size: int, @@ -718,6 +777,7 @@ def test_trtllm_batch_prefill( head_dim: int, non_contiguous_query: bool, skips_softmax: bool, + uses_shared_paged_kv_idx: bool, ): _test_trtllm_batch_prefill( kv_layout, @@ -737,6 +797,7 @@ def test_trtllm_batch_prefill( head_dim, non_contiguous_query=non_contiguous_query, skips_softmax=skips_softmax, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) @@ -760,6 +821,7 @@ def test_trtllm_batch_prefill( @pytest.mark.parametrize("max_kv_len", [8192]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("skips_softmax", [False, True]) +@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False]) def test_trtllm_batch_prefill_bs1( kv_layout: str, batch_size: int, @@ -776,6 +838,7 @@ def test_trtllm_batch_prefill_bs1( max_kv_len: int, head_dim: int, skips_softmax: bool, + uses_shared_paged_kv_idx: bool, ): _test_trtllm_batch_prefill( kv_layout, @@ -794,6 +857,7 @@ def test_trtllm_batch_prefill_bs1( False, head_dim, skips_softmax=skips_softmax, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) @@ -817,6 +881,7 @@ def _test_trtllm_batch_decode( max_q_len: int | None = None, non_contiguous_query: bool = False, skips_softmax: bool = False, + uses_shared_paged_kv_idx: bool = True, ) -> None: """ Common function for testing trtllm-gen decode. @@ -846,6 +911,10 @@ def _test_trtllm_batch_decode( if backend == "xqa" and q_dtype == "fp8": pytest.skip("xqa backend only supports fp16 and bf16 query") + # XQA backend doesn't support non-shared page indices + if backend == "xqa" and not uses_shared_paged_kv_idx: + pytest.skip("xqa backend does not support non-shared page indices") + if o_dtype == "nvfp4" and ( q_len_per_req is not None and q_len_per_req > 1 @@ -898,6 +967,12 @@ def _test_trtllm_batch_decode( kv_indptr = generate_cumsum_lens(page_per_seq) kv_last_page_len = get_last_page_len(seq_lens, page_size) + kv_cache_arg, page_table_kernel, kv_block_scales_kernel = ( + prepare_paged_kv_for_kernel( + kv_cache, page_table, uses_shared_paged_kv_idx, kv_block_scales + ) + ) + workspace_buffer, workspace_buffer_ref = create_workspace_buffers(GPU_DEVICE) # Create output tensor and related data @@ -1009,9 +1084,9 @@ def _test_trtllm_batch_decode( output = flashinfer.decode.trtllm_batch_decode_with_kv_cache( q_input, - kv_cache, + kv_cache_arg, workspace_buffer, - page_table, + page_table_kernel, seq_lens.to(GPU_DEVICE), torch.max(seq_lens).item(), bmm1_scale, @@ -1031,7 +1106,8 @@ def _test_trtllm_batch_decode( max_q_len=max_q_len if max_q_len is not None else None, cum_seq_lens_q=q_indptr if max_q_len is not None else None, skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, - kv_block_scales=kv_block_scales, + kv_block_scales=kv_block_scales_kernel, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) if backend == "trtllm-gen": # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero @@ -1085,7 +1161,8 @@ def _test_trtllm_batch_decode( and backend == "trtllm-gen" and q_len_per_req is not None # only test for the case all requests have the same q_len - ): # wrapper api does not support fp4 output/kv yet. + and uses_shared_paged_kv_idx + ): # wrapper api does not support fp4 output/kv or separate KV page indices yet. # test wrapper with trtllm-gen backend wrapper_trtllm_gen = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout, backend="trtllm-gen" @@ -1184,6 +1261,7 @@ def _test_trtllm_batch_decode( @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("non_contiguous_query", [False, True]) @pytest.mark.parametrize("skips_softmax", [False, True]) +@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False]) def test_trtllm_batch_decode( backend: str, kv_layout: str, @@ -1202,6 +1280,7 @@ def test_trtllm_batch_decode( head_dim: int, non_contiguous_query: bool, skips_softmax: bool, + uses_shared_paged_kv_idx: bool, ): # xqa backend does not support non-contiguous query yet if backend == "xqa" and non_contiguous_query: @@ -1227,6 +1306,7 @@ def test_trtllm_batch_decode( kv_dtype in ("fp8", "nvfp4"), non_contiguous_query=non_contiguous_query, skips_softmax=skips_softmax, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) @@ -1251,6 +1331,7 @@ def test_trtllm_batch_decode( @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("device_scale", [True, False]) @pytest.mark.parametrize("skips_softmax", [False, True]) +@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False]) def test_trtllm_batch_decode_bs1( kv_layout: str, batch_size: int, @@ -1268,6 +1349,7 @@ def test_trtllm_batch_decode_bs1( head_dim: int, device_scale: bool, skips_softmax: bool, + uses_shared_paged_kv_idx: bool, ) -> None: # Small number of test cases for batch size 1 _test_trtllm_batch_decode( @@ -1288,6 +1370,7 @@ def test_trtllm_batch_decode_bs1( head_dim, device_scale, skips_softmax=skips_softmax, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) @@ -1323,6 +1406,7 @@ def test_trtllm_batch_decode_bs1( @pytest.mark.parametrize("head_dim", [256]) @pytest.mark.parametrize("device_scale", [True, False]) @pytest.mark.parametrize("skips_softmax", [False, True]) +@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False]) def test_trtllm_batch_decode_head_dim_256( kv_layout: str, batch_size: int, @@ -1340,6 +1424,7 @@ def test_trtllm_batch_decode_head_dim_256( head_dim: int, device_scale: bool, skips_softmax: bool, + uses_shared_paged_kv_idx: bool, ): # Small number of test cases for head_dim = 256 _test_trtllm_batch_decode( @@ -1360,6 +1445,7 @@ def test_trtllm_batch_decode_head_dim_256( head_dim, device_scale, skips_softmax=skips_softmax, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) @@ -1390,6 +1476,7 @@ def test_trtllm_batch_decode_head_dim_256( @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("device_scale", [True, False]) @pytest.mark.parametrize("skips_softmax", [False]) +@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False]) def test_trtllm_batch_decode_long_sequence_length( kv_layout: str, batch_size: int, @@ -1407,6 +1494,7 @@ def test_trtllm_batch_decode_long_sequence_length( head_dim: int, device_scale: bool, skips_softmax: bool, + uses_shared_paged_kv_idx: bool, ) -> None: # Small number of test cases for long sequence length _test_trtllm_batch_decode( @@ -1427,6 +1515,7 @@ def test_trtllm_batch_decode_long_sequence_length( head_dim, device_scale, skips_softmax=skips_softmax, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) @@ -1674,6 +1763,7 @@ def make_query_non_contiguous( @pytest.mark.parametrize("max_in_kv_len", [110]) @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("skips_softmax", [False, True]) +@pytest.mark.parametrize("uses_shared_paged_kv_idx", [False, True]) def test_trtllm_batch_decode_spec( backend: str, kv_layout: str, @@ -1691,6 +1781,7 @@ def test_trtllm_batch_decode_spec( max_in_kv_len: int, head_dim: int, skips_softmax: bool, + uses_shared_paged_kv_idx: bool, ) -> None: _test_trtllm_batch_decode( backend, @@ -1710,4 +1801,5 @@ def test_trtllm_batch_decode_spec( head_dim, max_q_len=max_q_len, skips_softmax=skips_softmax, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index 8a1628d2a9..a6e8253049 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -226,6 +226,7 @@ def trtllm_batch_decode_mla( backend: str, MAX_SEQ_LEN: int, skips_softmax: bool, + uses_shared_paged_kv_idx: bool = True, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if backend == "xqa": @@ -235,6 +236,8 @@ def trtllm_batch_decode_mla( pytest.skip( "XQA MLA only supports q_len_per_request == 1 and dtype == torch.float8_e4m3fn" ) + if not uses_shared_paged_kv_idx: + pytest.skip("xqa backend does not support separate KV page indices") if backend == "trtllm-gen": if compute_capability[0] != 10: pytest.skip("TRTLLM-GEN MLA only supports SM100 and SM103 GPUs") @@ -292,6 +295,13 @@ def trtllm_batch_decode_mla( ] block_id += num_blocks_needed + # For separate KV page indices, duplicate the page table rows since + # MLA K and V share the same compressed representation. + if not uses_shared_paged_kv_idx: + block_tables_kernel = torch.stack([block_tables, block_tables], dim=1) + else: + block_tables_kernel = block_tables + # Create interleaved KV cache # Allocate more than needed blocks, block_id is just enough, to mimick real-world cases kv_cache = torch.randn( @@ -329,7 +339,7 @@ def trtllm_batch_decode_mla( qk_nope_head_dim=layer_dimensions.head_dimensions.qk_nope_head_dim, kv_lora_rank=layer_dimensions.head_dimensions.kv_lora_rank, qk_rope_head_dim=layer_dimensions.head_dimensions.qk_rope_head_dim, - block_tables=block_tables, + block_tables=block_tables_kernel, seq_lens=seq_lens_tensor, max_seq_len=max_seq_len, bmm1_scale=scale / ((128 + 64) ** 0.5), @@ -337,6 +347,7 @@ def trtllm_batch_decode_mla( skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor, enable_pdl=enable_pdl, backend=backend, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, ) # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future @@ -714,6 +725,7 @@ def trtllm_batch_decode_mla_sparse( @pytest.mark.parametrize("enable_pdl", [True, False, None]) @pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"]) @pytest.mark.parametrize("skips_softmax", [False, True]) +@pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False]) def test_trtllm_batch_decode_mla( layer_dimensions: MLALayerDimensions, batch_size: int, @@ -725,6 +737,7 @@ def test_trtllm_batch_decode_mla( enable_pdl: bool, backend: str, skips_softmax: bool, + uses_shared_paged_kv_idx: bool, ): if backend == "xqa" and layer_dimensions.head_dimensions == smaller_mla_dimensions: pytest.skip("XQA MLA does not support smaller MLA dimensions yet.") @@ -743,6 +756,7 @@ def test_trtllm_batch_decode_mla( backend, 1024, skips_softmax, + uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, )