Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,12 @@ def _rocm_aiter_mla_decode_fwd_impl(
logit_cap: float = 0.0,
q_scale: torch.Tensor | None = None,
kv_scale: torch.Tensor | None = None,
work_meta_data: torch.Tensor | None = None,
work_indptr: torch.Tensor | None = None,
work_info_set: torch.Tensor | None = None,
reduce_indptr: torch.Tensor | None = None,
reduce_final_map: torch.Tensor | None = None,
reduce_partial_map: torch.Tensor | None = None,
) -> None:
from aiter.mla import mla_decode_fwd

Expand All @@ -395,6 +401,29 @@ def _rocm_aiter_mla_decode_fwd_impl(
kwargs["q_scale"] = q_scale
kwargs["kv_scale"] = kv_scale

if work_meta_data is not None:
assert work_indptr is not None, (
"work_indptr must be provided with work_meta_data"
)
assert work_info_set is not None, (
"work_info_set must be provided with work_meta_data"
)
assert reduce_indptr is not None, (
"reduce_indptr must be provided with work_meta_data"
)
assert reduce_final_map is not None, (
"reduce_final_map must be provided with work_meta_data"
)
assert reduce_partial_map is not None, (
"reduce_partial_map must be provided with work_meta_data"
)
kwargs["work_meta_data"] = work_meta_data
kwargs["work_indptr"] = work_indptr
kwargs["work_info_set"] = work_info_set
kwargs["reduce_indptr"] = reduce_indptr
kwargs["reduce_final_map"] = reduce_final_map
kwargs["reduce_partial_map"] = reduce_partial_map
Comment thread
SKPsanjeevi marked this conversation as resolved.

mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
Expand All @@ -421,6 +450,12 @@ def _rocm_aiter_mla_decode_fwd_fake(
logit_cap: float = 0.0,
q_scale: torch.Tensor | None = None,
kv_scale: torch.Tensor | None = None,
work_meta_data: torch.Tensor | None = None,
work_indptr: torch.Tensor | None = None,
work_info_set: torch.Tensor | None = None,
reduce_indptr: torch.Tensor | None = None,
reduce_final_map: torch.Tensor | None = None,
reduce_partial_map: torch.Tensor | None = None,
) -> None:
pass

Expand Down Expand Up @@ -1614,6 +1649,12 @@ def mla_decode_fwd(
logit_cap: float = 0.0,
q_scale: torch.Tensor | None = None,
kv_scale: torch.Tensor | None = None,
work_meta_data: torch.Tensor | None = None,
work_indptr: torch.Tensor | None = None,
work_info_set: torch.Tensor | None = None,
reduce_indptr: torch.Tensor | None = None,
reduce_final_map: torch.Tensor | None = None,
reduce_partial_map: torch.Tensor | None = None,
):
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q,
Expand All @@ -1628,6 +1669,12 @@ def mla_decode_fwd(
logit_cap=logit_cap,
q_scale=q_scale,
kv_scale=kv_scale,
work_meta_data=work_meta_data,
work_indptr=work_indptr,
work_info_set=work_info_set,
reduce_indptr=reduce_indptr,
reduce_final_map=reduce_final_map,
reduce_partial_map=reduce_partial_map,
)

@staticmethod
Expand Down
109 changes: 107 additions & 2 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
QueryLenSupport,
)
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.kv_cache_interface import AttentionSpec


Expand Down Expand Up @@ -70,8 +75,14 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
max_qo_len: int | None = None


@dataclass
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass
work_meta_data: torch.Tensor | None = None
work_indptr: torch.Tensor | None = None
work_info_set: torch.Tensor | None = None
reduce_indptr: torch.Tensor | None = None
reduce_final_map: torch.Tensor | None = None
reduce_partial_map: torch.Tensor | None = None


class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
Expand Down Expand Up @@ -115,6 +126,55 @@ def __init__(
max_num_pages, dtype=torch.int32, device=device
)

from aiter import dtypes, get_mla_metadata_info_v1

self._num_attention_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
)
q_dtype = self.decode_attn_out_dtype
kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto")
if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"):
kv_cache_dtype_str = "fp8"
else:
kv_cache_dtype_str = "bf16"
kv_dtype = dtypes.d_dtypes.get(kv_cache_dtype_str, dtypes.bf16)
(
(work_meta_data_size, work_meta_data_type),
(work_indptr_size, work_indptr_type),
(work_info_set_size, work_info_set_type),
(reduce_indptr_size, reduce_indptr_type),
(reduce_final_map_size, reduce_final_map_type),
(reduce_partial_map_size, reduce_partial_map_type),
) = get_mla_metadata_info_v1(
max_num_reqs,
1,
self._num_attention_heads,
q_dtype,
kv_dtype,
is_sparse=False,
fast_mode=True,
)
self._mla_work_meta_data = torch.empty(
work_meta_data_size, dtype=work_meta_data_type, device=device
)
self._mla_work_indptr = torch.empty(
work_indptr_size, dtype=work_indptr_type, device=device
)
self._mla_work_info_set = torch.empty(
work_info_set_size, dtype=work_info_set_type, device=device
)
self._mla_reduce_indptr = torch.empty(
reduce_indptr_size, dtype=reduce_indptr_type, device=device
)
self._mla_reduce_final_map = torch.empty(
reduce_final_map_size, dtype=reduce_final_map_type, device=device
)
self._mla_reduce_partial_map = torch.empty(
reduce_partial_map_size,
dtype=reduce_partial_map_type,
device=device,
)

if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
Expand Down Expand Up @@ -183,6 +243,28 @@ def _build_decode(
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
)

from aiter import get_mla_metadata_v1

get_mla_metadata_v1(
qo_indptr,
paged_kv_indptr,
paged_kv_last_page_len,
self._num_attention_heads,
1,
True,
self._mla_work_meta_data,
self._mla_work_info_set,
self._mla_work_indptr,
self._mla_reduce_indptr,
self._mla_reduce_final_map,
self._mla_reduce_partial_map,
page_size=1,
kv_granularity=16,
max_seqlen_qo=max_qo_len,
uni_seqlen_qo=max_qo_len,
fast_mode=True,
)

attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
Expand All @@ -197,6 +279,23 @@ def _build_decode(

return attn_metadata

def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AiterMLAMetadata:
attn_metadata = super().build(
common_prefix_len, common_attn_metadata, fast_build
)
attn_metadata.work_meta_data = self._mla_work_meta_data
attn_metadata.work_indptr = self._mla_work_indptr
attn_metadata.work_info_set = self._mla_work_info_set
attn_metadata.reduce_indptr = self._mla_reduce_indptr
attn_metadata.reduce_final_map = self._mla_reduce_final_map
attn_metadata.reduce_partial_map = self._mla_reduce_partial_map
return attn_metadata


@triton.jit
def _copy_page_indices_kernel(
Expand Down Expand Up @@ -337,6 +436,12 @@ def forward_mqa(
attn_metadata.decode.paged_kv_last_page_len,
q_scale=layer._q_scale,
kv_scale=layer._k_scale,
work_meta_data=attn_metadata.work_meta_data,
work_indptr=attn_metadata.work_indptr,
work_info_set=attn_metadata.work_info_set,
reduce_indptr=attn_metadata.reduce_indptr,
reduce_final_map=attn_metadata.reduce_final_map,
reduce_partial_map=attn_metadata.reduce_partial_map,
)

if self._needs_head_repeat:
Expand Down
Loading