Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
bench-mla-decode:
python benchmarks/flashinfer_benchmark.py \
--routine BatchMLAPagedAttentionWrapper \
--batch_size 1024 \
--s_kv 8192 \
--num_qo_heads 32 \
--num_kv_heads 1 \
--head_dim_ckv 256 \
--head_dim_kpe 64 \
--page_size 64 \
--backends trtllm-native \
--q_dtype bfloat16 \
--kv_dtype bfloat16 \
--s_qo 1 \
--num_iters 500

bench-mla-prefill-bf16:
python benchmarks/flashinfer_benchmark.py \
--routine BatchPrefillWithRaggedKVCacheWrapper \
--batch_size 2 \
--s_kv 8192 \
--s_qo 8192 \
--num_qo_heads 128 \
--num_kv_heads 128 \
--head_dim_qk 128 \
--head_dim_vo 128 \
--page_size 64 \
--backends trtllm-native \
--q_dtype bfloat16 \
--kv_dtype bfloat16 \
--num_iters 100

bench-mla-prefill-fp8:
python benchmarks/flashinfer_benchmark.py \
--routine BatchPrefillWithRaggedKVCacheWrapper \
--batch_size 2 \
--s_kv 8192 \
--s_qo 8192 \
--num_qo_heads 128 \
--num_kv_heads 128 \
--head_dim_qk 128 \
--head_dim_vo 128 \
--page_size 64 \
--backends trtllm-native \
--q_dtype fp8_e4m3 \
--kv_dtype fp8_e4m3 \
--num_iters 100
24 changes: 13 additions & 11 deletions benchmarks/routines/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ def run_backend_wrapper(
def testBatchPrefillWithRaggedKVCacheWrapper(args):
"""
Test BatchPrefillWithRaggedKVCacheWrapper API and equivalent cuDNN API.
Supports fa2, fa3, cutlass, and cudnn backends.
Supports fa2, fa3, cutlass, cudnn and trtllm-gen backends.

This test:
1. Creates ragged KV cache and query tensors for prefill
Expand Down Expand Up @@ -1460,15 +1460,11 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
backends.remove("trtllm-gen")
if "trtllm-native" in backends:
remove_trtllm_native = False
if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] trtllm-native backend does not support FP8. Skipping.")
remove_trtllm_native = True
if not (head_dim_qk == 192 and head_dim_vo == 128):
if not (head_dim_qk == 192 and head_dim_vo == 128) and not (
head_dim_qk == 128 and head_dim_vo == 128
):
print(
"[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128"
"[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128 or head_dim_qk == 128 and head_dim_vo == 128. Skipping."
)
remove_trtllm_native = True
if remove_trtllm_native:
Expand Down Expand Up @@ -1733,6 +1729,12 @@ def run_backend_wrapper(
is_cuda_graph_compatible=True,
)[0]
elif backend == "trtllm-native":
# For FP8: bmm1_scale = q_scale * k_scale * sm_scale,
# bmm2_scale = v_scale
_k_scale = k_scale if k_scale is not None else 1.0
_v_scale = v_scale if v_scale is not None else 1.0
_bmm1_scale = scale * _k_scale
_bmm2_scale = _v_scale
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
query=q,
key=k,
Expand All @@ -1741,8 +1743,8 @@ def run_backend_wrapper(
seq_lens=actual_seq_lens_kv_device,
max_q_len=s_qo,
max_kv_len=s_kv,
bmm1_scale=scale,
bmm2_scale=1.0,
bmm1_scale=_bmm1_scale,
bmm2_scale=_bmm2_scale,
o_sf_scale=-1,
batch_size=batch_size,
window_left=-1,
Expand Down
4 changes: 2 additions & 2 deletions csrc/trtllm_fmha_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void trtllm_paged_attention_launcher(
// The sparse MLA parameters.
runner_params.mSparseMla = sparse_mla_top_k > 0;
runner_params.mSparseMlaTopK = sparse_mla_top_k;
TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || sparse_mla_top_k <= 0)
TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || (head_dim_qk == 320 && head_dim_vo == 256) || sparse_mla_top_k <= 0)
<< "Only decode MLA supports sparse MLA";

AlignedAllocator float_allocator(workspace_buffer, workspace_size);
Expand Down Expand Up @@ -251,7 +251,7 @@ void trtllm_paged_attention_decode(
TVM_FFI_ICHECK_EQ(head_dim_k, head_dim_q)
<< "head_dim_k and head_dim_q must be the same, got " << std::to_string(head_dim_k) << " and "
<< std::to_string(head_dim_q);
TVM_FFI_ICHECK((head_dim_v == 576 && head_dim_o == 512) || head_dim_v == head_dim_o)
TVM_FFI_ICHECK((head_dim_v == 576 && head_dim_o == 512) || (head_dim_v == 320 && head_dim_o == 256) || head_dim_v == head_dim_o)
<< "head_dim_v and head_dim_o must be the same for non-MLA attention, got "
<< std::to_string(head_dim_v) << " and " << std::to_string(head_dim_o);
int max_num_blocks_per_seq = block_tables.size(-1);
Expand Down
3 changes: 2 additions & 1 deletion flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class ArtifactPath:
When compiling new cubins for backend directories, update the corresponding path.
"""

TRTLLM_GEN_FMHA: str = "e86f0e45764555d070c3d143b4caaea61a45b777/fmha/trtllm-gen/"
TRTLLM_GEN_FMHA: str = "ac86d4cb8196b7686b32cd74598f71e28625d4c3/fmha/trtllm-gen/"
# TRTLLM_GEN_FMHA: str = "e86f0e45764555d070c3d143b4caaea61a45b777/fmha/trtllm-gen/"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please remove the commented-out old TRTLLM_GEN_FMHA path before merging to keep the code clean.

TRTLLM_GEN_BMM: str = (
"456b1ae890d436c794b17e4435b41b849d3e5950/batched_gemm-2a674db-3a84a12"
)
Expand Down
131 changes: 94 additions & 37 deletions flashinfer/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
limitations under the License.
"""

from dataclasses import dataclass
import functools
from typing import List, Literal, Optional, Tuple, Union, overload

Expand Down Expand Up @@ -63,16 +64,70 @@ def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table):
)


@dataclass(frozen=True)
class MLAHeadDimensions:
"""
The dimensions of a single MLA head.

Args:
qk_nope_head_dim (int): The number of input channels without positional information in non-absorb mode.
qk_rope_head_dim (int): The number of channels carrying positional information for both absorb and non-absorb modes.
v_head_dim (int): The number of value channels, which is also the output head dimension in non-absorb mode.
kv_lora_rank (int): The dimension of the compressed key-value representation across heads.
"""

qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
kv_lora_rank: int


deepseek_mla_dimensions = MLAHeadDimensions(
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
kv_lora_rank=512,
)

smaller_mla_dimensions = MLAHeadDimensions(
qk_nope_head_dim=64,
qk_rope_head_dim=64,
v_head_dim=128,
kv_lora_rank=256,
)

supported_mla_head_dimensions = [deepseek_mla_dimensions, smaller_mla_dimensions]


@dataclass(frozen=True)
class MLALayerDimensions:
"""
The dimensions of an MLA layer.

Args:
head_dimensions (MLAHeadDimensions): The dimensions of a single MLA head.
num_heads (int): The number of heads in the MLA layer.
"""

head_dimensions: MLAHeadDimensions
num_heads: int


supported_mla_layer_dimensions = [
MLALayerDimensions(head_dimensions=deepseek_mla_dimensions, num_heads=128),
MLALayerDimensions(head_dimensions=smaller_mla_dimensions, num_heads=32),
]


def _check_trtllm_gen_mla_shape(
query,
kv_cache,
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
sparse_mla_top_k,
page_table,
page_size,
):
query: torch.Tensor,
kv_cache: torch.Tensor,
kv_lora_rank: int,
qk_rope_head_dim: int,
sparse_mla_top_k: int,
page_table: torch.Tensor,
page_size: int,
) -> torch.Tensor:
if query.ndim != 4:
raise ValueError(f"Expected query.ndim == 4, got {query.ndim}")

Expand All @@ -83,35 +138,39 @@ def _check_trtllm_gen_mla_shape(
elif kv_cache.ndim != 4:
raise ValueError(f"Expected kv_cache.ndim == 3 or 4, got {kv_cache.ndim}")

if qk_nope_head_dim != 128:
raise ValueError(f"Expected qk_nope_head_dim == 128, got {qk_nope_head_dim}")
if kv_lora_rank != 512:
raise ValueError(f"Expected kv_lora_rank == 512, got {kv_lora_rank}")
if qk_rope_head_dim != 64:
raise ValueError(f"Expected qk_rope_head_dim == 64, got {qk_rope_head_dim}")

B_q, Q_len, H, D_q = query.shape
D_ckv = kv_cache.shape[3]
# if H != 128:
# raise ValueError(f"Expected 128 heads for query, got {H}")
# todo(Yingyi): should we check num_heads == 128? Is this deepseek only?
if D_q != D_ckv or D_q != 576:
is_deepseek_dimensions = (
kv_lora_rank == deepseek_mla_dimensions.kv_lora_rank
and qk_rope_head_dim == deepseek_mla_dimensions.qk_rope_head_dim
)
is_smaller_mla_dimensions = (
kv_lora_rank == smaller_mla_dimensions.kv_lora_rank
and qk_rope_head_dim == smaller_mla_dimensions.qk_rope_head_dim
)
if not (is_deepseek_dimensions or is_smaller_mla_dimensions):
raise ValueError(
f"Unsupported MLA dimensions, got kv_lora_rank={kv_lora_rank} and qk_rope_head_dim={qk_rope_head_dim}, supported dimensions are: {supported_mla_head_dimensions}"
)

num_seqs, num_tokens, _, qk_head_dim = query.shape
ckv_dim = kv_cache.shape[3]
expected_qk_head_dim = kv_lora_rank + qk_rope_head_dim
if qk_head_dim != expected_qk_head_dim or ckv_dim != expected_qk_head_dim:
raise ValueError(
f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}"
f"Expected head dim {expected_qk_head_dim} for query and kv_cache, got {qk_head_dim} and {ckv_dim}"
)

if sparse_mla_top_k > 0:
page_table_shape = page_table.shape
if page_table_shape != (B_q, Q_len, sparse_mla_top_k):
if page_table_shape != (num_seqs, num_tokens, sparse_mla_top_k):
raise ValueError(
f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}"
f"Expected page_table.shape == (num_seqs, num_tokens, sparse_mla_top_k), got {page_table_shape}"
)
else:
B_block_table, block_num = page_table.shape
block_size = page_size
if B_q != B_block_table:
if num_seqs != B_block_table:
raise ValueError(
f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}"
f"Expected batch size {num_seqs} for query and block_table, got {num_seqs} and {B_block_table}"
)
if block_num % (128 / block_size) != 0:
raise ValueError(
Expand Down Expand Up @@ -523,7 +582,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
query: torch.Tensor,
kv_cache: torch.Tensor,
workspace_buffer: torch.Tensor,
qk_nope_head_dim: int,
qk_nope_head_dim: int, # TODO: remove in 1.0?
kv_lora_rank: int,
qk_rope_head_dim: int,
block_tables: torch.Tensor,
Expand All @@ -535,7 +594,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
bmm2_scale: Union[float, torch.Tensor] = 1.0,
sinks: Optional[List[torch.Tensor]] = None,
skip_softmax_threshold_scale_factor: Optional[float] = None,
enable_pdl: bool = None,
enable_pdl: bool | None = None,
backend: str = "auto",
) -> torch.Tensor:
"""
Expand All @@ -544,8 +603,8 @@ def trtllm_batch_decode_with_kv_cache_mla(
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe] or [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache. Both 3D and 4D formats are supported for backward compatibility.
workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use.
qk_nope_head_dim: qk_nope_head_dim, must be 128
kv_lora_rank: kv_lora_rank, must be 512
qk_nope_head_dim: qk_nope_head_dim, must be 128 or 64
kv_lora_rank: kv_lora_rank, must be 512 or 256
qk_rope_head_dim: qk_rope_head_dim, must be 64
sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA.
block_tables: page_table of kv cache, [batch_size, num_pages]
Expand Down Expand Up @@ -617,7 +676,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
query,
kv_cache,
workspace_buffer,
qk_nope_head_dim,
-1, # Unused, marked for removal.
kv_lora_rank,
qk_rope_head_dim,
block_tables,
Expand Down Expand Up @@ -650,7 +709,6 @@ def trtllm_batch_decode_with_kv_cache_mla(
kv_cache = _check_trtllm_gen_mla_shape(
query,
kv_cache,
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
sparse_mla_top_k,
Expand Down Expand Up @@ -712,17 +770,17 @@ def xqa_batch_decode_with_kv_cache_mla(
query: torch.Tensor,
kv_cache: torch.Tensor,
workspace_buffer: torch.Tensor,
qk_nope_head_dim: int,
qk_nope_head_dim: int, # TODO: remove in 1.0?
kv_lora_rank: int,
qk_rope_head_dim: int,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
max_seq_len: int, # TODO: remove in 1.0?
out: Optional[torch.Tensor] = None,
bmm1_scale: Union[float, torch.Tensor] = 1.0,
bmm2_scale: Union[float, torch.Tensor] = 1.0,
sinks: Optional[List[torch.Tensor]] = None,
enable_pdl: bool = None,
enable_pdl: bool | None = None,
) -> torch.Tensor:
"""
Parameters:
Expand Down Expand Up @@ -776,7 +834,6 @@ def xqa_batch_decode_with_kv_cache_mla(
kv_cache = _check_trtllm_gen_mla_shape(
query,
kv_cache,
qk_nope_head_dim,
kv_lora_rank,
qk_rope_head_dim,
0, # sparse_mla_top_k
Expand Down
14 changes: 11 additions & 3 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3495,8 +3495,10 @@ def trtllm_ragged_attention_deepseek(
If return_lse is True, the output will be a tuple of two tensors, the first is the output tensor, the second is the lse tensor.
If return_lse is False, the output will be a single tensor.
"""
assert query.shape[2] == 192 and key.shape[2] == 192 and value.shape[2] == 128, (
"currently only support deepseek r1 192 query and 128 value"
is_dsr1 = query.shape[2] == 192 and key.shape[2] == 192 and value.shape[2] == 128
is_smaller_dimensions = query.shape[2] == 128 and key.shape[2] == 128 and value.shape[2] == 128
assert is_dsr1 or is_smaller_dimensions, (
"currently only support deepseek r1 192 query and 128 value or smaller dimensions 128 query and 128 value"
)

if enable_pdl is None:
Expand All @@ -3505,12 +3507,18 @@ def trtllm_ragged_attention_deepseek(
run_func = get_trtllm_gen_fmha_module().trtllm_ragged_attention
sm_count = get_device_sm_count(query.device)
if out is None:
# FP8 inputs produce bfloat16 output by default (TRT-LLM kernels
# do not support FP8 output for ragged attention)
if query.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
out_dtype = torch.bfloat16
else:
out_dtype = query.dtype
out = torch.empty(
query.shape[0],
query.shape[1],
value.shape[2],
device=query.device,
dtype=query.dtype,
dtype=out_dtype,
)
if return_lse and lse is None:
lse = torch.empty(
Expand Down
Loading
Loading