From 6a2f8790dcfad4a4206f11c3f18d3dca9f4f3fc7 Mon Sep 17 00:00:00 2001 From: Frida Andersson Date: Wed, 6 May 2026 18:46:13 +0000 Subject: [PATCH] [ROCm][Bugfix] Add +256 col guard to preshuffle logits buffer (DSv3.2) The AITER gluon preshuffle kernel (_gluon_deepgemm_fp8_paged_mqa_logits_ preshuffle) performs unmasked buffer_store writes up to ~190 float32 elements past context_length in each logits row when block_size=64. With the previous exact-size allocation those stores corrupt the logits of the adjacent row, causing wrong top-k selection and degenerate output. Fix: introduce _get_paged_logits_buffer that allocates (rows, cols + _PAGED_LOGITS_COL_PADDING) where _PAGED_LOGITS_COL_PADDING=256. A non-contiguous [:rows, :cols] slice is intentionally avoided: deepgemm_fp8_paged_mqa_logits assumes contiguous output and would compute incorrect row offsets from a non-contiguous tensor. The full contiguous allocation ensures stride(0) = cols + 256 consistently; the padding columns absorb the OOB writes. top_k_per_row_decode takes logits.stride(0) and logits.stride(1) as explicit arguments and bounds iteration by seq_lens, so the extra columns are never read. A fresh allocation per call (no global cache) ensures each HIP graph bucket owns its own stable tensor pointer; a shared global reallocated for a larger bucket would leave earlier-captured graphs with dangling pointers on replay. Also fixes device="cuda" -> q_fp8.device so TP ranks > 0 allocate on the correct GPU. Validated: GSM8K 5-shot flexible-extract 0.9416 on TP4 with HIP graphs and block_size=64 (reference fork: 0.9409). Related: #40643 (maeehart: same padding with caching, draft pending MAF investigation at num_speculative_tokens=2). Co-authored-by: Markus Hartikainen Signed-off-by: Frida Andersson --- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 51 +++++++++++++++++-- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 5d0343ffd607..4838b9514f90 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -20,6 +20,50 @@ else: _ON_GFX942 = False +# Over-allocate each logits row by this many float32 columns to absorb +# out-of-bounds stores from the AITER preshuffle kernel's unmasked +# buffer_store (up to ~190 elements past context_length). The downstream +# top_k_per_row_decode op takes stride(0)/stride(1) explicitly, so the +# widened stride is transparent. +_PAGED_LOGITS_COL_PADDING = 256 + + +def _get_paged_logits_buffer( + rows: int, cols: int, device: torch.device +) -> torch.Tensor: + """Return a fresh contiguous (rows, cols + _PAGED_LOGITS_COL_PADDING) + float32 tensor pre-filled with -inf. + + **Why the return shape is wider than `cols`:** + Returning a narrower ``[:rows, :cols]`` slice would produce a + non-contiguous tensor (stride(0) = cols + _PAGED_LOGITS_COL_PADDING, + shape[1] = cols). ``deepgemm_fp8_paged_mqa_logits`` writes into this + buffer and assumes a contiguous layout; passing a non-contiguous slice + causes it to compute incorrect row offsets, corrupting logits across rows. + Returning the full contiguous allocation ensures the kernel sees a + consistent stride(0) = cols + _PAGED_LOGITS_COL_PADDING, with the padding + columns absorbing the ~190 out-of-bounds float32 writes the AITER + preshuffle kernel can emit past context_length. + + The sole consumer of the returned tensor is ``top_k_per_row_decode``, + which takes ``logits.stride(0)`` and ``logits.stride(1)`` as explicit + arguments and bounds iteration via ``seq_lens``, so the extra columns are + never read. + + **Why allocate fresh each call (no caching):** + vLLM captures a separate HIP graph for every batch-size bucket and records + the pointer of whatever tensor was live at capture time. A shared global + that gets reallocated for a larger bucket leaves earlier-captured graphs + with dangling pointers on replay. Each fresh allocation is owned by the + graph that captured it for the graph's lifetime. + """ + return torch.full( + (rows, cols + _PAGED_LOGITS_COL_PADDING), + float("-inf"), + device=device, + dtype=torch.float32, + ) + @triton.jit def _indexer_k_quant_and_cache_kernel( @@ -390,11 +434,8 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits ) batch_size, next_n, heads, _ = q_fp8.shape - out_logits = torch.full( - [batch_size * next_n, max_model_len], - float("-inf"), - device="cuda", - dtype=torch.float32, + out_logits = _get_paged_logits_buffer( + batch_size * next_n, max_model_len, q_fp8.device ) deepgemm_fp8_paged_mqa_logits( q_fp8,