Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,19 @@
)
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform

# CudaPlatform and RocmPlatform import their respective compiled C extensions
# at module level, raising ModuleNotFoundError on incompatible builds.
try:
from vllm.platforms.cuda import CudaPlatform
except (ImportError, ModuleNotFoundError):
CudaPlatform = None

try:
from vllm.platforms.rocm import RocmPlatform
except (ImportError, ModuleNotFoundError):
RocmPlatform = None

from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend

Expand Down Expand Up @@ -101,6 +112,8 @@ def test_backend_selection(
assert backend.get_name() == "CPU_ATTN"

elif device == "hip":
if RocmPlatform is None:
pytest.skip("RocmPlatform not available")
with patch("vllm.platforms.current_platform", RocmPlatform()):
if use_mla:
# ROCm MLA backend logic:
Expand All @@ -126,6 +139,8 @@ def test_backend_selection(
assert backend.get_name() == expected

elif device == "cuda":
if CudaPlatform is None:
pytest.skip("CudaPlatform not available")
with patch("vllm.platforms.current_platform", CudaPlatform()):
capability = torch.cuda.get_device_capability()
if use_mla:
Expand Down Expand Up @@ -214,7 +229,7 @@ def test_backend_selection(
assert backend.get_name() == expected


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("device", ["cpu", "cuda", "hip"])
def test_fp32_fallback(device: str):
"""Test attention backend selection with fp32."""
# Use default config (no backend specified)
Expand All @@ -227,10 +242,25 @@ def test_fp32_fallback(device: str):
assert backend.get_name() == "CPU_ATTN"

elif device == "cuda":
if CudaPlatform is None:
pytest.skip("CudaPlatform not available")
with patch("vllm.platforms.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None)
assert backend.get_name() == "FLEX_ATTENTION"

elif device == "hip":
if RocmPlatform is None:
pytest.skip("RocmPlatform not available")
# ROCm backends do not support head_size=16 (minimum is 32).
# No known HuggingFace transformer model uses head_size=16.
# Revisit if a real model with this head size is identified
# and accuracy-tested.
with (
patch("vllm.platforms.current_platform", RocmPlatform()),
pytest.raises(ValueError, match="No valid attention backend"),
):
get_attn_backend(16, torch.float32, None)


def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
"""Test FlashAttn validation."""
Expand Down Expand Up @@ -367,6 +397,8 @@ def test_per_head_quant_scales_backend_selection(
attention_config=attention_config, cache_config=cache_config
)

if CudaPlatform is None:
pytest.skip("CudaPlatform not available")
with (
set_current_vllm_config(vllm_config),
patch("vllm.platforms.current_platform", CudaPlatform()),
Expand Down
Loading