Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f782c66
add AITER MLA implementation in attention backend
vllmellm Mar 28, 2025
42d5c62
remove unused arguments in aiter mla decode fwd kernel
vllmellm Mar 28, 2025
565a3fd
add unittest for AITER MLA backend in attention selector
vllmellm Mar 29, 2025
645f400
add unittest for MLA attention backend selector
vllmellm Apr 1, 2025
22c8726
code cleaning
vllmellm Apr 1, 2025
5dc1348
update AITER version
vllmellm Apr 1, 2025
12f8023
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 1, 2025
da8c69f
add ck flash attn in prefill mla computation
vllmellm Apr 2, 2025
1ea5718
further code cleaning
vllmellm Apr 2, 2025
681d777
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 2, 2025
9ada055
fix mypy typing errors
vllmellm Apr 3, 2025
1ceb3b9
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 3, 2025
20a3f07
fix mypy error on Iterable typing error
vllmellm Apr 3, 2025
194a42a
remove padding for v tensor in AITER MLA which improves performance
vllmellm Apr 15, 2025
a9a02d5
upgrade aiter package version
vllmellm Apr 15, 2025
02a4fb3
only support AITER FA in AITER MLA backend to avoid latency caused by…
vllmellm Apr 15, 2025
95213e2
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 15, 2025
6e48433
add missing data types of arguments in aiter_mla_decode_fwd
vllmellm Apr 16, 2025
8c2ed72
NIT
vllmellm Apr 17, 2025
c95cb02
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 21, 2025
25d88d5
support block-size 1 for ROCM AITER MLA
vllmellm Apr 21, 2025
f38c4a9
fix mypy error
vllmellm Apr 21, 2025
0027497
preserve the lines
vllmellm Apr 21, 2025
78007d0
return back calling the tiron fa function to its original format
vllmellm Apr 21, 2025
cb4e861
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 22, 2025
54817a1
fix fstring in error message
vllmellm Apr 22, 2025
8fd039e
Update MLA attention backend selector for rocm attention selector and…
vllmellm Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="8970b25b"
ARG AITER_BRANCH="5a77249"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base
Expand Down
149 changes: 128 additions & 21 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,152 @@ def clear_cache():
_cached_get_attn_backend.cache_clear()


@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
"cuda": ["TRITON_MLA", "FLASHMLA"],
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
"cpu": [],
}

DEVICE_REGULAR_ATTN_BACKENDS = {
"cuda": ["XFORMERS", "FLASHINFER"],
"hip": ["ROCM_FLASH"],
"cpu": ["TORCH_SDPA"],
}

DEVICE_MLA_BLOCK_SIZES = {
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
"hip": [16, 1], # HIP requires special handling for block_size=1
"cpu": [16] # CPU uses fixed block size from test cases
}


def generate_params():
params = []
for use_mla in [True, False]:
for device in ["cuda", "hip", "cpu"]:
backends = DEVICE_MLA_BACKENDS[
device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
for name in backends:
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
16
]
for block_size in block_sizes:
params.append(
pytest.param(
device,
name,
use_mla,
block_size,
id=
f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
))
return params


@pytest.mark.parametrize("device, name, use_mla, block_size",
generate_params())
@pytest.mark.parametrize("use_v1", [True, False])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env(
device: str,
name: str,
use_mla: bool,
block_size: int,
use_v1: bool,
device: str,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""

"""Test attention backend selection with valid device-backend pairs."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
m.setenv(STR_BACKEND_ENV_VAR, name)
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")

if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
block_size, False)
assert backend.get_name() == "TORCH_SDPA"

elif device == "hip":
with patch("vllm.attention.selector.current_platform",
RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
assert backend.get_name() == EXPECTED
else:
if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float16,
torch.float16, 16, False)
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == EXPECTED
if use_mla:
# Validate HIP MLA backend-block_size combinations
valid_combination = (
(name == "TRITON_MLA" and block_size != 1)
or (name == "ROCM_AITER_MLA" and block_size == 1))

if valid_combination:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == name
else:
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
assert backend.get_name() == expected

elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
if use_mla:
if name == "FLASHMLA" and block_size == 64:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)

# only on cuda platforms with specific capability.
is_supported, _ = is_flashmla_supported()

if not is_supported:
# if platform is not supported then skip this case.
pytest.skip()
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = ("TRITON_MLA_VLLM_V1"
if use_v1 else "TRITON_MLA")
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected


def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
Expand Down
21 changes: 18 additions & 3 deletions vllm/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,24 @@ def advance_step(self,
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)

self._ops_advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions)

def _ops_advance_step(self, num_seqs: int, num_queries: int,
block_size: int, input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor) -> None:
# here we use advance_step_flashinfo to update the paged_kv_* tensors
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
input_tokens=input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
input_positions=input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
Expand All @@ -727,6 +739,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
BLOCK_TABLE_EXTENDER: list[list[int]] = []

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
Expand Down Expand Up @@ -877,8 +890,10 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why relocate these lines? Also can you please explain to me why we now need self.__class__.BLOCK_TABLE_EXTENDER

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.class.BLOCK_TABLE_EXTENDER this is a static class variable since common had this hardcoded as "[]" in the line below:
self.block_tables.extend([] * cuda_graph_pad_size)
cuz in AiterMLAMetadataBuilder for capturing graph we need "[[]]" instead of "[]", by eliminating the hardcoded extender into class variable allows the subclass to implement itsown extender value or just inherit from parent.

to review this file is better to open the entire file, as the github interface is not representative enough what has been changed.

overall as explained in the PR descript for the summary of the changes to accommodate AITER MLA implementation and reduce the code duplication in the subclass some refactoring has been made in certain function to allow more flexibility in subclasses.

Implementation

ROCM_AITER_MLA is introduced as an additional attention backend type for ROCm platform.
To support this backend the modules below are implemented vllm/attention/backends/rocm_aiter_mla.py

  • AiterMLABackend inherits from MLACommonBackend.
  • AiterMLAMetadata inherits from MLACommonMetadata: note that from this class the advance_step function utilizes advance_step_flashinfer function from VLLM cutom ops.
  • AiterMLAMetadataBuilder inherits from MLACommonMetadataBuilder.
  • AiterMLAState inherits from MLACommonState.
  • AiterMLAImpl class inherits from CommonMLAImpl:
    Important notes for this class:
    • flash_attn_varlen_func (FA function) used in this class is AITER FA implementation (flash_attn_varlen_func from AITER package).
    • _forward_decode function in this class uses mla_decode_fwd kernel from AITER package.

The MLACommon module has been refactored to reduce code duplication in its subclasses. This was achieved by separating the attention output computation into two dedicated functions named as _get_fwd_prefill_attn_output and _get_prefill_ctx_attn_output that are used in _compute_prefill_context and _forward_prefill function respectively.
Another refactoring is placed in advance_step function by separating out the pre assertion checks before calling an advance_step method to allow advance_step function to be overridden without code duplication in its subclasses.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson after resolving merge conflict for this file. the only changes in common.py are as below:

  • invoking ops.advance_step_flashattn in a separate function _ops_advance_step that can be overridden by subclass that is used in advance_step function.

  • use of "static" class variable as BLOCK_TABLE_EXTENDER: list[list[int]] = [] that is used to update self.block_tables in graph mode which eliminates the hardcoded "[]" self.block_tables.extend([] * cuda_graph_pad_size) to allow flexibility for the subclasses to override this update based on the class variable.

cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens

block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
Expand Down
Loading