Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 106 additions & 1 deletion vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,77 @@
from vllm import _custom_ops as ops


# Defense-in-depth padding for the paged-MQA-logits output buffer on
# gfx950 (MI355X) / DeepSeek V3.2 MTP decode.
#
# Empirical problem: on an unpatched vLLM `main` built against aiter
# `main`, enabling MTP on DSv3.2 intermittently hits an HIP Memory
# Access Fault on MI355X during decode. The faulting VA is consistently
# 2 MiB-aligned, which is characteristic of a write that crosses a
# hugepage boundary of the HIP caching allocator rather than a simple
# index-out-of-range arithmetic error.
#
# Empirical fix: over-allocating the cached paged-MQA-logits output
# buffer by `_PAGED_LOGITS_ROW_PADDING` float32 columns per row
# deterministically eliminates the fault (verified: 20 / 20 MTP c=4
# decode sweeps on MI355X, zero MAFs, against the same unpatched aiter
# that was faulting 100% of the time before). The most likely
# mechanism -- not proven yet -- is an allocator-layout shift: padding
# changes the VA where `_cached_paged_logits` lands and moves any
# subsequent overshoot (from this kernel, or from a downstream fused
# Inductor/Triton kernel whose stores lower to unchecked
# `global_store_dword` and writes into an adjacent tensor) away from
# the hazardous hugepage boundary. Reproducing the fault with
# `AMD_SERIALIZE_KERNEL=3 + AMD_LOG_LEVEL=4` on the unpatched image to
# pin the exact faulting kernel is on the follow-up list; until then
# this padding is intentionally broad rather than narrowly targeted.
#
# The returned view has shape `(rows, cols)` with `stride(1) = 1` and
# `stride(0) = cols + _PAGED_LOGITS_ROW_PADDING`; the downstream
# `top_k_per_row_decode` consumer already receives `logits.stride(0)`
# and `logits.stride(1)` as explicit arguments, so this is transparent.
#
# The companion aiter PR ROCm/aiter#2866 adds
# `mask=offset < max_model_len` to unmasked `buffer_store` sites in
# `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle[_varctx]` as kernel
# hygiene; it pairs well with this padding as belt-and-suspenders but
# this padding stands on its own empirical evidence regardless.
_PAGED_LOGITS_ROW_PADDING: int = 256
_cached_paged_logits: torch.Tensor | None = None


def _get_paged_logits_buffer(
rows: int, cols: int, device: torch.device
) -> torch.Tensor:
"""Return a (rows, cols) float32 buffer pre-filled with -inf for the
paged MQA-logits kernel to write into.

The underlying storage is over-allocated by `_PAGED_LOGITS_ROW_PADDING`
columns; see the module-level comment above `_PAGED_LOGITS_ROW_PADDING`
for the full rationale (defense-in-depth against an intermittent
2 MiB-aligned HIP memory access fault on MI355X during DSv3.2 MTP
decode). Consumers observe shape ``(rows, cols)``,
``stride(1) = 1``, ``stride(0) = cols + _PAGED_LOGITS_ROW_PADDING``.

The buffer is cached across decode steps and reused when the logical
shape and device match, saving a ``torch.full(-inf)`` per layer
(DSv3.2 has 61 layers per decode step).
"""
global _cached_paged_logits
padded_cols = cols + _PAGED_LOGITS_ROW_PADDING
if (
_cached_paged_logits is not None
and _cached_paged_logits.shape[0] == rows
and _cached_paged_logits.shape[1] == padded_cols
and _cached_paged_logits.device == device
):
return _cached_paged_logits[:, :cols]
_cached_paged_logits = torch.full(
(rows, padded_cols), float("-inf"), device=device, dtype=torch.float32
)
return _cached_paged_logits[:, :cols]
Comment on lines +78 to +88
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current caching logic for _cached_paged_logits reallocates the buffer whenever the number of rows (batch size * next_n) changes, including when it decreases. In dynamic scheduling scenarios common in vLLM, this leads to frequent VRAM reallocations and torch.full calls, which can degrade performance and increase memory fragmentation.

Additionally, the function returns _cached_paged_logits[:, :cols]. If the reallocation logic is improved to allow reusing a larger buffer (using >= rows), this slice would return more rows than currently requested. Since downstream operations like top_k_per_row_decode rely on logits.shape[0] to determine the number of rows to process, this could cause out-of-bounds reads on other metadata tensors like seq_lens if the buffer is larger than the current batch.

It is recommended to allow reuse when the buffer is large enough and to return a strictly sized view. Note that this assumes the kernel or downstream operations handle masking of stale data within the (rows, cols) area, which seems to be the case given the existing reuse logic.

Suggested change
if (
_cached_paged_logits is not None
and _cached_paged_logits.shape[0] == rows
and _cached_paged_logits.shape[1] == padded_cols
and _cached_paged_logits.device == device
):
return _cached_paged_logits[:, :cols]
_cached_paged_logits = torch.full(
(rows, padded_cols), float("-inf"), device=device, dtype=torch.float32
)
return _cached_paged_logits[:, :cols]
if (
_cached_paged_logits is not None
and _cached_paged_logits.shape[0] >= rows
and _cached_paged_logits.shape[1] == padded_cols
and _cached_paged_logits.device == device
):
return _cached_paged_logits[:rows, :cols]
_cached_paged_logits = torch.full(
(rows, padded_cols), float("-inf"), device=device, dtype=torch.float32
)
return _cached_paged_logits[:rows, :cols]



@triton.jit
def _indexer_k_quant_and_cache_kernel(
k_ptr, # [num_tokens, head_dim]
Expand Down Expand Up @@ -300,6 +371,7 @@ def rocm_fp8_paged_mqa_logits(
block_tables: torch.Tensor,
schedule_metadata: torch.Tensor,
max_model_len: int,
block_size: int = 1,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache.

Expand All @@ -317,6 +389,12 @@ def rocm_fp8_paged_mqa_logits(
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.
block_size: KV cache block size. When > 1 and the installed aiter
build exposes the newer ``deepgemm_fp8_paged_mqa_logits`` API,
the fused preshuffle kernel is used (required for DeepSeek V3.2
decode on gfx950). Defaults to 1, preserving the legacy
``deepgemm_fp8_paged_mqa_logits_stage1`` + ``out_qk.sum(dim=0)``
code path.

Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
Expand All @@ -329,10 +407,36 @@ def rocm_fp8_paged_mqa_logits(
aiter_paged_mqa_logits_module = paged_mqa_logits_module()

if aiter_paged_mqa_logits_module is not None:
batch_size, next_n, heads, _ = q_fp8.shape
# Prefer the newer fused `deepgemm_fp8_paged_mqa_logits` API when
# the aiter build exposes it AND `block_size > 1` (required by the
# MFMA-shape preshuffle kernel). Fall back to `_stage1` otherwise.
_deepgemm_fp8_paged_mqa_logits = getattr(
aiter_paged_mqa_logits_module,
"deepgemm_fp8_paged_mqa_logits",
None,
)
if _deepgemm_fp8_paged_mqa_logits is not None and block_size > 1:
out_logits = _get_paged_logits_buffer(
batch_size * next_n, max_model_len, q_fp8.device
)
_deepgemm_fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
weights,
out_logits,
context_lens,
block_tables,
max_model_len,
ChunkK=256,
Preshuffle=(block_size == 64),
KVBlockSize=block_size,
WavePerEU=2,
)
return out_logits
deepgemm_fp8_paged_mqa_logits_stage1 = (
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1
)
batch_size, next_n, heads, _ = q_fp8.shape
out_qk = torch.full(
(heads, batch_size * next_n, max_model_len),
float("-inf"),
Expand Down Expand Up @@ -625,6 +729,7 @@ def rocm_aiter_sparse_attn_indexer(
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
block_size=kv_cache.shape[1],
)

num_rows = logits.shape[0]
Expand Down
Loading