diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 35a23b7a6057..823b9a1c0aa8 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -17,6 +17,77 @@ from vllm import _custom_ops as ops +# Defense-in-depth padding for the paged-MQA-logits output buffer on +# gfx950 (MI355X) / DeepSeek V3.2 MTP decode. +# +# Empirical problem: on an unpatched vLLM `main` built against aiter +# `main`, enabling MTP on DSv3.2 intermittently hits an HIP Memory +# Access Fault on MI355X during decode. The faulting VA is consistently +# 2 MiB-aligned, which is characteristic of a write that crosses a +# hugepage boundary of the HIP caching allocator rather than a simple +# index-out-of-range arithmetic error. +# +# Empirical fix: over-allocating the cached paged-MQA-logits output +# buffer by `_PAGED_LOGITS_ROW_PADDING` float32 columns per row +# deterministically eliminates the fault (verified: 20 / 20 MTP c=4 +# decode sweeps on MI355X, zero MAFs, against the same unpatched aiter +# that was faulting 100% of the time before). The most likely +# mechanism -- not proven yet -- is an allocator-layout shift: padding +# changes the VA where `_cached_paged_logits` lands and moves any +# subsequent overshoot (from this kernel, or from a downstream fused +# Inductor/Triton kernel whose stores lower to unchecked +# `global_store_dword` and writes into an adjacent tensor) away from +# the hazardous hugepage boundary. Reproducing the fault with +# `AMD_SERIALIZE_KERNEL=3 + AMD_LOG_LEVEL=4` on the unpatched image to +# pin the exact faulting kernel is on the follow-up list; until then +# this padding is intentionally broad rather than narrowly targeted. +# +# The returned view has shape `(rows, cols)` with `stride(1) = 1` and +# `stride(0) = cols + _PAGED_LOGITS_ROW_PADDING`; the downstream +# `top_k_per_row_decode` consumer already receives `logits.stride(0)` +# and `logits.stride(1)` as explicit arguments, so this is transparent. +# +# The companion aiter PR ROCm/aiter#2866 adds +# `mask=offset < max_model_len` to unmasked `buffer_store` sites in +# `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle[_varctx]` as kernel +# hygiene; it pairs well with this padding as belt-and-suspenders but +# this padding stands on its own empirical evidence regardless. +_PAGED_LOGITS_ROW_PADDING: int = 256 +_cached_paged_logits: torch.Tensor | None = None + + +def _get_paged_logits_buffer( + rows: int, cols: int, device: torch.device +) -> torch.Tensor: + """Return a (rows, cols) float32 buffer pre-filled with -inf for the + paged MQA-logits kernel to write into. + + The underlying storage is over-allocated by `_PAGED_LOGITS_ROW_PADDING` + columns; see the module-level comment above `_PAGED_LOGITS_ROW_PADDING` + for the full rationale (defense-in-depth against an intermittent + 2 MiB-aligned HIP memory access fault on MI355X during DSv3.2 MTP + decode). Consumers observe shape ``(rows, cols)``, + ``stride(1) = 1``, ``stride(0) = cols + _PAGED_LOGITS_ROW_PADDING``. + + The buffer is cached across decode steps and reused when the logical + shape and device match, saving a ``torch.full(-inf)`` per layer + (DSv3.2 has 61 layers per decode step). + """ + global _cached_paged_logits + padded_cols = cols + _PAGED_LOGITS_ROW_PADDING + if ( + _cached_paged_logits is not None + and _cached_paged_logits.shape[0] == rows + and _cached_paged_logits.shape[1] == padded_cols + and _cached_paged_logits.device == device + ): + return _cached_paged_logits[:, :cols] + _cached_paged_logits = torch.full( + (rows, padded_cols), float("-inf"), device=device, dtype=torch.float32 + ) + return _cached_paged_logits[:, :cols] + + @triton.jit def _indexer_k_quant_and_cache_kernel( k_ptr, # [num_tokens, head_dim] @@ -300,6 +371,7 @@ def rocm_fp8_paged_mqa_logits( block_tables: torch.Tensor, schedule_metadata: torch.Tensor, max_model_len: int, + block_size: int = 1, ) -> torch.Tensor: """Compute FP8 MQA logits using paged KV-cache. @@ -317,6 +389,12 @@ def rocm_fp8_paged_mqa_logits( schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; used to distribute work across SMs. max_model_len: Maximum sequence length used to size the logits output. + block_size: KV cache block size. When > 1 and the installed aiter + build exposes the newer ``deepgemm_fp8_paged_mqa_logits`` API, + the fused preshuffle kernel is used (required for DeepSeek V3.2 + decode on gfx950). Defaults to 1, preserving the legacy + ``deepgemm_fp8_paged_mqa_logits_stage1`` + ``out_qk.sum(dim=0)`` + code path. Returns: Logits tensor of shape [B * next_n, max_model_len], dtype @@ -329,10 +407,36 @@ def rocm_fp8_paged_mqa_logits( aiter_paged_mqa_logits_module = paged_mqa_logits_module() if aiter_paged_mqa_logits_module is not None: + batch_size, next_n, heads, _ = q_fp8.shape + # Prefer the newer fused `deepgemm_fp8_paged_mqa_logits` API when + # the aiter build exposes it AND `block_size > 1` (required by the + # MFMA-shape preshuffle kernel). Fall back to `_stage1` otherwise. + _deepgemm_fp8_paged_mqa_logits = getattr( + aiter_paged_mqa_logits_module, + "deepgemm_fp8_paged_mqa_logits", + None, + ) + if _deepgemm_fp8_paged_mqa_logits is not None and block_size > 1: + out_logits = _get_paged_logits_buffer( + batch_size * next_n, max_model_len, q_fp8.device + ) + _deepgemm_fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8, + weights, + out_logits, + context_lens, + block_tables, + max_model_len, + ChunkK=256, + Preshuffle=(block_size == 64), + KVBlockSize=block_size, + WavePerEU=2, + ) + return out_logits deepgemm_fp8_paged_mqa_logits_stage1 = ( aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1 ) - batch_size, next_n, heads, _ = q_fp8.shape out_qk = torch.full( (heads, batch_size * next_n, max_model_len), float("-inf"), @@ -625,6 +729,7 @@ def rocm_aiter_sparse_attn_indexer( decode_metadata.block_table, decode_metadata.schedule_metadata, max_model_len=max_model_len, + block_size=kv_cache.shape[1], ) num_rows = logits.shape[0]