[fix] bugfix 2856: Fix pre-allocated out shape check in trtllm_batch_decode_with_kv_cache_mla for q_len_per_req > 1#2876
[fix] bugfix 2856: Fix pre-allocated out shape check in trtllm_batch_decode_with_kv_cache_mla for q_len_per_req > 1#2876qsang-nv wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
📝 WalkthroughWalkthroughRefactor in Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical bug in the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request correctly fixes a bug in trtllm_batch_decode_with_kv_cache_mla where the shape check for a pre-allocated output tensor was incorrect for multi-token generation (q_len_per_req > 1). The fix unifies the output shape calculation for both the out=None and pre-allocated out cases, which also improves code clarity by removing duplication. A new test case has been added to verify the fix, ensuring that both paths produce identical results. The changes look good. I have one minor suggestion in the new test file to improve maintainability by reducing code duplication.
| global global_trtllm_gen_fmha_workspace_buffer | ||
| if global_trtllm_gen_fmha_workspace_buffer is None: | ||
| global_trtllm_gen_fmha_workspace_buffer = torch.zeros( | ||
| workspace_size, | ||
| dtype=torch.int8, | ||
| device=device, | ||
| ) | ||
| workspace = global_trtllm_gen_fmha_workspace_buffer |
There was a problem hiding this comment.
This workspace buffer initialization logic is duplicated from the trtllm_batch_decode_mla helper function in this file (lines 324-328). To improve maintainability and reduce code duplication, consider creating a pytest fixture to provide the workspace buffer. This would encapsulate the global variable and its initialization logic, making the tests cleaner.
For example, you could add a fixture like this:
@pytest.fixture(scope="module")
def trtllm_gen_fmha_workspace(device="cuda:0"):
"""Provides a zero-initialized workspace buffer for trtllm-gen MLA tests."""
global global_trtllm_gen_fmha_workspace_buffer
if global_trtllm_gen_fmha_workspace_buffer is None:
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
workspace_size, dtype=torch.int8, device=device
)
return global_trtllm_gen_fmha_workspace_bufferAnd then use it in the test signature.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)
887-903: Assert the preallocated buffer is actually reused.Right now the test verifies shape/value equivalence, but it does not guarantee zero extra allocation on the
out=path. Add a pointer check so the regression guard also enforces buffer reuse.Proposed test hardening
result_pre = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( query=query, kv_cache=kv_cache, workspace_buffer=workspace, qk_nope_head_dim=qk_nope_head_dim, kv_lora_rank=kv_lora_rank, qk_rope_head_dim=qk_rope_head_dim, block_tables=block_tables, seq_lens=seq_lens, max_seq_len=max_seq_len, out=out, bmm1_scale=bmm1_scale, bmm2_scale=1.0, backend="trtllm-gen", ) + assert result_pre.data_ptr() == out.data_ptr(), "Expected kernel to write into provided `out` tensor" assert result_pre.shape == expected_shape torch.testing.assert_close(result_none, result_pre, rtol=1e-3, atol=1e-3)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_gen_mla.py` around lines 887 - 903, The test currently checks shape and values but not that the provided preallocated buffer is reused; after calling flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(...) with the out= parameter, assert that the returned tensor reuses the same memory by comparing pointers (e.g. result_pre.data_ptr() == out.data_ptr()); keep the existing torch.testing.assert_close but add this pointer equality assertion (or capture out_ptr before the call and compare to result_pre.data_ptr() after) to enforce zero extra allocation.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 887-903: The test currently checks shape and values but not that
the provided preallocated buffer is reused; after calling
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(...) with the out=
parameter, assert that the returned tensor reuses the same memory by comparing
pointers (e.g. result_pre.data_ptr() == out.data_ptr()); keep the existing
torch.testing.assert_close but add this pointer equality assertion (or capture
out_ptr before the call and compare to result_pre.data_ptr() after) to enforce
zero extra allocation.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 912c5542-2ff0-4ad2-98e9-ce6f7a04f0f8
📒 Files selected for processing (2)
flashinfer/mla.pytests/attention/test_trtllm_gen_mla.py
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)
856-864: Consider zeroing shared workspace on reuse to avoid test-order coupling.At Line 857, the global buffer is only zero-initialized on first allocation. If a prior test mutated it, this test can become order-dependent. Re-zeroing on reuse makes this test self-contained.
Proposed tweak
global global_trtllm_gen_fmha_workspace_buffer if global_trtllm_gen_fmha_workspace_buffer is None: global_trtllm_gen_fmha_workspace_buffer = torch.zeros( workspace_size, dtype=torch.int8, device=device, ) + else: + global_trtllm_gen_fmha_workspace_buffer.zero_() workspace = global_trtllm_gen_fmha_workspace_buffer🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/attention/test_trtllm_gen_mla.py` around lines 856 - 864, The global shared buffer global_trtllm_gen_fmha_workspace_buffer may contain leftovers from prior tests; when reusing it (after the existing allocation check using workspace_size and device), explicitly zero it before assigning workspace (e.g., call its in-place zeroing method such as zero_() or fill_(0) on global_trtllm_gen_fmha_workspace_buffer) so the test does not depend on prior test mutations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 856-864: The global shared buffer
global_trtllm_gen_fmha_workspace_buffer may contain leftovers from prior tests;
when reusing it (after the existing allocation check using workspace_size and
device), explicitly zero it before assigning workspace (e.g., call its in-place
zeroing method such as zero_() or fill_(0) on
global_trtllm_gen_fmha_workspace_buffer) so the test does not depend on prior
test mutations.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 125aae67-88e7-4de3-ab3d-021b4ba04c17
📒 Files selected for processing (1)
tests/attention/test_trtllm_gen_mla.py
|
/bot run |
|
[FAILED] Pipeline #46925815: 13/20 passed |
yzh119
left a comment
There was a problem hiding this comment.
LGTM, should be ready to merge as long as all CI passed.
|
/bot run |
|
[FAILED] Pipeline #46953779: 9/20 passed |
📌 Description
This PR fixes #2856.
trtllm_batch_decode_with_kv_cache_mla rejects a correctly-shaped pre-allocated out tensor when q_len_per_req > 1 (speculative decoding / MTP). The out is None path correctly infers a 4D output shape [B, q_len, H, kv_lora_rank] via query.shape[:-1] + (kv_lora_rank,), but the out is not None path hardcodes a 3D expected shape [B, H, kv_lora_rank], missing the q_len dimension.
The fix unifies both paths to use query.shape[:-1] + (kv_lora_rank,) as the expected output shape.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Refactor
Tests