From d18e11ce8029d9222719b3c6dba03f577f6f7094 Mon Sep 17 00:00:00 2001 From: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> Date: Tue, 24 Mar 2026 00:02:30 -0700 Subject: [PATCH 1/2] fix issue 2856 Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> --- flashinfer/mla/_core.py | 9 +-- tests/attention/test_trtllm_gen_mla.py | 99 ++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 4 deletions(-) diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index d722abaeb6..f2238a739c 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -749,14 +749,15 @@ def trtllm_batch_decode_with_kv_cache_mla( uses_shared_paged_kv_idx, ) + expected_out_shape = query.shape[:-1] + (kv_lora_rank,) if out is None: - out_shape = query.shape[:-1] + (kv_lora_rank,) - out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) + out = torch.empty( + expected_out_shape, dtype=torch.bfloat16, device=query.device + ) else: - batch_size, _, num_q_heads, _ = query.shape check_shape_dtype_device( out, - [batch_size, num_q_heads, kv_lora_rank], + expected_out_shape, torch.bfloat16, query.device, "out", diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index 19baa8d182..55587ec453 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -879,3 +879,102 @@ def test_trtllm_batch_decode_mla_sparse( qk_nope_head_dim, num_attn_heads, ) + + +@pytest.mark.parametrize("q_len_per_request", [1, 2, 4]) +@pytest.mark.parametrize("batch_size", [1, 4]) +def test_trtllm_batch_decode_mla_preallocated_out( + q_len_per_request: int, + batch_size: int, +): + """Issue #2856: pre-allocated out tensor rejected when q_len_per_req > 1. + The shape check hardcoded 3D but query is 4D for multi-token generation.""" + cc = get_compute_capability(torch.device("cuda")) + if cc[0] != 10: + pytest.skip("trtllm-gen MLA requires SM100/SM103") + + device = "cuda:0" + layer_dim = supported_mla_layer_dimensions[0] + kv_lora_rank = layer_dim.head_dimensions.kv_lora_rank + qk_nope_head_dim = layer_dim.head_dimensions.qk_nope_head_dim + qk_rope_head_dim = layer_dim.head_dimensions.qk_rope_head_dim + num_heads = layer_dim.num_heads + head_dim_qk = kv_lora_rank + qk_rope_head_dim + + page_size = 64 + max_seq_len = 256 + num_pages_per_seq = (max_seq_len + page_size - 1) // page_size + head_dim_ckv_kpe = kv_lora_rank + qk_rope_head_dim + + kv_cache = torch.randn( + num_pages_per_seq * batch_size, + 1, + page_size, + head_dim_ckv_kpe, + dtype=torch.bfloat16, + device=device, + ) + block_tables = torch.arange( + num_pages_per_seq * batch_size, + device=device, + dtype=torch.int32, + ).reshape(batch_size, num_pages_per_seq) + seq_lens = torch.full((batch_size,), max_seq_len, device=device, dtype=torch.int32) + + query = torch.randn( + batch_size, + q_len_per_request, + num_heads, + head_dim_qk, + dtype=torch.bfloat16, + device=device, + ) + + global global_trtllm_gen_fmha_workspace_buffer + if global_trtllm_gen_fmha_workspace_buffer is None: + global_trtllm_gen_fmha_workspace_buffer = torch.zeros( + workspace_size, + dtype=torch.int8, + device=device, + ) + workspace = global_trtllm_gen_fmha_workspace_buffer + + bmm1_scale = 1.0 / (head_dim_qk**0.5) + + # out=None should work + result_none = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=bmm1_scale, + bmm2_scale=1.0, + backend="trtllm-gen", + ) + expected_shape = (batch_size, q_len_per_request, num_heads, kv_lora_rank) + assert result_none.shape == expected_shape + + # out=pre-allocated should also work (this was the bug) + out = torch.empty(expected_shape, dtype=torch.bfloat16, device=device) + result_pre = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache, + workspace_buffer=workspace, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + out=out, + bmm1_scale=bmm1_scale, + bmm2_scale=1.0, + backend="trtllm-gen", + ) + assert result_pre.shape == expected_shape + torch.testing.assert_close(result_none, result_pre, rtol=1e-3, atol=1e-3) From bcaeec66117bb983b8d5762bdc3589dbb7fc0d6b Mon Sep 17 00:00:00 2001 From: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> Date: Tue, 24 Mar 2026 00:59:29 -0700 Subject: [PATCH 2/2] resolve coderabbit comment Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> --- tests/attention/test_trtllm_gen_mla.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index 55587ec453..c1cf3d8a50 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -976,5 +976,8 @@ def test_trtllm_batch_decode_mla_preallocated_out( bmm2_scale=1.0, backend="trtllm-gen", ) + assert result_pre.data_ptr() == out.data_ptr(), ( + "Expected kernel to write into provided out tensor" + ) assert result_pre.shape == expected_shape torch.testing.assert_close(result_none, result_pre, rtol=1e-3, atol=1e-3)