diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 5d0343ffd607..4838b9514f90 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -20,6 +20,50 @@ else: _ON_GFX942 = False +# Over-allocate each logits row by this many float32 columns to absorb +# out-of-bounds stores from the AITER preshuffle kernel's unmasked +# buffer_store (up to ~190 elements past context_length). The downstream +# top_k_per_row_decode op takes stride(0)/stride(1) explicitly, so the +# widened stride is transparent. +_PAGED_LOGITS_COL_PADDING = 256 + + +def _get_paged_logits_buffer( + rows: int, cols: int, device: torch.device +) -> torch.Tensor: + """Return a fresh contiguous (rows, cols + _PAGED_LOGITS_COL_PADDING) + float32 tensor pre-filled with -inf. + + **Why the return shape is wider than `cols`:** + Returning a narrower ``[:rows, :cols]`` slice would produce a + non-contiguous tensor (stride(0) = cols + _PAGED_LOGITS_COL_PADDING, + shape[1] = cols). ``deepgemm_fp8_paged_mqa_logits`` writes into this + buffer and assumes a contiguous layout; passing a non-contiguous slice + causes it to compute incorrect row offsets, corrupting logits across rows. + Returning the full contiguous allocation ensures the kernel sees a + consistent stride(0) = cols + _PAGED_LOGITS_COL_PADDING, with the padding + columns absorbing the ~190 out-of-bounds float32 writes the AITER + preshuffle kernel can emit past context_length. + + The sole consumer of the returned tensor is ``top_k_per_row_decode``, + which takes ``logits.stride(0)`` and ``logits.stride(1)`` as explicit + arguments and bounds iteration via ``seq_lens``, so the extra columns are + never read. + + **Why allocate fresh each call (no caching):** + vLLM captures a separate HIP graph for every batch-size bucket and records + the pointer of whatever tensor was live at capture time. A shared global + that gets reallocated for a larger bucket leaves earlier-captured graphs + with dangling pointers on replay. Each fresh allocation is owned by the + graph that captured it for the graph's lifetime. + """ + return torch.full( + (rows, cols + _PAGED_LOGITS_COL_PADDING), + float("-inf"), + device=device, + dtype=torch.float32, + ) + @triton.jit def _indexer_k_quant_and_cache_kernel( @@ -390,11 +434,8 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits ) batch_size, next_n, heads, _ = q_fp8.shape - out_logits = torch.full( - [batch_size * next_n, max_model_len], - float("-inf"), - device="cuda", - dtype=torch.float32, + out_logits = _get_paged_logits_buffer( + batch_size * next_n, max_model_len, q_fp8.device ) deepgemm_fp8_paged_mqa_logits( q_fp8,