diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..7d0cc6112f --- /dev/null +++ b/Makefile @@ -0,0 +1,47 @@ +bench-mla-decode: + python benchmarks/flashinfer_benchmark.py \ + --routine BatchMLAPagedAttentionWrapper \ + --batch_size 1024 \ + --s_kv 8192 \ + --num_qo_heads 32 \ + --num_kv_heads 1 \ + --head_dim_ckv 256 \ + --head_dim_kpe 64 \ + --page_size 64 \ + --backends trtllm-native \ + --q_dtype bfloat16 \ + --kv_dtype bfloat16 \ + --s_qo 1 \ + --num_iters 500 + +bench-mla-prefill-bf16: + python benchmarks/flashinfer_benchmark.py \ + --routine BatchPrefillWithRaggedKVCacheWrapper \ + --batch_size 2 \ + --s_kv 8192 \ + --s_qo 8192 \ + --num_qo_heads 128 \ + --num_kv_heads 128 \ + --head_dim_qk 128 \ + --head_dim_vo 128 \ + --page_size 64 \ + --backends trtllm-native \ + --q_dtype bfloat16 \ + --kv_dtype bfloat16 \ + --num_iters 100 + +bench-mla-prefill-fp8: + python benchmarks/flashinfer_benchmark.py \ + --routine BatchPrefillWithRaggedKVCacheWrapper \ + --batch_size 2 \ + --s_kv 8192 \ + --s_qo 8192 \ + --num_qo_heads 128 \ + --num_kv_heads 128 \ + --head_dim_qk 128 \ + --head_dim_vo 128 \ + --page_size 64 \ + --backends trtllm-native \ + --q_dtype fp8_e4m3 \ + --kv_dtype fp8_e4m3 \ + --num_iters 100 \ No newline at end of file diff --git a/benchmarks/routines/attention.py b/benchmarks/routines/attention.py index 9b3e8cb206..393a10a9e7 100644 --- a/benchmarks/routines/attention.py +++ b/benchmarks/routines/attention.py @@ -1352,7 +1352,7 @@ def run_backend_wrapper( def testBatchPrefillWithRaggedKVCacheWrapper(args): """ Test BatchPrefillWithRaggedKVCacheWrapper API and equivalent cuDNN API. - Supports fa2, fa3, cutlass, and cudnn backends. + Supports fa2, fa3, cutlass, cudnn and trtllm-gen backends. This test: 1. Creates ragged KV cache and query tensors for prefill @@ -1460,15 +1460,11 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): backends.remove("trtllm-gen") if "trtllm-native" in backends: remove_trtllm_native = False - if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ - torch.float8_e4m3fn, - torch.float8_e5m2, - ]: - print("[INFO] trtllm-native backend does not support FP8. Skipping.") - remove_trtllm_native = True - if not (head_dim_qk == 192 and head_dim_vo == 128): + if not (head_dim_qk == 192 and head_dim_vo == 128) and not ( + head_dim_qk == 128 and head_dim_vo == 128 + ): print( - "[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128" + "[INFO] trtllm-native backend requires head_dim_qk == 192 and head_dim_vo == 128 or head_dim_qk == 128 and head_dim_vo == 128. Skipping." ) remove_trtllm_native = True if remove_trtllm_native: @@ -1733,6 +1729,12 @@ def run_backend_wrapper( is_cuda_graph_compatible=True, )[0] elif backend == "trtllm-native": + # For FP8: bmm1_scale = q_scale * k_scale * sm_scale, + # bmm2_scale = v_scale + _k_scale = k_scale if k_scale is not None else 1.0 + _v_scale = v_scale if v_scale is not None else 1.0 + _bmm1_scale = scale * _k_scale + _bmm2_scale = _v_scale return flashinfer.prefill.trtllm_ragged_attention_deepseek( query=q, key=k, @@ -1741,8 +1743,8 @@ def run_backend_wrapper( seq_lens=actual_seq_lens_kv_device, max_q_len=s_qo, max_kv_len=s_kv, - bmm1_scale=scale, - bmm2_scale=1.0, + bmm1_scale=_bmm1_scale, + bmm2_scale=_bmm2_scale, o_sf_scale=-1, batch_size=batch_size, window_left=-1, diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 6aea76ee58..41a36e72bb 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -146,7 +146,7 @@ void trtllm_paged_attention_launcher( // The sparse MLA parameters. runner_params.mSparseMla = sparse_mla_top_k > 0; runner_params.mSparseMlaTopK = sparse_mla_top_k; - TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || sparse_mla_top_k <= 0) + TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || (head_dim_qk == 320 && head_dim_vo == 256) || sparse_mla_top_k <= 0) << "Only decode MLA supports sparse MLA"; AlignedAllocator float_allocator(workspace_buffer, workspace_size); @@ -251,7 +251,7 @@ void trtllm_paged_attention_decode( TVM_FFI_ICHECK_EQ(head_dim_k, head_dim_q) << "head_dim_k and head_dim_q must be the same, got " << std::to_string(head_dim_k) << " and " << std::to_string(head_dim_q); - TVM_FFI_ICHECK((head_dim_v == 576 && head_dim_o == 512) || head_dim_v == head_dim_o) + TVM_FFI_ICHECK((head_dim_v == 576 && head_dim_o == 512) || (head_dim_v == 320 && head_dim_o == 256) || head_dim_v == head_dim_o) << "head_dim_v and head_dim_o must be the same for non-MLA attention, got " << std::to_string(head_dim_v) << " and " << std::to_string(head_dim_o); int max_num_blocks_per_seq = block_tables.size(-1); diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index cce73a6827..ede803931f 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -87,7 +87,8 @@ class ArtifactPath: When compiling new cubins for backend directories, update the corresponding path. """ - TRTLLM_GEN_FMHA: str = "e86f0e45764555d070c3d143b4caaea61a45b777/fmha/trtllm-gen/" + TRTLLM_GEN_FMHA: str = "ac86d4cb8196b7686b32cd74598f71e28625d4c3/fmha/trtllm-gen/" + # TRTLLM_GEN_FMHA: str = "e86f0e45764555d070c3d143b4caaea61a45b777/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( "456b1ae890d436c794b17e4435b41b849d3e5950/batched_gemm-2a674db-3a84a12" ) diff --git a/flashinfer/mla.py b/flashinfer/mla.py index 3adda40789..03fe784df3 100644 --- a/flashinfer/mla.py +++ b/flashinfer/mla.py @@ -14,6 +14,7 @@ limitations under the License. """ +from dataclasses import dataclass import functools from typing import List, Literal, Optional, Tuple, Union, overload @@ -63,16 +64,70 @@ def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table): ) +@dataclass(frozen=True) +class MLAHeadDimensions: + """ + The dimensions of a single MLA head. + + Args: + qk_nope_head_dim (int): The number of input channels without positional information in non-absorb mode. + qk_rope_head_dim (int): The number of channels carrying positional information for both absorb and non-absorb modes. + v_head_dim (int): The number of value channels, which is also the output head dimension in non-absorb mode. + kv_lora_rank (int): The dimension of the compressed key-value representation across heads. + """ + + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + kv_lora_rank: int + + +deepseek_mla_dimensions = MLAHeadDimensions( + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + kv_lora_rank=512, +) + +smaller_mla_dimensions = MLAHeadDimensions( + qk_nope_head_dim=64, + qk_rope_head_dim=64, + v_head_dim=128, + kv_lora_rank=256, +) + +supported_mla_head_dimensions = [deepseek_mla_dimensions, smaller_mla_dimensions] + + +@dataclass(frozen=True) +class MLALayerDimensions: + """ + The dimensions of an MLA layer. + + Args: + head_dimensions (MLAHeadDimensions): The dimensions of a single MLA head. + num_heads (int): The number of heads in the MLA layer. + """ + + head_dimensions: MLAHeadDimensions + num_heads: int + + +supported_mla_layer_dimensions = [ + MLALayerDimensions(head_dimensions=deepseek_mla_dimensions, num_heads=128), + MLALayerDimensions(head_dimensions=smaller_mla_dimensions, num_heads=32), +] + + def _check_trtllm_gen_mla_shape( - query, - kv_cache, - qk_nope_head_dim, - kv_lora_rank, - qk_rope_head_dim, - sparse_mla_top_k, - page_table, - page_size, -): + query: torch.Tensor, + kv_cache: torch.Tensor, + kv_lora_rank: int, + qk_rope_head_dim: int, + sparse_mla_top_k: int, + page_table: torch.Tensor, + page_size: int, +) -> torch.Tensor: if query.ndim != 4: raise ValueError(f"Expected query.ndim == 4, got {query.ndim}") @@ -83,35 +138,39 @@ def _check_trtllm_gen_mla_shape( elif kv_cache.ndim != 4: raise ValueError(f"Expected kv_cache.ndim == 3 or 4, got {kv_cache.ndim}") - if qk_nope_head_dim != 128: - raise ValueError(f"Expected qk_nope_head_dim == 128, got {qk_nope_head_dim}") - if kv_lora_rank != 512: - raise ValueError(f"Expected kv_lora_rank == 512, got {kv_lora_rank}") - if qk_rope_head_dim != 64: - raise ValueError(f"Expected qk_rope_head_dim == 64, got {qk_rope_head_dim}") - - B_q, Q_len, H, D_q = query.shape - D_ckv = kv_cache.shape[3] - # if H != 128: - # raise ValueError(f"Expected 128 heads for query, got {H}") - # todo(Yingyi): should we check num_heads == 128? Is this deepseek only? - if D_q != D_ckv or D_q != 576: + is_deepseek_dimensions = ( + kv_lora_rank == deepseek_mla_dimensions.kv_lora_rank + and qk_rope_head_dim == deepseek_mla_dimensions.qk_rope_head_dim + ) + is_smaller_mla_dimensions = ( + kv_lora_rank == smaller_mla_dimensions.kv_lora_rank + and qk_rope_head_dim == smaller_mla_dimensions.qk_rope_head_dim + ) + if not (is_deepseek_dimensions or is_smaller_mla_dimensions): + raise ValueError( + f"Unsupported MLA dimensions, got kv_lora_rank={kv_lora_rank} and qk_rope_head_dim={qk_rope_head_dim}, supported dimensions are: {supported_mla_head_dimensions}" + ) + + num_seqs, num_tokens, _, qk_head_dim = query.shape + ckv_dim = kv_cache.shape[3] + expected_qk_head_dim = kv_lora_rank + qk_rope_head_dim + if qk_head_dim != expected_qk_head_dim or ckv_dim != expected_qk_head_dim: raise ValueError( - f"Expected head dim 576 for query and kv_cache, got {D_q} and {D_ckv}" + f"Expected head dim {expected_qk_head_dim} for query and kv_cache, got {qk_head_dim} and {ckv_dim}" ) if sparse_mla_top_k > 0: page_table_shape = page_table.shape - if page_table_shape != (B_q, Q_len, sparse_mla_top_k): + if page_table_shape != (num_seqs, num_tokens, sparse_mla_top_k): raise ValueError( - f"Expected page_table.shape == (B_q, Q_len, sparse_mla_top_k), got {page_table_shape}" + f"Expected page_table.shape == (num_seqs, num_tokens, sparse_mla_top_k), got {page_table_shape}" ) else: B_block_table, block_num = page_table.shape block_size = page_size - if B_q != B_block_table: + if num_seqs != B_block_table: raise ValueError( - f"Expected batch size {B_q} for query and block_table, got {B_q} and {B_block_table}" + f"Expected batch size {num_seqs} for query and block_table, got {num_seqs} and {B_block_table}" ) if block_num % (128 / block_size) != 0: raise ValueError( @@ -523,7 +582,7 @@ def trtllm_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, workspace_buffer: torch.Tensor, - qk_nope_head_dim: int, + qk_nope_head_dim: int, # TODO: remove in 1.0? kv_lora_rank: int, qk_rope_head_dim: int, block_tables: torch.Tensor, @@ -535,7 +594,7 @@ def trtllm_batch_decode_with_kv_cache_mla( bmm2_scale: Union[float, torch.Tensor] = 1.0, sinks: Optional[List[torch.Tensor]] = None, skip_softmax_threshold_scale_factor: Optional[float] = None, - enable_pdl: bool = None, + enable_pdl: bool | None = None, backend: str = "auto", ) -> torch.Tensor: """ @@ -544,8 +603,8 @@ def trtllm_batch_decode_with_kv_cache_mla( query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length. kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe] or [num_pages, 1, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache. Both 3D and 4D formats are supported for backward compatibility. workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use. - qk_nope_head_dim: qk_nope_head_dim, must be 128 - kv_lora_rank: kv_lora_rank, must be 512 + qk_nope_head_dim: qk_nope_head_dim, must be 128 or 64 + kv_lora_rank: kv_lora_rank, must be 512 or 256 qk_rope_head_dim: qk_rope_head_dim, must be 64 sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA. block_tables: page_table of kv cache, [batch_size, num_pages] @@ -617,7 +676,7 @@ def trtllm_batch_decode_with_kv_cache_mla( query, kv_cache, workspace_buffer, - qk_nope_head_dim, + -1, # Unused, marked for removal. kv_lora_rank, qk_rope_head_dim, block_tables, @@ -650,7 +709,6 @@ def trtllm_batch_decode_with_kv_cache_mla( kv_cache = _check_trtllm_gen_mla_shape( query, kv_cache, - qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim, sparse_mla_top_k, @@ -712,17 +770,17 @@ def xqa_batch_decode_with_kv_cache_mla( query: torch.Tensor, kv_cache: torch.Tensor, workspace_buffer: torch.Tensor, - qk_nope_head_dim: int, + qk_nope_head_dim: int, # TODO: remove in 1.0? kv_lora_rank: int, qk_rope_head_dim: int, block_tables: torch.Tensor, seq_lens: torch.Tensor, - max_seq_len: int, + max_seq_len: int, # TODO: remove in 1.0? out: Optional[torch.Tensor] = None, bmm1_scale: Union[float, torch.Tensor] = 1.0, bmm2_scale: Union[float, torch.Tensor] = 1.0, sinks: Optional[List[torch.Tensor]] = None, - enable_pdl: bool = None, + enable_pdl: bool | None = None, ) -> torch.Tensor: """ Parameters: @@ -776,7 +834,6 @@ def xqa_batch_decode_with_kv_cache_mla( kv_cache = _check_trtllm_gen_mla_shape( query, kv_cache, - qk_nope_head_dim, kv_lora_rank, qk_rope_head_dim, 0, # sparse_mla_top_k diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index bf0fcfc822..d39bd7a44b 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -3495,8 +3495,10 @@ def trtllm_ragged_attention_deepseek( If return_lse is True, the output will be a tuple of two tensors, the first is the output tensor, the second is the lse tensor. If return_lse is False, the output will be a single tensor. """ - assert query.shape[2] == 192 and key.shape[2] == 192 and value.shape[2] == 128, ( - "currently only support deepseek r1 192 query and 128 value" + is_dsr1 = query.shape[2] == 192 and key.shape[2] == 192 and value.shape[2] == 128 + is_smaller_dimensions = query.shape[2] == 128 and key.shape[2] == 128 and value.shape[2] == 128 + assert is_dsr1 or is_smaller_dimensions, ( + "currently only support deepseek r1 192 query and 128 value or smaller dimensions 128 query and 128 value" ) if enable_pdl is None: @@ -3505,12 +3507,18 @@ def trtllm_ragged_attention_deepseek( run_func = get_trtllm_gen_fmha_module().trtllm_ragged_attention sm_count = get_device_sm_count(query.device) if out is None: + # FP8 inputs produce bfloat16 output by default (TRT-LLM kernels + # do not support FP8 output for ragged attention) + if query.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + out_dtype = torch.bfloat16 + else: + out_dtype = query.dtype out = torch.empty( query.shape[0], query.shape[1], value.shape[2], device=query.device, - dtype=query.dtype, + dtype=out_dtype, ) if return_lse and lse is None: lse = torch.empty( diff --git a/tests/attention/test_trtllm_gen_attention.py b/tests/attention/test_trtllm_gen_attention.py index 50f5bc9d1e..bc3e594193 100755 --- a/tests/attention/test_trtllm_gen_attention.py +++ b/tests/attention/test_trtllm_gen_attention.py @@ -2,6 +2,11 @@ import pytest import torch +from flashinfer.mla import ( + MLAHeadDimensions, + deepseek_mla_dimensions, + smaller_mla_dimensions, +) from tests.test_helpers.utils_fp4 import ( cast_from_fp4, recover_swizzled_scales, @@ -36,7 +41,7 @@ def flip_coin(*args, **kwargs): return (hash_value % 2) == 0 -def to_float8(x, dtype=torch.float8_e4m3fn): +def to_float8(x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn): finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) @@ -45,7 +50,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn): return x_scl_sat.to(dtype), scale.float().reciprocal() -def generate_seq_lens_prefill(batch_size, max_q_len, max_in_kv_len): +def generate_seq_lens_prefill(batch_size: int, max_q_len: int, max_in_kv_len: int): q_lens = torch.randint(1, max_q_len + 1, (batch_size,), dtype=torch.int32) q_lens[-1] = max_q_len in_kv_lens = torch.randint(0, max_in_kv_len + 1, (batch_size,), dtype=torch.int) @@ -54,7 +59,9 @@ def generate_seq_lens_prefill(batch_size, max_q_len, max_in_kv_len): return q_lens, in_kv_lens, seq_lens -def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len, max_q_len): +def generate_seq_lens_decode( + batch_size: int, q_len_per_req: int | None, max_in_kv_len: int, max_q_len: int +): if q_len_per_req is not None: assert max_q_len is None, "Can not specify both q_len_per_req and max_q_len." q_lens = torch.full((batch_size,), q_len_per_req, dtype=torch.int32) @@ -67,7 +74,7 @@ def generate_seq_lens_decode(batch_size, q_len_per_req, max_in_kv_len, max_q_len return q_lens, in_kv_lens, seq_lens -def generate_cumsum_lens(lens): +def generate_cumsum_lens(lens: torch.Tensor): return torch.cat( [ torch.tensor([0], dtype=torch.int32, device=GPU_DEVICE), @@ -76,9 +83,11 @@ def generate_cumsum_lens(lens): ) -def create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype): +def create_query_tensor( + q_lens: torch.Tensor, num_qo_heads: int, head_dim: int, q_dtype: str +): q = torch.randn( - torch.sum(q_lens).item(), + int(torch.sum(q_lens).item()), num_qo_heads, head_dim, dtype=torch.bfloat16 if q_dtype == "fp8" else DTYPE_MAP[q_dtype], @@ -96,14 +105,14 @@ def create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype): def create_kv_cache( - batch_size, - seq_lens, - page_size, - num_kv_heads, - head_dim, - kv_dtype, - ref_kv_dtype, - kv_layout="HND", + batch_size: int, + seq_lens: torch.Tensor, + page_size: int, + num_kv_heads: int, + head_dim: int, + kv_dtype: str, + ref_kv_dtype: str, + kv_layout: str = "HND", ): # Create separate K and V caches max_seq_len = torch.max(seq_lens).item() @@ -174,12 +183,12 @@ def create_kv_cache( return kv_cache, k_scale, v_scale, ref_kv_cache -def create_page_table(batch_size, seq_lens, page_size): +def create_page_table(batch_size: int, seq_lens: torch.Tensor, page_size: int): page_per_seq = (seq_lens + page_size - 1) // page_size - max_num_pages_per_seq = torch.max(page_per_seq).item() + max_num_pages_per_seq = int(torch.max(page_per_seq).item()) # Generate random but unique page IDs for all sequences - total_pages_needed = torch.sum(page_per_seq).item() + total_pages_needed = int(torch.sum(page_per_seq).item()) all_page_ids = torch.randperm( total_pages_needed, dtype=torch.int32, device=GPU_DEVICE ) @@ -251,7 +260,7 @@ def flatten_paged_kv( return k_flat, v_flat, kv_indptr_tokens -def create_workspace_buffers(device): +def create_workspace_buffers(device: torch.Device): # Lazily initialize and reuse global workspace buffers global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer if global_workspace_buffer is None: @@ -265,7 +274,9 @@ def create_workspace_buffers(device): return global_trtllm_gen_fmha_workspace_buffer, global_workspace_buffer -def create_output(q, o_dtype, create_out_tensor, create_out_dtype): +def create_output( + q: torch.Tensor, o_dtype: str, create_out_tensor: bool, create_out_dtype: bool +): if o_dtype == "fp8": o_scale = torch.rand(1).item() * 0.5 + 0.5 # Scale range: 0.5 ~ 1.0 else: @@ -313,12 +324,12 @@ def get_last_page_len(seq_lens, page_size): def unpack_compare_nvfp4( output: FP4Tensor, - output_ref, - o_sf_scale, - o_sf_vec_size, - sf_rtol=2e-1, - sf_atol=2e-1, - rmse_tol=0.3, + output_ref: torch.Tensor, + o_sf_scale: float, + o_sf_vec_size: int, + sf_rtol: float = 2e-1, + sf_atol: float = 2e-1, + rmse_tol: float = 0.3, ): output_ref, out_scale_factor_ref = ref_fp4_quant( output_ref, o_sf_scale, o_sf_vec_size @@ -405,23 +416,23 @@ def generate_causal_mask( def _test_trtllm_batch_prefill( - kv_layout, - batch_size, - page_size, - num_kv_heads, - head_grp_size, - window_left, - q_dtype, - o_dtype, - kv_dtype, - enable_pdl, - enable_sink, - max_q_len, - max_kv_len, - device_scale, - head_dim, - non_contiguous_query=False, - skips_softmax=False, + kv_layout: str, + batch_size: int, + page_size: int, + num_kv_heads: int, + head_grp_size: int, + window_left: int, + q_dtype: str, + o_dtype: str, + kv_dtype: str, + enable_pdl: bool, + enable_sink: bool, + max_q_len: int, + max_kv_len: int, + device_scale: float, + head_dim: int, + non_contiguous_query: bool = False, + skips_softmax: bool = False, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: @@ -437,9 +448,7 @@ def _test_trtllm_batch_prefill( # Generate random sequence lengths num_qo_heads = num_kv_heads * head_grp_size - q_lens, in_kv_lens, seq_lens = generate_seq_lens_prefill( - batch_size, max_q_len, max_kv_len - ) + q_lens, _, seq_lens = generate_seq_lens_prefill(batch_size, max_q_len, max_kv_len) # Create query tensor and related data q, q_scale, ref_q = create_query_tensor(q_lens, num_qo_heads, head_dim, q_dtype) @@ -667,22 +676,22 @@ def _test_trtllm_batch_prefill( @pytest.mark.parametrize("non_contiguous_query", [False, True]) @pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_batch_prefill( - kv_layout, - batch_size, - page_size, - num_kv_heads, - head_grp_size, - window_left, - q_dtype, - o_dtype, - kv_dtype, - enable_pdl, - enable_sink, - max_q_len, - max_kv_len, - head_dim, - non_contiguous_query, - skips_softmax, + kv_layout: str, + batch_size: int, + page_size: int, + num_kv_heads: int, + head_grp_size: int, + window_left: int, + q_dtype: str, + o_dtype: str, + kv_dtype: str, + enable_pdl: bool, + enable_sink: bool, + max_q_len: int, + max_kv_len: int, + head_dim: int, + non_contiguous_query: bool, + skips_softmax: bool, ): _test_trtllm_batch_prefill( kv_layout, @@ -726,21 +735,21 @@ def test_trtllm_batch_prefill( @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_batch_prefill_bs1( - kv_layout, - batch_size, - page_size, - num_kv_heads, - head_grp_size, - window_left, - q_dtype, - o_dtype, - kv_dtype, - enable_pdl, - enable_sink, - max_q_len, - max_kv_len, - head_dim, - skips_softmax, + kv_layout: str, + batch_size: int, + page_size: int, + num_kv_heads: int, + head_grp_size: int, + window_left: int, + q_dtype: str, + o_dtype: str, + kv_dtype: str, + enable_pdl: bool, + enable_sink: bool, + max_q_len: int, + max_kv_len: int, + head_dim: int, + skips_softmax: bool, ): _test_trtllm_batch_prefill( kv_layout, @@ -763,26 +772,26 @@ def test_trtllm_batch_prefill_bs1( def _test_trtllm_batch_decode( - backend, - kv_layout, - batch_size, - q_len_per_req, - page_size, - num_kv_heads, - head_grp_size, - window_left, - q_dtype, - o_dtype, - kv_dtype, - enable_pdl, - enable_sink, - max_in_kv_len, - head_dim, - device_scale=False, - max_q_len=None, - non_contiguous_query=False, - skips_softmax=False, -): + backend: str, + kv_layout: str, + batch_size: int, + q_len_per_req: int, + page_size: int, + num_kv_heads: int, + head_grp_size: int, + window_left: int, + q_dtype: str, + o_dtype: str, + kv_dtype: str, + enable_pdl: bool, + enable_sink: bool, + max_in_kv_len: int, + head_dim: int, + device_scale: bool = False, + max_q_len: int | None = None, + non_contiguous_query: bool = False, + skips_softmax: bool = False, +) -> None: """ Common function for testing trtllm-gen decode. @@ -1133,23 +1142,23 @@ def _test_trtllm_batch_decode( @pytest.mark.parametrize("non_contiguous_query", [False, True]) @pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_batch_decode( - backend, - kv_layout, - batch_size, - q_len_per_req, - page_size, - num_kv_heads, - head_grp_size, - window_left, - q_dtype, - o_dtype, - kv_dtype, - enable_pdl, - enable_sink, - max_in_kv_len, - head_dim, - non_contiguous_query, - skips_softmax, + backend: str, + kv_layout: str, + batch_size: int, + q_len_per_req: int, + page_size: int, + num_kv_heads: int, + head_grp_size: int, + window_left: int, + q_dtype: str, + o_dtype: str, + kv_dtype: str, + enable_pdl: bool, + enable_sink: bool, + max_in_kv_len: int, + head_dim: int, + non_contiguous_query: bool, + skips_softmax: bool, ): # xqa backend does not support non-contiguous query yet if backend == "xqa" and non_contiguous_query: @@ -1200,23 +1209,23 @@ def test_trtllm_batch_decode( @pytest.mark.parametrize("device_scale", [True, False]) @pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_batch_decode_bs1( - kv_layout, - batch_size, - q_len_per_req, - page_size, - num_kv_heads, - head_grp_size, - window_left, - q_dtype, - o_dtype, - kv_dtype, - enable_pdl, - enable_sink, - max_in_kv_len, - head_dim, - device_scale, - skips_softmax, -): + kv_layout: str, + batch_size: int, + q_len_per_req: int, + page_size: int, + num_kv_heads: int, + head_grp_size: int, + window_left: int, + q_dtype: str, + o_dtype: str, + kv_dtype: str, + enable_pdl: bool, + enable_sink: bool, + max_in_kv_len: int, + head_dim: int, + device_scale: bool, + skips_softmax: bool, +) -> None: # Small number of test cases for batch size 1 _test_trtllm_batch_decode( "trtllm-gen", @@ -1271,22 +1280,22 @@ def test_trtllm_batch_decode_bs1( @pytest.mark.parametrize("device_scale", [True, False]) @pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_batch_decode_head_dim_256( - kv_layout, - batch_size, - q_len_per_req, - page_size, - num_kv_heads, - head_grp_size, - window_left, - q_dtype, - o_dtype, - kv_dtype, - enable_pdl, - enable_sink, - max_in_kv_len, - head_dim, - device_scale, - skips_softmax, + kv_layout: str, + batch_size: int, + q_len_per_req: int, + page_size: int, + num_kv_heads: int, + head_grp_size: int, + window_left: int, + q_dtype: str, + o_dtype: str, + kv_dtype: str, + enable_pdl: bool, + enable_sink: bool, + max_in_kv_len: int, + head_dim: int, + device_scale: bool, + skips_softmax: bool, ): # Small number of test cases for head_dim = 256 _test_trtllm_batch_decode( @@ -1338,23 +1347,23 @@ def test_trtllm_batch_decode_head_dim_256( @pytest.mark.parametrize("device_scale", [True, False]) @pytest.mark.parametrize("skips_softmax", [False]) def test_trtllm_batch_decode_long_sequence_length( - kv_layout, - batch_size, - q_len_per_req, - page_size, - num_kv_heads, - head_grp_size, - window_left, - q_dtype, - o_dtype, - kv_dtype, - enable_pdl, - enable_sink, - max_in_kv_len, - head_dim, - device_scale, - skips_softmax, -): + kv_layout: str, + batch_size: int, + q_len_per_req: int, + page_size: int, + num_kv_heads: int, + head_grp_size: int, + window_left: int, + q_dtype: str, + o_dtype: str, + kv_dtype: str, + enable_pdl: bool, + enable_sink: bool, + max_in_kv_len: int, + head_dim: int, + device_scale: bool, + skips_softmax: bool, +) -> None: # Small number of test cases for long sequence length _test_trtllm_batch_decode( "trtllm-gen", @@ -1377,6 +1386,9 @@ def test_trtllm_batch_decode_long_sequence_length( ) +@pytest.mark.parametrize( + "mla_dimensions", [deepseek_mla_dimensions, smaller_mla_dimensions] +) @pytest.mark.parametrize("batch_size", [4, 128, 256]) @pytest.mark.parametrize("s_qo", [32, 64, 87]) @pytest.mark.parametrize("s_kv", [32, 64, 87]) @@ -1384,9 +1396,16 @@ def test_trtllm_batch_decode_long_sequence_length( @pytest.mark.parametrize("head_grp_size", [1, 5, 8]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("skips_softmax", [False, True]) -def test_trtllm_gen_prefill_deepseek( - batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal, skips_softmax -): +def test_trtllm_gen_prefill( + mla_dimensions: MLAHeadDimensions, + batch_size: int, + s_qo: int, + s_kv: int, + num_kv_heads: int, + head_grp_size: int, + causal: bool, + skips_softmax: bool, +) -> None: compute_capability = get_compute_capability(torch.device(device="cuda")) if compute_capability[0] != 10: pytest.skip("These tests are only guaranteed to work on SM100 and SM103 GPUs.") @@ -1394,8 +1413,8 @@ def test_trtllm_gen_prefill_deepseek( pytest.skip("s_qo > s_kv, skipping test as causal") num_qo_heads = num_kv_heads * head_grp_size - head_dim_qk = 192 - head_dim_vo = 128 + head_dim_qk = mla_dimensions.qk_nope_head_dim + mla_dimensions.qk_rope_head_dim + head_dim_vo = mla_dimensions.v_head_dim seed = 0 torch.manual_seed(seed) @@ -1409,8 +1428,8 @@ def test_trtllm_gen_prefill_deepseek( s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device ) - cumsum_s_qo = torch.sum(actual_seq_lens_q) - cumsum_s_kv = torch.sum(actual_seq_lens_kv) + cumsum_s_qo = int(torch.sum(actual_seq_lens_q).item()) + cumsum_s_kv = int(torch.sum(actual_seq_lens_kv).item()) q = torch.randn( cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16 @@ -1516,6 +1535,9 @@ def test_trtllm_gen_prefill_deepseek( assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all() +@pytest.mark.parametrize( + "mla_dimensions", [deepseek_mla_dimensions, smaller_mla_dimensions] +) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize("s_qo", [1024]) @pytest.mark.parametrize("s_kv", [1024]) @@ -1523,15 +1545,31 @@ def test_trtllm_gen_prefill_deepseek( @pytest.mark.parametrize("head_grp_size", [1]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("skips_softmax", [False, True]) -def test_trtllm_gen_prefill_deepseek_bs1( - batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal, skips_softmax +def test_trtllm_gen_prefill_bs1( + mla_dimensions: MLAHeadDimensions, + batch_size: int, + s_qo: int, + s_kv: int, + num_kv_heads: int, + head_grp_size: int, + causal: bool, + skips_softmax: bool, ): - test_trtllm_gen_prefill_deepseek( - batch_size, s_qo, s_kv, num_kv_heads, head_grp_size, causal, skips_softmax + test_trtllm_gen_prefill( + mla_dimensions, + batch_size, + s_qo, + s_kv, + num_kv_heads, + head_grp_size, + causal, + skips_softmax, ) -def make_query_non_contiguous(q, num_qo_heads, head_dim): +def make_query_non_contiguous( + q: torch.Tensor, num_qo_heads: int, head_dim: int +) -> torch.Tensor: """ Create a non-contiguous version of the query tensor. Create a (N, H, 2*D) tensor and slice the first D dimensions: x[..., :D] @@ -1593,23 +1631,23 @@ def make_query_non_contiguous(q, num_qo_heads, head_dim): @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_batch_decode_spec( - backend, - kv_layout, - batch_size, - max_q_len, - page_size, - num_kv_heads, - head_grp_size, - window_left, - q_dtype, - o_dtype, - kv_dtype, - enable_pdl, - enable_sink, - max_in_kv_len, - head_dim, - skips_softmax, -): + backend: str, + kv_layout: str, + batch_size: int, + max_q_len: int, + page_size: int, + num_kv_heads: int, + head_grp_size: int, + window_left: int, + q_dtype: str, + o_dtype: str, + kv_dtype: str, + enable_pdl: bool, + enable_sink: bool, + max_in_kv_len: int, + head_dim: int, + skips_softmax: bool, +) -> None: _test_trtllm_batch_decode( backend, kv_layout, diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index bd0ee03684..fc8b663a01 100644 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -3,6 +3,10 @@ import random import flashinfer +from flashinfer.mla import ( + MLALayerDimensions, + supported_mla_layer_dimensions, +) from flashinfer.utils import get_compute_capability global_workspace_buffer = None # can.be empty initialized @@ -210,6 +214,7 @@ def scaled_dot_product_attention( def trtllm_batch_decode_mla( + layer_dimensions: MLALayerDimensions, batch_size: int, scale: float, dtype: torch.dtype, @@ -241,18 +246,13 @@ def trtllm_batch_decode_mla( torch.manual_seed(42) device = "cuda:0" - # Deepseek attention config (decode-MLA) - num_q_heads = 128 - qk_nope_head_dim = 128 - qk_rope_head_dim = 64 - kv_lora_rank = 512 - # Initialize tensors query = torch.randn( batch_size, q_len_per_request, - num_q_heads, - kv_lora_rank + qk_rope_head_dim, + layer_dimensions.num_heads, + layer_dimensions.head_dimensions.kv_lora_rank + + layer_dimensions.head_dimensions.qk_rope_head_dim, device=device, ).to(dtype) @@ -292,7 +292,13 @@ def trtllm_batch_decode_mla( # Create interleaved KV cache # Allocate more than needed blocks, block_id is just enough, to mimick real-world cases kv_cache = torch.randn( - size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), device=device + size=( + num_blocks, + page_size, + layer_dimensions.head_dimensions.kv_lora_rank + + layer_dimensions.head_dimensions.qk_rope_head_dim, + ), + device=device, ).to(dtype) # (num_blocks, 1, page_size, kv_lora_rank + qk_rope_head_dim) @@ -318,9 +324,9 @@ def trtllm_batch_decode_mla( query=query, kv_cache=kv_cache.unsqueeze(1), workspace_buffer=workspace_buffer, - qk_nope_head_dim=qk_nope_head_dim, - kv_lora_rank=kv_lora_rank, - qk_rope_head_dim=qk_rope_head_dim, + qk_nope_head_dim=layer_dimensions.head_dimensions.qk_nope_head_dim, + kv_lora_rank=layer_dimensions.head_dimensions.kv_lora_rank, + qk_rope_head_dim=layer_dimensions.head_dimensions.qk_rope_head_dim, block_tables=block_tables, seq_lens=seq_lens_tensor, max_seq_len=max_seq_len, @@ -361,25 +367,29 @@ def trtllm_batch_decode_mla( kv_indptr, kv_indices, seq_lens_tensor, - num_q_heads, - kv_lora_rank, - qk_rope_head_dim, + layer_dimensions.num_heads, + layer_dimensions.head_dimensions.kv_lora_rank, + layer_dimensions.head_dimensions.qk_rope_head_dim, page_size, True, sm_scale, query.dtype, kv_cache.dtype, ) - q_nope = query[..., :kv_lora_rank].view( - batch_size * q_len_per_request, num_q_heads, kv_lora_rank + q_nope = query[..., : layer_dimensions.head_dimensions.kv_lora_rank].view( + batch_size * q_len_per_request, + layer_dimensions.num_heads, + layer_dimensions.head_dimensions.kv_lora_rank, ) - q_pe = query[..., kv_lora_rank:].view( - batch_size * q_len_per_request, num_q_heads, qk_rope_head_dim + q_pe = query[..., layer_dimensions.head_dimensions.kv_lora_rank :].view( + batch_size * q_len_per_request, + layer_dimensions.num_heads, + layer_dimensions.head_dimensions.qk_rope_head_dim, ) # todo: fix kv_cache - ckv = kv_cache[..., :kv_lora_rank] - kpe = kv_cache[..., kv_lora_rank:] + ckv = kv_cache[..., : layer_dimensions.head_dimensions.kv_lora_rank] + kpe = kv_cache[..., layer_dimensions.head_dimensions.kv_lora_rank :] o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False) @@ -392,7 +402,9 @@ def trtllm_batch_decode_mla( try: torch.testing.assert_close( output, - o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), + o_ref.view( + batch_size, q_len_per_request, layer_dimensions.num_heads, -1 + ), rtol=1e-1, atol=1e-1, ) # todo: do reference with normal attention? @@ -404,7 +416,9 @@ def trtllm_batch_decode_mla( try: torch.testing.assert_close( output, - o_ref.view(batch_size, q_len_per_request, num_q_heads, -1), + o_ref.view( + batch_size, q_len_per_request, layer_dimensions.num_heads, -1 + ), rtol=1e-2, atol=1e-2, ) @@ -417,10 +431,16 @@ def trtllm_batch_decode_mla( rtol = 0.05 diff_abs = torch.abs( - o_ref.view(batch_size, q_len_per_request, num_q_heads, -1) - output + o_ref.view(batch_size, q_len_per_request, layer_dimensions.num_heads, -1) + - output ) diff_rel = diff_abs / ( - torch.abs(o_ref.view(batch_size, q_len_per_request, num_q_heads, -1)) + 1e-8 + torch.abs( + o_ref.view( + batch_size, q_len_per_request, layer_dimensions.num_heads, -1 + ) + ) + + 1e-8 ) within_tolerance = (diff_abs <= atol) | (diff_rel <= rtol) @@ -434,6 +454,10 @@ def trtllm_batch_decode_mla( ) +@pytest.mark.parametrize( + "layer_dimensions", + supported_mla_layer_dimensions, +) @pytest.mark.parametrize( "batch_size", [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024], @@ -449,6 +473,7 @@ def trtllm_batch_decode_mla( @pytest.mark.parametrize("backend", ["trtllm-gen", "xqa"]) @pytest.mark.parametrize("skips_softmax", [False, True]) def test_trtllm_batch_decode_mla( + layer_dimensions: MLALayerDimensions, batch_size: int, scale: float, dtype: torch.dtype, @@ -460,6 +485,7 @@ def test_trtllm_batch_decode_mla( skips_softmax: bool, ): trtllm_batch_decode_mla( + layer_dimensions, batch_size, scale, dtype, @@ -473,6 +499,10 @@ def test_trtllm_batch_decode_mla( ) +@pytest.mark.parametrize( + "layer_dimensions", + supported_mla_layer_dimensions, +) @pytest.mark.parametrize( "batch_size", [2, 4, 8], @@ -487,6 +517,7 @@ def test_trtllm_batch_decode_mla( @pytest.mark.parametrize("MAX_SEQ_LEN", [1024, 8960]) @pytest.mark.parametrize("skips_softmax", [False, True]) def test_dsr1_trtllm_mla( + layer_dimensions: MLALayerDimensions, batch_size: int, scale: float, dtype: torch.dtype, @@ -499,6 +530,7 @@ def test_dsr1_trtllm_mla( skips_softmax: bool, ): trtllm_batch_decode_mla( + layer_dimensions, batch_size, scale, dtype, @@ -706,7 +738,7 @@ def test_trtllm_batch_decode_mla_sparse( sm_scale = scale / ((qk_nope_head_dim + qk_rope_head_dim) ** 0.5) - out_ref, lse_ref = sparse_mla_reference_torch( + out_ref, _ = sparse_mla_reference_torch( cache_seqlens=seq_lens_tensor, block_table=block_tables, q=query_ref,