Skip to content

[Bugfix][Backport][Hardware][AMD] Backport PR #31282 to v0.13.0: Fix last_page_len calculation in AITER MLA decode#38146

Closed
khairulkabir1661 wants to merge 2 commits intovllm-project:releases/v0.13.0from
khairulkabir1661:v0.13.0-with-pr31282
Closed

[Bugfix][Backport][Hardware][AMD] Backport PR #31282 to v0.13.0: Fix last_page_len calculation in AITER MLA decode#38146
khairulkabir1661 wants to merge 2 commits intovllm-project:releases/v0.13.0from
khairulkabir1661:v0.13.0-with-pr31282

Conversation

@khairulkabir1661
Copy link
Copy Markdown
Contributor

@khairulkabir1661 khairulkabir1661 commented Mar 25, 2026

Summary

Backport of #31282 to v0.13.0: Fixes incorrect paged_kv_last_page_len calculation in the ROCm AITER MLA decode path.

The Bug

The AITER MLA kernel uses a block size of 1 (each page contains exactly 1 token). However, the paged_kv_last_page_len was incorrectly set to the full sequence length:

# BEFORE (buggy):
paged_kv_last_page_len = torch.where(seq_lens_device == 0, 1, seq_lens_device)

For a sequence of 127 tokens, this would set last_page_len=127, telling the kernel the last page has 127 tokens when it only has 1.

Impact

This bug could cause:

  • Incorrect attention score calculations for sequences with non-standard lengths
  • Potential out-of-bounds memory access in the MLA decode kernel
  • Unpredictable behavior for sequences with prime-number lengths

The Fix

Per @ganyi1996ppo's suggestion, the persistent buffer is now pre-initialized to 1s for efficiency:

# In __init__:
self.paged_kv_last_page_len = torch.ones(
    max_num_reqs, dtype=torch.int32, device=device
)

# In _build_decode (cudagraph path):
# Just slice the pre-initialized buffer - no fill needed
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]

Comparison with Other Backends

The FlashInfer backend correctly computes last_page_len using modulo:

paged_kv_last_page_len_np = seq_lens_np % page_size

For AITER MLA with page_size=1, this would always yield 0, which then becomes 1. Our fix directly initializes to 1, which is equivalent and more efficient.

Test Plan

MLA-Specific Validation (DeepSeek-V2-Lite)

VLLM_USE_V1=1 python -c "                                                                                                                                                                                                   
from vllm import LLM, SamplingParams                                                                                                                                                                                          
                                                                                                                                                                                                                              
# DeepSeek-V2-Lite uses MLA architecture                                                                                                                                                                                      
llm = LLM(                                                                                                                                                                                                                    
    model='deepseek-ai/DeepSeek-V2-Lite',                                                                                                                                                                                     
    max_model_len=512,                                                                                                                                                                                                        
    trust_remote_code=True                                                                                                                                                                                                    
)                                                                                                                                                                                                                             
                                                                                                                                                                                                                              
# Test with prime-number token counts (most likely to expose the bug)                                                                                                                                                         
params = SamplingParams(max_tokens=127, temperature=0.0)                                                                                                                                                                    
outputs = llm.generate(['Explain quantum computing.'], params)                                                                                                                                                                
print(f'Generated {len(outputs[0].outputs[0].token_ids)} tokens')                                                                                                                                                             
"        

Test Criteria

No crash or hang during MLA decode
Output is coherent (not garbage)
Works with prime-number token counts (127, 131)
Pre-commit CI passes
DCO check passes

Impact

  • Fixes potential correctness issues in AITER MLA decode on v0.13.0
  • Required for DeepSeek-V2/V3 models using MLA architecture on ROCm
  • More efficient: pre-initialized buffer eliminates redundant fills
  • 1 file changed, 12 insertions(+), 9 deletions(-)

Related

c0de128 added 2 commits March 25, 2026 15:35
Per review feedback from @ganyi1996ppo: Initialize the persistent
buffer to 1 during startup instead of 0, then just slice it during
metadata preparation. This is more efficient than allocating a new
tensor and filling it each decode call.

Changes:
- Initialize paged_kv_last_page_len with torch.ones() instead of zeros()
- Remove redundant fill_(1) call since buffer is pre-filled with 1s

Signed-off-by: c0de128 <kevin.mckay@outlook.com>
Per reviewer suggestion, eliminate redundant torch.ones() allocation
in _build_decode by reusing the pre-initialized buffer slice directly.

Changes:
- Move paged_kv_last_page_len initialization outside cudagraph block
- Replace per-call allocation with buffer slice in _build_decode
- Remove unnecessary copy operation in cudagraph mode

This is more efficient since we avoid allocation on every decode call.

Signed-off-by: c0de128 <kevin.mckay@outlook.com>
Copy link
Copy Markdown

@claude claude bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@mergify mergify bot added rocm Related to AMD ROCm v1 bug Something isn't working labels Mar 25, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Mar 25, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the handling of paged_kv_last_page_len in rocm_aiter_mla.py. It pre-initializes this tensor with ones, leveraging the fact that the kernel block size is always 1, meaning each page contains a single token. This change eliminates redundant calculations and memory copies for paged_kv_last_page_len during decode operations, particularly benefiting cudagraph mode by reusing a pre-initialized buffer slice. There is no feedback to provide.

@markmc markmc closed this Mar 26, 2026
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Mar 26, 2026
@khairulkabir1661 khairulkabir1661 deleted the v0.13.0-with-pr31282 branch March 26, 2026 17:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working rocm Related to AMD ROCm v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants