Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_name() -> str:

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1 if current_platform.is_rocm() else 64]
return [64]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
Expand Down
66 changes: 48 additions & 18 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import (
AttentionBackend,
Expand Down Expand Up @@ -86,11 +87,15 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
"auto",
"float16",
"bfloat16",
"fp8_e4m3",
"fp8_e5m2",
"fp8_e4m3fnuz",
"fp8_e5m2fnuz",
]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]
return [64]

@staticmethod
def get_name() -> str:
Expand Down Expand Up @@ -146,7 +151,7 @@ class ROCMAiterMLASparseMetadata(AttentionMetadata):
paged_kv_indptr: torch.Tensor
paged_kv_indptr_rest: torch.Tensor

block_size: int = 1
block_size: int = 64
topk_tokens: int = 2048


Expand Down Expand Up @@ -197,8 +202,6 @@ def __init__(
max_num_batched_tokens, dtype=torch.int32, device=device
)

# These two needs to be calculated in runtime,
# but we still needs to prepare the buffer
self.paged_kv_indices = torch.zeros(
[max_num_batched_tokens * self.topk_tokens],
dtype=torch.int32,
Expand Down Expand Up @@ -314,26 +317,51 @@ def __init__(
assert indexer is not None
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer

def _forward_bf16_kv(
def _forward_sparse_mla(
self,
q: torch.Tensor, # [sq, heads, d_qk]
kv_c_and_k_pe_cache: torch.Tensor, # [blocks, heads, d_qk]
topk_indices: torch.Tensor, # [sq, topk]
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
topk_indices: torch.Tensor,
attn_metadata: ROCMAiterMLASparseMetadata,
layer: AttentionLayer,
) -> torch.Tensor:
num_tokens = q.shape[0]
attn_out_dtype = q.dtype
mla_kwargs: dict = {}

if self.kv_cache_dtype in ("fp8_e4m3", "fp8_e5m2",
"fp8_e4m3fnuz", "fp8_e5m2fnuz"):
fp8_dtype = current_platform.fp8_dtype()
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(fp8_dtype)
mla_kwargs["k_scale"] = layer._k_scale
mla_kwargs["v_scale"] = layer._v_scale

# mla_decode_fwd uses page_size=1 internally. When block_size > 1,
# flatten [num_pages, block_size, head_size] ->
# [num_pages * block_size, 1, head_size] so flat token indices work.
if kv_c_and_k_pe_cache.dim() >= 2 and kv_c_and_k_pe_cache.shape[1] != 1:
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.reshape(
-1, 1, kv_c_and_k_pe_cache.shape[-1]
)

mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads)
qo_indptr = attn_metadata.qo_indptr[:num_tokens + 1]
kv_last_page_len = attn_metadata.paged_kv_last_page_len[:num_tokens]

seq_len = (topk_indices != -1).sum(dim=-1)
torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:])
attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1])
kv_indptr = attn_metadata.paged_kv_indptr[:num_tokens + 1]

output = torch.empty(
[num_tokens, mla_num_heads, self.kv_lora_rank],
dtype=q.dtype,
dtype=attn_out_dtype,
device=q.device,
)
seq_len = (topk_indices != -1).sum(dim=-1)
torch.cumsum(seq_len, dim=0, out=attn_metadata.paged_kv_indptr[1:])
attn_metadata.paged_kv_indptr_rest.fill_(attn_metadata.paged_kv_indptr[-1])

fetch_id_to_ragged_triton(
topk_indices,
attn_metadata.paged_kv_indptr,
kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.topk_tokens,
)
Expand All @@ -343,11 +371,12 @@ def _forward_bf16_kv(
kv_c_and_k_pe_cache,
output,
self.scale,
attn_metadata.qo_indptr,
qo_indptr,
1,
attn_metadata.paged_kv_indptr,
kv_indptr,
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len,
kv_last_page_len,
**mla_kwargs,
)

return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, output)
Expand Down Expand Up @@ -381,8 +410,9 @@ def forward_mqa(
)

mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q)
attn_out = self._forward_bf16_kv(
mla_padded_q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata
attn_out = self._forward_sparse_mla(
mla_padded_q, kv_c_and_k_pe_cache, topk_indices_global,
attn_metadata, layer
)

return attn_out, None
139 changes: 113 additions & 26 deletions vllm/v1/attention/ops/rocm_aiter_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,55 @@
import torch

from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import LayerNameType
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton

logger = init_logger(__name__)

_AITER_MQA_SMALL_HEADS_WARNED = False

_cached_paged_logits: torch.Tensor | None = None

# Over-allocate each logits row by this many float32 columns to absorb
# out-of-bounds stores from the AITER preshuffle kernel's unmasked
# buffer_store (up to ~190 elements past context_length). The downstream
# top_k_per_row_decode op takes stride(0)/stride(1) explicitly, so the
# widened stride is transparent. Credit: maeehart (vllm-project/vllm#40643).
_PAGED_LOGITS_COL_PADDING = 256


def _get_paged_logits_buffer(
rows: int, cols: int, device: torch.device
) -> torch.Tensor:
"""Return a (rows, cols) float32 view pre-filled with -inf.

Within a decode step every layer sees the same (batch*next_n,
actual_max_seq_len) shape, so the expensive torch.full call only
happens once per step (or when the shape changes). The backing
storage is wider by _PAGED_LOGITS_COL_PADDING columns to guard
against preshuffle kernel OOB writes.
"""
global _cached_paged_logits
padded_cols = cols + _PAGED_LOGITS_COL_PADDING
if (
_cached_paged_logits is not None
and _cached_paged_logits.shape[0] >= rows
and _cached_paged_logits.shape[1] >= padded_cols
and _cached_paged_logits.device == device
):
buf = _cached_paged_logits[:rows, :cols]
buf.fill_(float("-inf"))
return buf
_cached_paged_logits = torch.full(
(rows, padded_cols), float("-inf"), device=device, dtype=torch.float32
)
return _cached_paged_logits[:rows, :cols]


if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops

Expand Down Expand Up @@ -111,7 +154,7 @@ def indexer_k_quant_and_cache_triton(
block_size,
num_tokens,
head_dim,
"NHD",
"SHUFFLE",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The quantization kernel is now hardcoded to use the "SHUFFLE" layout. However, the attention logic in rocm_fp8_paged_mqa_logits only enables Preshuffle=True when block_size == 64 (line 408) and falls back to the stage1 kernel for block_size == 1 (line 413). The stage1 kernel and the gluon kernel with Preshuffle=False expect the standard "NHD" layout. This inconsistency will lead to incorrect results for any block_size other than 64. The layout should be conditional on the block size.

Suggested change
"SHUFFLE",
"SHUFFLE" if block_size == 64 else "NHD",

block_tile_size,
head_tile_size,
IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz,
Expand Down Expand Up @@ -212,7 +255,7 @@ def cp_gather_indexer_k_quant_cache_triton(
block_table_stride,
k_cache_value.stride(0),
k_cache_scale.stride(0),
"NHD",
"SHUFFLE",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the indexer quantization kernel, this gather kernel should only use the "SHUFFLE" layout when the block size is 64 to maintain compatibility with the attention kernels and the fallback paths.

Suggested change
"SHUFFLE",
"SHUFFLE" if block_size == 64 else "NHD",

head_dim,
block_tile_size,
head_tile_size,
Expand Down Expand Up @@ -300,6 +343,7 @@ def rocm_fp8_paged_mqa_logits(
block_tables: torch.Tensor,
schedule_metadata: torch.Tensor,
max_model_len: int,
block_size: int = 1,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache.

Expand All @@ -317,41 +361,80 @@ def rocm_fp8_paged_mqa_logits(
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.
block_size: KV cache block size (default 1).

Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
global _AITER_MQA_SMALL_HEADS_WARNED
from vllm._aiter_ops import rocm_aiter_ops

batch_size, next_n, heads, _ = q_fp8.shape

aiter_paged_mqa_logits_module = None
if rocm_aiter_ops.is_enabled():
if rocm_aiter_ops.is_enabled() and heads >= 16:
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
elif rocm_aiter_ops.is_enabled() and not _AITER_MQA_SMALL_HEADS_WARNED:
logger.warning(
"AITER paged MQA logits kernel does not support %d heads "
"(requires >= 16). Falling back to PyTorch reference.",
heads,
)
_AITER_MQA_SMALL_HEADS_WARNED = True

if aiter_paged_mqa_logits_module is not None:
deepgemm_fp8_paged_mqa_logits_stage1 = (
aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1
_deepgemm_fp8_paged_mqa_logits = getattr(
aiter_paged_mqa_logits_module,
"deepgemm_fp8_paged_mqa_logits",
None,
)
batch_size, next_n, heads, _ = q_fp8.shape
out_qk = torch.full(
(heads, batch_size * next_n, max_model_len),
float("-inf"),
device="cuda",
dtype=torch.float32,
use_new_api = (
_deepgemm_fp8_paged_mqa_logits is not None and block_size > 1
)
# TODO: 1. Replace _stage1 and out_qk.sum with another fused variant;
# 2. Remove ChunkQ when AITER PR #2891 merged
deepgemm_fp8_paged_mqa_logits_stage1(
q_fp8,
kv_cache_fp8,
weights,
out_qk,
context_lens,
block_tables,
max_model_len,
ChunkQ=heads,
)
return out_qk.sum(dim=0)
if use_new_api:
out_logits = _get_paged_logits_buffer(
batch_size * next_n, max_model_len, q_fp8.device
)
_deepgemm_fp8_paged_mqa_logits(
q_fp8,
kv_cache_fp8,
weights,
out_logits,
context_lens,
block_tables,
max_model_len,
ChunkK=256,
Preshuffle=block_size == 64,
KVBlockSize=block_size,
WavePerEU=2,
)
return out_logits
else:
_stage1 = (
aiter_paged_mqa_logits_module
.deepgemm_fp8_paged_mqa_logits_stage1
)
out_qk = torch.full(
(heads, batch_size * next_n, max_model_len),
float("-inf"),
device="cuda",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding device="cuda" can cause runtime errors in multi-GPU environments or if the input tensors are on a specific device that is not the current default. It is safer to use the device of the input query tensor.

Suggested change
device="cuda",
device=q_fp8.device,

dtype=torch.float32,
)
# TODO: 1. Replace _stage1 and out_qk.sum with another fused
# variant;
# 2. Remove ChunkQ when AITER PR #2891 merged
_stage1(
q_fp8,
kv_cache_fp8,
weights,
out_qk,
context_lens,
block_tables,
max_model_len,
ChunkQ=heads,
)
return out_qk.sum(dim=0)
else:
return fp8_paged_mqa_logits_torch(
q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
Expand Down Expand Up @@ -540,7 +623,7 @@ def rocm_aiter_sparse_attn_indexer(
num_tokens = slot_mapping.shape[0]
k = k[:num_tokens]

ops.indexer_k_quant_and_cache(
indexer_k_quant_and_cache_triton(
k,
kv_cache,
slot_mapping,
Expand Down Expand Up @@ -598,6 +681,7 @@ def rocm_aiter_sparse_attn_indexer(
if has_decode:
decode_metadata = layer_attn_metadata.decode
assert decode_metadata is not None
kv_block_size = kv_cache.shape[1]
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2)
Expand All @@ -620,14 +704,17 @@ def rocm_aiter_sparse_attn_indexer(
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n

actual_max_seq_len = layer_attn_metadata.max_seq_len

logits = rocm_fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
max_model_len=actual_max_seq_len,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Using actual_max_seq_len (which is the dynamic maximum sequence length of the current batch) as the max_model_len parameter is incompatible with CUDA Graphs. Since this value is a Python integer and changes every step, it will be baked into the graph during capture. Subsequent replays with a different actual maximum sequence length will use the stale value, leading to incorrect GEMM dimensions or memory corruption. You should use the constant max_model_len passed as an argument to the function to ensure graph stability.

Suggested change
max_model_len=actual_max_seq_len,
max_model_len=max_model_len,

block_size=kv_block_size,
)

num_rows = logits.shape[0]
Expand Down
Loading