Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions flashinfer/mla/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,14 +749,15 @@ def trtllm_batch_decode_with_kv_cache_mla(
uses_shared_paged_kv_idx,
)

expected_out_shape = query.shape[:-1] + (kv_lora_rank,)
if out is None:
out_shape = query.shape[:-1] + (kv_lora_rank,)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
out = torch.empty(
expected_out_shape, dtype=torch.bfloat16, device=query.device
)
else:
batch_size, _, num_q_heads, _ = query.shape
check_shape_dtype_device(
out,
[batch_size, num_q_heads, kv_lora_rank],
expected_out_shape,
torch.bfloat16,
query.device,
"out",
Expand Down
102 changes: 102 additions & 0 deletions tests/attention/test_trtllm_gen_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,3 +879,105 @@ def test_trtllm_batch_decode_mla_sparse(
qk_nope_head_dim,
num_attn_heads,
)


@pytest.mark.parametrize("q_len_per_request", [1, 2, 4])
@pytest.mark.parametrize("batch_size", [1, 4])
def test_trtllm_batch_decode_mla_preallocated_out(
q_len_per_request: int,
batch_size: int,
):
"""Issue #2856: pre-allocated out tensor rejected when q_len_per_req > 1.
The shape check hardcoded 3D but query is 4D for multi-token generation."""
cc = get_compute_capability(torch.device("cuda"))
if cc[0] != 10:
pytest.skip("trtllm-gen MLA requires SM100/SM103")

device = "cuda:0"
layer_dim = supported_mla_layer_dimensions[0]
kv_lora_rank = layer_dim.head_dimensions.kv_lora_rank
qk_nope_head_dim = layer_dim.head_dimensions.qk_nope_head_dim
qk_rope_head_dim = layer_dim.head_dimensions.qk_rope_head_dim
num_heads = layer_dim.num_heads
head_dim_qk = kv_lora_rank + qk_rope_head_dim

page_size = 64
max_seq_len = 256
num_pages_per_seq = (max_seq_len + page_size - 1) // page_size
head_dim_ckv_kpe = kv_lora_rank + qk_rope_head_dim

kv_cache = torch.randn(
num_pages_per_seq * batch_size,
1,
page_size,
head_dim_ckv_kpe,
dtype=torch.bfloat16,
device=device,
)
block_tables = torch.arange(
num_pages_per_seq * batch_size,
device=device,
dtype=torch.int32,
).reshape(batch_size, num_pages_per_seq)
seq_lens = torch.full((batch_size,), max_seq_len, device=device, dtype=torch.int32)

query = torch.randn(
batch_size,
q_len_per_request,
num_heads,
head_dim_qk,
dtype=torch.bfloat16,
device=device,
)

global global_trtllm_gen_fmha_workspace_buffer
if global_trtllm_gen_fmha_workspace_buffer is None:
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size,
dtype=torch.int8,
device=device,
)
workspace = global_trtllm_gen_fmha_workspace_buffer
Comment on lines +933 to +940
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This workspace buffer initialization logic is duplicated from the trtllm_batch_decode_mla helper function in this file (lines 324-328). To improve maintainability and reduce code duplication, consider creating a pytest fixture to provide the workspace buffer. This would encapsulate the global variable and its initialization logic, making the tests cleaner.

For example, you could add a fixture like this:

@pytest.fixture(scope="module")
def trtllm_gen_fmha_workspace(device="cuda:0"):
    """Provides a zero-initialized workspace buffer for trtllm-gen MLA tests."""
    global global_trtllm_gen_fmha_workspace_buffer
    if global_trtllm_gen_fmha_workspace_buffer is None:
        global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
            workspace_size, dtype=torch.int8, device=device
        )
    return global_trtllm_gen_fmha_workspace_buffer

And then use it in the test signature.


bmm1_scale = 1.0 / (head_dim_qk**0.5)

# out=None should work
result_none = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace,
qk_nope_head_dim=qk_nope_head_dim,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
bmm1_scale=bmm1_scale,
bmm2_scale=1.0,
backend="trtllm-gen",
)
expected_shape = (batch_size, q_len_per_request, num_heads, kv_lora_rank)
assert result_none.shape == expected_shape

# out=pre-allocated should also work (this was the bug)
out = torch.empty(expected_shape, dtype=torch.bfloat16, device=device)
result_pre = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache,
workspace_buffer=workspace,
qk_nope_head_dim=qk_nope_head_dim,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
out=out,
bmm1_scale=bmm1_scale,
bmm2_scale=1.0,
backend="trtllm-gen",
)
assert result_pre.data_ptr() == out.data_ptr(), (
"Expected kernel to write into provided out tensor"
)
assert result_pre.shape == expected_shape
torch.testing.assert_close(result_none, result_pre, rtol=1e-3, atol=1e-3)
Loading