Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,50 @@
else:
_ON_GFX942 = False

# Over-allocate each logits row by this many float32 columns to absorb
# out-of-bounds stores from the AITER preshuffle kernel's unmasked
# buffer_store (up to ~190 elements past context_length). The downstream
# top_k_per_row_decode op takes stride(0)/stride(1) explicitly, so the
# widened stride is transparent.
_PAGED_LOGITS_COL_PADDING = 256


def _get_paged_logits_buffer(
rows: int, cols: int, device: torch.device
) -> torch.Tensor:
"""Return a fresh contiguous (rows, cols + _PAGED_LOGITS_COL_PADDING)
float32 tensor pre-filled with -inf.

**Why the return shape is wider than `cols`:**
Returning a narrower ``[:rows, :cols]`` slice would produce a
non-contiguous tensor (stride(0) = cols + _PAGED_LOGITS_COL_PADDING,
shape[1] = cols). ``deepgemm_fp8_paged_mqa_logits`` writes into this
buffer and assumes a contiguous layout; passing a non-contiguous slice
causes it to compute incorrect row offsets, corrupting logits across rows.
Returning the full contiguous allocation ensures the kernel sees a
consistent stride(0) = cols + _PAGED_LOGITS_COL_PADDING, with the padding
columns absorbing the ~190 out-of-bounds float32 writes the AITER
preshuffle kernel can emit past context_length.

The sole consumer of the returned tensor is ``top_k_per_row_decode``,
which takes ``logits.stride(0)`` and ``logits.stride(1)`` as explicit
arguments and bounds iteration via ``seq_lens``, so the extra columns are
never read.

**Why allocate fresh each call (no caching):**
vLLM captures a separate HIP graph for every batch-size bucket and records
the pointer of whatever tensor was live at capture time. A shared global
that gets reallocated for a larger bucket leaves earlier-captured graphs
with dangling pointers on replay. Each fresh allocation is owned by the
graph that captured it for the graph's lifetime.
"""
return torch.full(
(rows, cols + _PAGED_LOGITS_COL_PADDING),
float("-inf"),
device=device,
dtype=torch.float32,
)


@triton.jit
def _indexer_k_quant_and_cache_kernel(
Expand Down Expand Up @@ -390,11 +434,8 @@ def rocm_fp8_paged_mqa_logits(
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits
)
batch_size, next_n, heads, _ = q_fp8.shape
out_logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device="cuda",
dtype=torch.float32,
out_logits = _get_paged_logits_buffer(
batch_size * next_n, max_model_len, q_fp8.device
)
deepgemm_fp8_paged_mqa_logits(
q_fp8,
Expand Down
Loading