[Bugfix][Backport][Hardware][AMD] Backport PR #31282 to v0.13.0: Fix last_page_len calculation in AITER MLA decode#38146
Closed
khairulkabir1661 wants to merge 2 commits intovllm-project:releases/v0.13.0from
Conversation
Per review feedback from @ganyi1996ppo: Initialize the persistent buffer to 1 during startup instead of 0, then just slice it during metadata preparation. This is more efficient than allocating a new tensor and filling it each decode call. Changes: - Initialize paged_kv_last_page_len with torch.ones() instead of zeros() - Remove redundant fill_(1) call since buffer is pre-filled with 1s Signed-off-by: c0de128 <kevin.mckay@outlook.com>
Per reviewer suggestion, eliminate redundant torch.ones() allocation in _build_decode by reusing the pre-initialized buffer slice directly. Changes: - Move paged_kv_last_page_len initialization outside cudagraph block - Replace per-call allocation with buffer slice in _build_decode - Remove unnecessary copy operation in cudagraph mode This is more efficient since we avoid allocation on every decode call. Signed-off-by: c0de128 <kevin.mckay@outlook.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request refactors the handling of paged_kv_last_page_len in rocm_aiter_mla.py. It pre-initializes this tensor with ones, leveraging the fact that the kernel block size is always 1, meaning each page contains a single token. This change eliminates redundant calculations and memory copies for paged_kv_last_page_len during decode operations, particularly benefiting cudagraph mode by reusing a pre-initialized buffer slice. There is no feedback to provide.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Backport of #31282 to v0.13.0: Fixes incorrect
paged_kv_last_page_lencalculation in the ROCm AITER MLA decode path.The Bug
The AITER MLA kernel uses a block size of 1 (each page contains exactly 1 token). However, the
paged_kv_last_page_lenwas incorrectly set to the full sequence length:For a sequence of 127 tokens, this would set
last_page_len=127, telling the kernel the last page has 127 tokens when it only has 1.Impact
This bug could cause:
The Fix
Per @ganyi1996ppo's suggestion, the persistent buffer is now pre-initialized to 1s for efficiency:
Comparison with Other Backends
The FlashInfer backend correctly computes
last_page_lenusing modulo:For AITER MLA with
page_size=1, this would always yield 0, which then becomes 1. Our fix directly initializes to 1, which is equivalent and more efficient.Test Plan
MLA-Specific Validation (DeepSeek-V2-Lite)
Test Criteria
No crash or hang during MLA decode
Output is coherent (not garbage)
Works with prime-number token counts (127, 131)
Pre-commit CI passes
DCO check passes
Impact
Related
releases/v0.13.0