Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ void trtllm_paged_attention_launcher(
const float* bmm1_scale_log2_ptr, const float* bmm2_scale_ptr, double o_sf_scale,
int64_t o_sf_vec_size, int64_t o_sf_start_index, int64_t window_left, int64_t sum_seq_q,
int64_t sparse_mla_top_k, float skip_softmax_threshold_scale_factor, bool skips_softmax,
int64_t sm_count, bool enable_pdl, int64_t workspace_size, cudaStream_t stream) {
bool uses_shared_paged_kv_idx, int64_t sm_count, bool enable_pdl, int64_t workspace_size,
cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads must be a multiple of num_kv_heads, got num_kv_heads: " << num_kv_heads
Expand Down Expand Up @@ -145,6 +146,7 @@ void trtllm_paged_attention_launcher(
window_left == -1 ? INT_MAX : window_left + 1; // disable window attention by INT_MAX
runner_params.mMaxSeqLenQ = max_q_len;
runner_params.mSumOfSeqLensQ = sum_seq_q;
runner_params.mUsesSharedPagedKvIdx = uses_shared_paged_kv_idx;
runner_params.ptrAttentionSinks = attention_sinks;
runner_params.enable_pdl = enable_pdl;

Expand Down Expand Up @@ -236,7 +238,8 @@ void trtllm_paged_attention_decode(
int64_t window_left, int64_t sparse_mla_top_k, int64_t sm_count, bool enable_pdl,
int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> cum_seq_lens_q, Optional<TensorView> key_block_scales,
Optional<TensorView> value_block_scales, Optional<float> skip_softmax_threshold_scale_factor) {
Optional<TensorView> value_block_scales, Optional<float> skip_softmax_threshold_scale_factor,
Optional<bool> uses_shared_paged_kv_idx) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim());
Expand Down Expand Up @@ -269,6 +272,10 @@ void trtllm_paged_attention_decode(
bool is_fp4_kv = is_4bit(kv_data_type);
int stride_idx_factor = is_fp4_kv ? 2 : 1;

// FlashInfer/vLLM layout -> true; TRT-LLM layout -> false.
// Default to flashinfer/vLLM layout.
bool const uses_shared_paged_kv_idx_value = uses_shared_paged_kv_idx.value_or(true);

// Assume HND layout after Python-side transpose: [..., H, N, D]
int page_size = key_cache.size(-2);
int num_kv_heads = key_cache.size(-3);
Expand Down Expand Up @@ -337,8 +344,8 @@ void trtllm_paged_attention_decode(
q_stride_heads, kv_stride_keys_values, kv_stride_heads, kv_stride_batch,
max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value, bmm1_scale_log2_ptr,
bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left, sum_seq_q,
sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax, sm_count,
enable_pdl, workspace_size, stream);
sparse_mla_top_k, skip_softmax_threshold_scale_factor_value, skips_softmax,
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, stream);
}

void trtllm_paged_attention_context(
Expand All @@ -350,7 +357,7 @@ void trtllm_paged_attention_context(
int64_t window_left, TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count,
bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
Optional<TensorView> key_block_scales, Optional<TensorView> value_block_scales,
Optional<float> skip_softmax_threshold_scale_factor) {
Optional<float> skip_softmax_threshold_scale_factor, Optional<bool> uses_shared_paged_kv_idx) {
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype());
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
Expand All @@ -373,6 +380,10 @@ void trtllm_paged_attention_context(
bool is_fp4_kv = is_4bit(kv_data_type);
int stride_idx_factor = is_fp4_kv ? 2 : 1;

// FlashInfer/vLLM layout -> true; TRT-LLM layout -> false.
// Default to flashinfer/vLLM layout.
bool const uses_shared_paged_kv_idx_value = uses_shared_paged_kv_idx.value_or(true);

// Assume HND layout after Python-side transpose: [..., H, N, D]
int page_size = key_cache.size(-2);
int num_kv_heads = key_cache.size(-3);
Expand Down Expand Up @@ -444,7 +455,7 @@ void trtllm_paged_attention_context(
kv_stride_heads, kv_stride_batch, max_num_blocks_per_seq, bmm1_scale_value, bmm2_scale_value,
bmm1_scale_log2_ptr, bmm2_scale_ptr, o_sf_scale, o_sf_vec_size, o_sf_start_index, window_left,
sum_seq_q, /*sparse_mla_top_k=*/0, skip_softmax_threshold_scale_factor_value, skips_softmax,
sm_count, enable_pdl, workspace_size, stream);
uses_shared_paged_kv_idx_value, sm_count, enable_pdl, workspace_size, stream);
}

void trtllm_ragged_attention_launcher(
Expand Down
25 changes: 24 additions & 1 deletion flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
MaskMode,
PosEncodingMode,
TensorLayout,
_check_block_tables_shape,
_check_cached_qkv_data_type,
_check_kv_layout,
_check_pos_encoding_mode,
Expand Down Expand Up @@ -1429,6 +1430,7 @@ def run(
key_block_scales,
value_block_scales,
skip_softmax_threshold_scale_factor,
True, # uses_shared_paged_kv_idx
Comment thread
DomBrown marked this conversation as resolved.
]

self._cached_module.paged_run(*run_args)
Expand Down Expand Up @@ -1982,6 +1984,7 @@ def _paged_run(
key_block_scales: Optional[torch.Tensor] = None,
value_block_scales: Optional[torch.Tensor] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
) -> torch.Tensor:
if out is None:
out = torch.empty_like(query)
Expand Down Expand Up @@ -2026,6 +2029,7 @@ def _paged_run(
key_block_scales,
value_block_scales,
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
)
return out

Expand Down Expand Up @@ -2090,6 +2094,7 @@ def paged_run(
key_block_scales: Optional[torch.Tensor] = None,
value_block_scales: Optional[torch.Tensor] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
) -> None:
assert maybe_lse is None
assert paged_kv_cache is not None
Expand Down Expand Up @@ -2119,6 +2124,7 @@ def paged_run(
key_block_scales=key_block_scales,
value_block_scales=value_block_scales,
skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx=uses_shared_paged_kv_idx,
)

@register_fake_op(f"flashinfer::{uri}_paged_run")
Expand Down Expand Up @@ -2161,6 +2167,7 @@ def _fake_paged_run(
key_block_scales: Optional[torch.Tensor] = None,
value_block_scales: Optional[torch.Tensor] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
) -> None:
pass

Expand Down Expand Up @@ -2203,6 +2210,7 @@ def trtllm_batch_decode_with_kv_cache(
] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
kv_cache_sf: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
uses_shared_paged_kv_idx: bool = True,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
) -> Union[torch.Tensor, FP4Tensor]:
"""
Parameters
Expand All @@ -2221,7 +2229,10 @@ def trtllm_batch_decode_with_kv_cache(
workspace

block_tables : torch.Tensor
page_table of kv cache, [batch_size, num_pages]
Page table of kv cache.
When ``uses_shared_paged_kv_idx`` is True (default): shape ``[batch_size, max_num_pages_per_seq]``.
When ``uses_shared_paged_kv_idx`` is False: shape ``[batch_size, 2, max_num_pages_per_seq]``
where dim 1 distinguishes K (0) and V (1) page indices.
Comment thread
DomBrown marked this conversation as resolved.

seq_lens : torch.Tensor
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``
Expand Down Expand Up @@ -2293,6 +2304,11 @@ def trtllm_batch_decode_with_kv_cache(
Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation.
The actual threshold value equals the provided threshold_scale_factor divided by the context length.

uses_shared_paged_kv_idx : bool = True
Whether the K and V page indices are shared as a unified index.
True (default) uses vLLM/FlashInfer layout with a 2D page table.
False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``.

Returns
-------
out : Union[torch.Tensor, FP4Tensor]
Expand Down Expand Up @@ -2353,6 +2369,10 @@ def trtllm_batch_decode_with_kv_cache(
raise ValueError("xqa backend does not support o_sf_scale or o_sf_vec_size")
if max_q_len is not None or cum_seq_lens_q is not None:
raise ValueError("xqa backend does not support cum_seq_lens_q")
if not uses_shared_paged_kv_idx:
raise ValueError(
"xqa backend does not support uses_shared_paged_kv_idx=False"
)

# Handle out and out_dtype
if out_dtype is None:
Expand Down Expand Up @@ -2486,6 +2506,8 @@ def trtllm_batch_decode_with_kv_cache(
assert max_q_len is not None
batch_size = cum_seq_lens_q.size(0) - 1

_check_block_tables_shape(block_tables, uses_shared_paged_kv_idx)

run_func(
out,
out_scale_factor,
Expand Down Expand Up @@ -2513,6 +2535,7 @@ def trtllm_batch_decode_with_kv_cache(
k_block_scales,
v_block_scales,
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
)

return (
Expand Down
25 changes: 23 additions & 2 deletions flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .jit.mla import gen_mla_module
from .utils import (
MaskMode,
_check_block_tables_shape,
check_shape_dtype_device,
determine_mla_backend,
device_support_pdl,
Expand Down Expand Up @@ -134,6 +135,7 @@ def _check_trtllm_gen_mla_shape(
sparse_mla_top_k: int,
page_table: torch.Tensor,
page_size: int,
uses_shared_paged_kv_idx: bool = True,
) -> torch.Tensor:
if query.ndim != 4:
raise ValueError(f"Expected query.ndim == 4, got {query.ndim}")
Expand Down Expand Up @@ -173,7 +175,9 @@ def _check_trtllm_gen_mla_shape(
f"Expected page_table.shape == (num_seqs, num_tokens, sparse_mla_top_k), got {page_table_shape}"
)
else:
B_block_table, block_num = page_table.shape
_check_block_tables_shape(page_table, uses_shared_paged_kv_idx)
B_block_table = page_table.shape[0]
block_num = page_table.shape[-1]
Comment thread
coderabbitai[bot] marked this conversation as resolved.
block_size = page_size
if num_seqs != B_block_table:
raise ValueError(
Expand Down Expand Up @@ -603,6 +607,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
skip_softmax_threshold_scale_factor: Optional[float] = None,
enable_pdl: bool | None = None,
backend: str = "auto",
uses_shared_paged_kv_idx: bool = True,
) -> torch.Tensor:
"""
Parameters
Expand All @@ -614,7 +619,11 @@ def trtllm_batch_decode_with_kv_cache_mla(
kv_lora_rank: kv_lora_rank, must be 512 or 256
qk_rope_head_dim: qk_rope_head_dim, must be 64
sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA.
block_tables: page_table of kv cache, [batch_size, num_pages]
block_tables: page table of kv cache.
When ``uses_shared_paged_kv_idx`` is True (default): shape ``[batch_size, max_num_pages_per_seq]``.
When ``uses_shared_paged_kv_idx`` is False: shape ``[batch_size, 2, max_num_pages_per_seq]``
where dim 1 distinguishes K (0) and V (1) page indices. For MLA both rows will
typically be identical since K and V share the same compressed representation.
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
Expand All @@ -633,6 +642,11 @@ def trtllm_batch_decode_with_kv_cache_mla(
When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
For sm_100 and sm_103 (blackwell architecture), ``auto`` will choose ``trtllm-gen`` backend.
For sm_120 (blackwell architecture), ``auto`` will choose ``xqa`` backend.
uses_shared_paged_kv_idx : bool = True
Whether the K and V page indices are shared as a unified index.
True (default) uses vLLM/FlashInfer layout with a 2D page table.
False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``.
False is only supported for trtllm-gen backend.

Note
----
Expand Down Expand Up @@ -679,6 +693,10 @@ def trtllm_batch_decode_with_kv_cache_mla(
)
if skip_softmax_threshold_scale_factor is not None:
raise ValueError("skip_softmax is not supported for XQA backend")
if not uses_shared_paged_kv_idx:
raise ValueError(
"XQA MLA does not support separate KV page indices (uses_shared_paged_kv_idx=False)"
)
return xqa_batch_decode_with_kv_cache_mla(
query,
kv_cache,
Expand Down Expand Up @@ -721,6 +739,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
sparse_mla_top_k,
block_tables,
block_size,
uses_shared_paged_kv_idx,
)

if out is None:
Expand Down Expand Up @@ -767,6 +786,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
None, # key_block_scales
None, # value_block_scales
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
)

return out
Expand Down Expand Up @@ -848,6 +868,7 @@ def xqa_batch_decode_with_kv_cache_mla(
0, # sparse_mla_top_k
block_tables,
block_size,
True, # XQA always uses shared paged KV index layout
)

if out is None:
Expand Down
19 changes: 18 additions & 1 deletion flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
MaskMode,
PosEncodingMode,
TensorLayout,
_check_block_tables_shape,
_check_cached_qkv_data_type,
_check_kv_layout,
_check_pos_encoding_mode,
Expand Down Expand Up @@ -264,6 +265,7 @@ def _paged_run(
key_block_scales: Optional[torch.Tensor] = None,
value_block_scales: Optional[torch.Tensor] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
) -> torch.Tensor:
sm_count = get_device_sm_count(query.device)
if out is None:
Expand Down Expand Up @@ -300,6 +302,7 @@ def _paged_run(
key_block_scales,
value_block_scales,
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
)
return out

Expand Down Expand Up @@ -670,6 +673,7 @@ def paged_run(
key_block_scales: Optional[torch.Tensor] = None,
value_block_scales: Optional[torch.Tensor] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
) -> None:
if backend == "trtllm-gen":
assert maybe_lse is None
Expand Down Expand Up @@ -706,6 +710,7 @@ def paged_run(
key_block_scales=key_block_scales,
value_block_scales=value_block_scales,
skip_softmax_threshold_scale_factor=skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx=uses_shared_paged_kv_idx,
)
elif backend == "fa2":
assert not is_float8(q)
Expand Down Expand Up @@ -844,6 +849,7 @@ def _fake_paged_run(
key_block_scales: Optional[torch.Tensor] = None,
value_block_scales: Optional[torch.Tensor] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
) -> None:
pass

Expand Down Expand Up @@ -2373,6 +2379,7 @@ def run(
key_block_scales,
value_block_scales,
skip_softmax_threshold_scale_factor,
True, # uses_shared_paged_kv_idx
]

assert self._cached_module is not None, "cached module is not initialized"
Expand Down Expand Up @@ -3696,6 +3703,7 @@ def trtllm_batch_context_with_kv_cache(
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
uses_shared_paged_kv_idx: bool = True,
) -> Union[torch.Tensor, FP4Tensor]:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"""
Parameters
Expand All @@ -3711,7 +3719,10 @@ def trtllm_batch_context_with_kv_cache(
workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use.
workspace
block_tables : torch.Tensor
page_table of kv cache, [batch_size, num_pages]
Page table of kv cache.
When ``uses_shared_paged_kv_idx`` is True (default): shape ``[batch_size, max_num_pages_per_seq]``.
When ``uses_shared_paged_kv_idx`` is False: shape ``[batch_size, 2, max_num_pages_per_seq]``
where dim 1 distinguishes K (0) and V (1) page indices.
seq_lens : torch.Tensor
A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``
max_q_len : int
Expand Down Expand Up @@ -3757,6 +3768,10 @@ def trtllm_batch_context_with_kv_cache(
If no value is provided, then standard attention is used.
Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation.
The actual threshold value equals the provided threshold_scale_factor divided by the context length.
uses_shared_paged_kv_idx : bool = True
Whether the K and V page indices are shared as a unified index.
True (default) uses vLLM/FlashInfer layout with a 2D page table.
False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``.
Returns
-------
out: Union[torch.Tensor, FP4Tensor]
Expand Down Expand Up @@ -3889,6 +3904,7 @@ def trtllm_batch_context_with_kv_cache(
bmm1_scale = bmm1_scale * log2e
if isinstance(bmm2_scale, torch.Tensor):
assert bmm2_scale.dtype == torch.float32
_check_block_tables_shape(block_tables, uses_shared_paged_kv_idx)
workspace_size = workspace_buffer.numel() * workspace_buffer.element_size()
run_func(
out,
Expand Down Expand Up @@ -3917,6 +3933,7 @@ def trtllm_batch_context_with_kv_cache(
key_block_scales,
value_block_scales,
skip_softmax_threshold_scale_factor,
uses_shared_paged_kv_idx,
)
return (
out
Expand Down
Loading
Loading