Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
f782c66
add AITER MLA implementation in attention backend
vllmellm Mar 28, 2025
42d5c62
remove unused arguments in aiter mla decode fwd kernel
vllmellm Mar 28, 2025
565a3fd
add unittest for AITER MLA backend in attention selector
vllmellm Mar 29, 2025
645f400
add unittest for MLA attention backend selector
vllmellm Apr 1, 2025
22c8726
code cleaning
vllmellm Apr 1, 2025
5dc1348
update AITER version
vllmellm Apr 1, 2025
12f8023
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 1, 2025
da8c69f
add ck flash attn in prefill mla computation
vllmellm Apr 2, 2025
1ea5718
further code cleaning
vllmellm Apr 2, 2025
681d777
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 2, 2025
9ada055
fix mypy typing errors
vllmellm Apr 3, 2025
1ceb3b9
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 3, 2025
20a3f07
fix mypy error on Iterable typing error
vllmellm Apr 3, 2025
194a42a
remove padding for v tensor in AITER MLA which improves performance
vllmellm Apr 15, 2025
a9a02d5
upgrade aiter package version
vllmellm Apr 15, 2025
02a4fb3
only support AITER FA in AITER MLA backend to avoid latency caused by…
vllmellm Apr 15, 2025
95213e2
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 15, 2025
6e48433
add missing data types of arguments in aiter_mla_decode_fwd
vllmellm Apr 16, 2025
8c2ed72
NIT
vllmellm Apr 17, 2025
c95cb02
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 21, 2025
25d88d5
support block-size 1 for ROCM AITER MLA
vllmellm Apr 21, 2025
f38c4a9
fix mypy error
vllmellm Apr 21, 2025
0027497
preserve the lines
vllmellm Apr 21, 2025
78007d0
return back calling the tiron fa function to its original format
vllmellm Apr 21, 2025
cb4e861
Merge remote-tracking branch 'origin/main' into aiter-mla-integration
vllmellm Apr 22, 2025
54817a1
fix fstring in error message
vllmellm Apr 22, 2025
8fd039e
Update MLA attention backend selector for rocm attention selector and…
vllmellm Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="8970b25b"
ARG AITER_BRANCH="5a77249"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base
Expand Down
149 changes: 128 additions & 21 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,45 +19,152 @@ def clear_cache():
_cached_get_attn_backend.cache_clear()


@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
"cuda": ["TRITON_MLA", "FLASHMLA"],
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
"cpu": [],
}

DEVICE_NON_MLA_BACKENDS = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: lets just call this DEVICE_REGULAR_ATTN_BACKENDS instead of MLA

Copy link
Contributor Author

@vllmellm vllmellm Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson This has been addressed. Thanks.

"cuda": ["XFORMERS", "FLASHINFER"],
"hip": ["ROCM_FLASH"],
"cpu": ["TORCH_SDPA"],
}

DEVICE_MLA_BLOCK_SIZES = {
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
"hip": [16, 1], # HIP requires special handling for block_size=1
"cpu": [16] # CPU uses fixed block size from test cases
}


def generate_params():
params = []
for use_mla in [True, False]:
for device in ["cuda", "hip", "cpu"]:
backends = DEVICE_MLA_BACKENDS[
device] if use_mla else DEVICE_NON_MLA_BACKENDS[device]
for name in backends:
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
16
]
for block_size in block_sizes:
params.append(
pytest.param(
device,
name,
use_mla,
block_size,
id=
f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
))
return params


@pytest.mark.parametrize("device, name, use_mla, block_size",
generate_params())
@pytest.mark.parametrize("use_v1", [True, False])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env(
device: str,
name: str,
use_mla: bool,
block_size: int,
use_v1: bool,
device: str,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""

"""Test attention backend selection with valid device-backend pairs."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
m.setenv(STR_BACKEND_ENV_VAR, name)
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")

if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
block_size, False)
assert backend.get_name() == "TORCH_SDPA"

elif device == "hip":
with patch("vllm.attention.selector.current_platform",
RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
assert backend.get_name() == EXPECTED
else:
if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float16,
torch.float16, 16, False)
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == EXPECTED
if use_mla:
# Validate HIP MLA backend-block_size combinations
valid_combination = (
(name == "TRITON_MLA" and block_size != 1)
or (name == "ROCM_AITER_MLA" and block_size == 1))

if valid_combination:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == name
else:
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
assert backend.get_name() == expected

elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
if use_mla:
if name == "FLASHMLA" and block_size == 64:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)

# only on cuda platforms with specific capability.
is_supported, _ = is_flashmla_supported()

if not is_supported:
# if platform is not supported then skip this case.
pytest.skip()
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = ("TRITON_MLA_VLLM_V1"
if use_v1 else "TRITON_MLA")
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected


def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
Expand Down
Loading