Skip to content

[fix] bugfix 2856: Fix pre-allocated out shape check in trtllm_batch_decode_with_kv_cache_mla for q_len_per_req > 1#2876

Open
qsang-nv wants to merge 2 commits intoflashinfer-ai:mainfrom
qsang-nv:issue_2856
Open

[fix] bugfix 2856: Fix pre-allocated out shape check in trtllm_batch_decode_with_kv_cache_mla for q_len_per_req > 1#2876
qsang-nv wants to merge 2 commits intoflashinfer-ai:mainfrom
qsang-nv:issue_2856

Conversation

@qsang-nv
Copy link
Copy Markdown
Collaborator

@qsang-nv qsang-nv commented Mar 24, 2026

📌 Description

This PR fixes #2856.

trtllm_batch_decode_with_kv_cache_mla rejects a correctly-shaped pre-allocated out tensor when q_len_per_req > 1 (speculative decoding / MTP). The out is None path correctly infers a 4D output shape [B, q_len, H, kv_lora_rank] via query.shape[:-1] + (kv_lora_rank,), but the out is not None path hardcodes a 3D expected shape [B, H, kv_lora_rank], missing the q_len dimension.

The fix unifies both paths to use query.shape[:-1] + (kv_lora_rank,) as the expected output shape.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Refactor

    • Improved output-shape handling in MLA batch decode to ensure consistent allocation and validation.
  • Tests

    • Added tests for MLA batch decode with preallocated output buffers, verifying shape correctness, in-place result writing, and numerical agreement across batch sizes and sequence lengths.

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 24, 2026

📝 Walkthrough

Walkthrough

Refactor in trtllm_batch_decode_with_kv_cache_mla unifies expected output-shape computation to handle 4D query tensors consistently; adds a parametrized test validating pre-allocated 4D out behavior for the trtllm-gen MLA decode path.

Changes

Cohort / File(s) Summary
MLA Output Shape Validation
flashinfer/mla.py
Unified output-shape computation: defines expected_out_shape = query.shape[:-1] + (kv_lora_rank,) and uses it for both allocation and validation of user-provided out, fixing rejection of pre-allocated 4D outputs when q_len_per_request > 1.
Preallocated Output Test
tests/attention/test_trtllm_gen_mla.py
Added test_trtllm_batch_decode_mla_preallocated_out (parametrized over q_len_per_request ∈ {1,2,4} and batch_size ∈ {1,4}) that verifies auto-allocation and writing into a preallocated out for the trtllm-gen backend and SM100/SM103 GPUs.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • sricketts
  • aleozlx
  • yongwww
  • yzh119
  • cyx-6
  • bkryu

Poem

🐇 I hopped through shapes with care and glee,
Four dims now fit — no more "why me?"
A test was sown, the bytes align,
Preallocated buffers sing, "All fine!" 🎉

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically identifies the bug being fixed (pre-allocated out shape check in trtllm_batch_decode_with_kv_cache_mla) and the condition under which it occurs (q_len_per_req > 1).
Description check ✅ Passed The description includes the required components: a clear explanation of the bug, link to issue #2856, identification of the root cause (3D vs 4D shape mismatch), and the proposed fix (unified shape computation).
Linked Issues check ✅ Passed The PR directly addresses all coding requirements from issue #2856: unifying shape validation using query.shape[:-1] + (kv_lora_rank,) for both out=None and out-provided paths, preserving dtype/device checks, and enabling pre-allocated 4D buffers for q_len > 1.
Out of Scope Changes check ✅ Passed All changes are directly scoped to the issue: mla.py refactors shape computation for the output validation logic, and test_trtllm_gen_mla.py adds a new test validating the pre-allocated out behavior with parametrized q_len and batch sizes.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves a critical bug in the trtllm_batch_decode_with_kv_cache_mla function that previously prevented the correct use of pre-allocated output tensors during multi-token generation. By standardizing the output shape validation, the change ensures consistent and accurate handling of output buffers, improving the robustness of the system for speculative decoding and multi-token prefill scenarios.

Highlights

  • Bug Fix: Addressed a bug in the trtllm_batch_decode_with_kv_cache_mla function where pre-allocated output tensors were incorrectly rejected when q_len_per_req > 1 due to an incorrect hardcoded 3D shape expectation.
  • Shape Unification: Unified the output shape calculation logic for both out is None and out is not None paths, ensuring the correct 4D shape query.shape[:-1] + (kv_lora_rank,) is consistently used.
  • New Test Case: Introduced a new test test_trtllm_batch_decode_mla_preallocated_out to validate the fix, covering scenarios with q_len_per_req > 1 and pre-allocated output tensors.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug in trtllm_batch_decode_with_kv_cache_mla where the shape check for a pre-allocated output tensor was incorrect for multi-token generation (q_len_per_req > 1). The fix unifies the output shape calculation for both the out=None and pre-allocated out cases, which also improves code clarity by removing duplication. A new test case has been added to verify the fix, ensuring that both paths produce identical results. The changes look good. I have one minor suggestion in the new test file to improve maintainability by reducing code duplication.

Comment on lines +856 to +863
global global_trtllm_gen_fmha_workspace_buffer
if global_trtllm_gen_fmha_workspace_buffer is None:
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size,
dtype=torch.int8,
device=device,
)
workspace = global_trtllm_gen_fmha_workspace_buffer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This workspace buffer initialization logic is duplicated from the trtllm_batch_decode_mla helper function in this file (lines 324-328). To improve maintainability and reduce code duplication, consider creating a pytest fixture to provide the workspace buffer. This would encapsulate the global variable and its initialization logic, making the tests cleaner.

For example, you could add a fixture like this:

@pytest.fixture(scope="module")
def trtllm_gen_fmha_workspace(device="cuda:0"):
    """Provides a zero-initialized workspace buffer for trtllm-gen MLA tests."""
    global global_trtllm_gen_fmha_workspace_buffer
    if global_trtllm_gen_fmha_workspace_buffer is None:
        global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
            workspace_size, dtype=torch.int8, device=device
        )
    return global_trtllm_gen_fmha_workspace_buffer

And then use it in the test signature.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)

887-903: Assert the preallocated buffer is actually reused.

Right now the test verifies shape/value equivalence, but it does not guarantee zero extra allocation on the out= path. Add a pointer check so the regression guard also enforces buffer reuse.

Proposed test hardening
     result_pre = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
         query=query,
         kv_cache=kv_cache,
         workspace_buffer=workspace,
         qk_nope_head_dim=qk_nope_head_dim,
         kv_lora_rank=kv_lora_rank,
         qk_rope_head_dim=qk_rope_head_dim,
         block_tables=block_tables,
         seq_lens=seq_lens,
         max_seq_len=max_seq_len,
         out=out,
         bmm1_scale=bmm1_scale,
         bmm2_scale=1.0,
         backend="trtllm-gen",
     )
+    assert result_pre.data_ptr() == out.data_ptr(), "Expected kernel to write into provided `out` tensor"
     assert result_pre.shape == expected_shape
     torch.testing.assert_close(result_none, result_pre, rtol=1e-3, atol=1e-3)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_mla.py` around lines 887 - 903, The test
currently checks shape and values but not that the provided preallocated buffer
is reused; after calling
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(...) with the out=
parameter, assert that the returned tensor reuses the same memory by comparing
pointers (e.g. result_pre.data_ptr() == out.data_ptr()); keep the existing
torch.testing.assert_close but add this pointer equality assertion (or capture
out_ptr before the call and compare to result_pre.data_ptr() after) to enforce
zero extra allocation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 887-903: The test currently checks shape and values but not that
the provided preallocated buffer is reused; after calling
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(...) with the out=
parameter, assert that the returned tensor reuses the same memory by comparing
pointers (e.g. result_pre.data_ptr() == out.data_ptr()); keep the existing
torch.testing.assert_close but add this pointer equality assertion (or capture
out_ptr before the call and compare to result_pre.data_ptr() after) to enforce
zero extra allocation.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 912c5542-2ff0-4ad2-98e9-ce6f7a04f0f8

📥 Commits

Reviewing files that changed from the base of the PR and between 19bbdd3 and abe810a.

📒 Files selected for processing (2)
  • flashinfer/mla.py
  • tests/attention/test_trtllm_gen_mla.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)

856-864: Consider zeroing shared workspace on reuse to avoid test-order coupling.

At Line 857, the global buffer is only zero-initialized on first allocation. If a prior test mutated it, this test can become order-dependent. Re-zeroing on reuse makes this test self-contained.

Proposed tweak
     global global_trtllm_gen_fmha_workspace_buffer
     if global_trtllm_gen_fmha_workspace_buffer is None:
         global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
             workspace_size,
             dtype=torch.int8,
             device=device,
         )
+    else:
+        global_trtllm_gen_fmha_workspace_buffer.zero_()
     workspace = global_trtllm_gen_fmha_workspace_buffer
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_trtllm_gen_mla.py` around lines 856 - 864, The global
shared buffer global_trtllm_gen_fmha_workspace_buffer may contain leftovers from
prior tests; when reusing it (after the existing allocation check using
workspace_size and device), explicitly zero it before assigning workspace (e.g.,
call its in-place zeroing method such as zero_() or fill_(0) on
global_trtllm_gen_fmha_workspace_buffer) so the test does not depend on prior
test mutations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 856-864: The global shared buffer
global_trtllm_gen_fmha_workspace_buffer may contain leftovers from prior tests;
when reusing it (after the existing allocation check using workspace_size and
device), explicitly zero it before assigning workspace (e.g., call its in-place
zeroing method such as zero_() or fill_(0) on
global_trtllm_gen_fmha_workspace_buffer) so the test does not depend on prior
test mutations.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 125aae67-88e7-4de3-ab3d-021b4ba04c17

📥 Commits

Reviewing files that changed from the base of the PR and between abe810a and 94498c3.

📒 Files selected for processing (1)
  • tests/attention/test_trtllm_gen_mla.py

@qsang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !460 has been created, and the CI pipeline #46925815 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46925815: 13/20 passed

@yzh119 yzh119 added the run-ci label Mar 25, 2026
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, should be ready to merge as long as all CI passed.

@qsang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !460 has been created, and the CI pipeline #46953779 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #46953779: 9/20 passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] trtllm_batch_decode_with_kv_cache_mla rejects pre-allocated out when q_len_per_req > 1

3 participants