diff --git a/tests/v1/attention/test_rocm_attention_backends_selection.py b/tests/v1/attention/test_rocm_attention_backends_selection.py new file mode 100644 index 000000000000..4ec79e9eb6ba --- /dev/null +++ b/tests/v1/attention/test_rocm_attention_backends_selection.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for attention backend selectors.""" + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.platforms import current_platform + +# ROCm-specific attention backend selection tests +pytestmark = pytest.mark.skipif( + not current_platform.is_rocm(), reason="ROCm-specific tests" +) + + +@pytest.fixture +def mock_vllm_config(): + """Create a mock VllmConfig for testing.""" + config = MagicMock() + config.model_config.dtype = torch.float16 + config.model_config.hf_config.architectures = ["LlamaForCausalLM"] + config.cache_config.block_size = 16 + return config + + +@pytest.fixture +def mock_on_gfx9(): + """Mock the on_gfx9 function to return True.""" + with patch("vllm.platforms.rocm.on_gfx9", return_value=True): + yield + + +@pytest.mark.parametrize( + "env_vars, selected_backend, expected_backend_path", + [ + # Test Case 1: Default (no env vars, no explicit backend) + ( + {}, + None, + AttentionBackendEnum.TRITON_ATTN.get_path(), + ), + # Test Case 2: Explicit TRITON_ATTN backend + ( + {}, + "TRITON_ATTN", + AttentionBackendEnum.TRITON_ATTN.get_path(), + ), + # Test Case 3: Explicit ROCM_ATTN backend + ( + {}, + "ROCM_ATTN", + AttentionBackendEnum.ROCM_ATTN.get_path(), + ), + # Test Case 4: Explicit ROCM_AITER_FA backend + ( + {}, + "ROCM_AITER_FA", + AttentionBackendEnum.ROCM_AITER_FA.get_path(), + ), + # Test Case 5: Explicit ROCM_AITER_UNIFIED_ATTN backend + ( + {}, + "ROCM_AITER_UNIFIED_ATTN", + AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(), + ), + # Test Case 6: VLLM_ROCM_USE_AITER=1 + # (defaults to AITER FA when MHA not explicitly disabled) + ( + {"VLLM_ROCM_USE_AITER": "1"}, + None, + AttentionBackendEnum.ROCM_AITER_FA.get_path(), + ), + # Test Case 7: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=1 + ( + {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "1"}, + None, + AttentionBackendEnum.ROCM_AITER_FA.get_path(), + ), + # Test Case 8: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION=1 + ( + { + "VLLM_ROCM_USE_AITER": "1", + "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": "1", + }, + None, + AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(), + ), + # Test Case 9: VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 + ( + {"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": "1"}, + None, + AttentionBackendEnum.ROCM_ATTN.get_path(), + ), + # Test Case 10: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN + ( + {"VLLM_ROCM_USE_AITER": "1"}, + "TRITON_ATTN", + AttentionBackendEnum.TRITON_ATTN.get_path(), + ), + # Test Case 11: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0 + # (explicitly disabled) + ( + {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"}, + None, + AttentionBackendEnum.TRITON_ATTN.get_path(), + ), + # Test Case 12: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN + ( + {"VLLM_ROCM_USE_AITER": "1"}, + "ROCM_ATTN", + AttentionBackendEnum.ROCM_ATTN.get_path(), + ), + ], +) +def test_standard_attention_backend_selection( + env_vars, + selected_backend, + expected_backend_path, + mock_vllm_config, + mock_on_gfx9, + monkeypatch, +): + """Test standard attention backend selection with various configurations.""" + # Set environment variables + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + + # Import after setting env vars to ensure they're picked up + # Reload envs to pick up new environment variables + import importlib + + import vllm.envs as envs + from vllm.attention.backends.registry import _Backend + + importlib.reload(envs) + + # Convert string backend to enum if provided + backend_enum = None + if selected_backend: + backend_enum = getattr(_Backend, selected_backend) + + # Get the backend class path + from vllm.platforms.rocm import RocmPlatform + + backend_path = RocmPlatform.get_attn_backend_cls( + selected_backend=backend_enum, + head_size=128, + dtype=torch.float16, + kv_cache_dtype="auto", + block_size=16, + use_mla=False, + has_sink=False, + use_sparse=False, + ) + assert backend_path == expected_backend_path + + +@pytest.mark.parametrize( + "env_vars, selected_backend, block_size, expected_backend_path, should_raise", + [ + # Test Case 1: TRITON_MLA with block_size != 1 + ( + {}, + "TRITON_MLA", + 16, + AttentionBackendEnum.TRITON_MLA.get_path(), + False, + ), + # Test Case 2: TRITON_MLA with block_size == 1 (should raise) + ( + {}, + "TRITON_MLA", + 1, + None, + True, + ), + # Test Case 3: ROCM_AITER_MLA with block_size == 1 + ( + {}, + "ROCM_AITER_MLA", + 1, + AttentionBackendEnum.ROCM_AITER_MLA.get_path(), + False, + ), + # Test Case 4: ROCM_AITER_MLA with block_size != 1 (should raise) + ( + {}, + "ROCM_AITER_MLA", + 16, + AttentionBackendEnum.ROCM_AITER_MLA.get_path(), + False, + ), + # Test Case 5: VLLM_ROCM_USE_AITER=1 with block_size == 1 + ( + {"VLLM_ROCM_USE_AITER": "1"}, + None, + 1, + AttentionBackendEnum.ROCM_AITER_MLA.get_path(), + False, + ), + # Test Case 6: VLLM_ROCM_USE_AITER=1 with block_size == 16 + # (should use ROCM_AITER_MLA now, as it supports block_size 16) + ( + {"VLLM_ROCM_USE_AITER": "1"}, + None, + 16, + AttentionBackendEnum.ROCM_AITER_MLA.get_path(), + False, + ), + # Test Case 7: VLLM_ROCM_USE_AITER=1 + explicit TRITON_MLA + ( + {"VLLM_ROCM_USE_AITER": "1"}, + "TRITON_MLA", + 16, + AttentionBackendEnum.TRITON_MLA.get_path(), + False, + ), + # Test Case 8: Explicit ROCM_AITER_TRITON_MLA + ( + {}, + "ROCM_AITER_TRITON_MLA", + 16, + AttentionBackendEnum.ROCM_AITER_TRITON_MLA.get_path(), + False, + ), + ], +) +def test_mla_backend_selection( + env_vars, + selected_backend, + block_size, + expected_backend_path, + should_raise, + mock_vllm_config, + monkeypatch, +): + """Test MLA backend selection with various configurations.""" + # Set environment variables + for key, value in env_vars.items(): + monkeypatch.setenv(key, value) + + # Import after setting env vars + # Reload envs + import importlib + + import vllm.envs as envs + from vllm.attention.backends.registry import _Backend + + importlib.reload(envs) + + # Mock is_aiter_mla_enabled based on env vars and block_size + aiter_enabled = env_vars.get("VLLM_ROCM_USE_AITER") == "1" + + mock_rocm_ops = MagicMock() + mock_rocm_ops.is_mla_enabled.return_value = aiter_enabled + mock_aiter_module = MagicMock() + mock_aiter_module.rocm_aiter_ops = mock_rocm_ops + + with patch.dict("sys.modules", {"vllm._aiter_ops": mock_aiter_module}): + # Convert string backend to enum if provided + backend_enum = None + if selected_backend: + backend_enum = getattr(_Backend, selected_backend) + + from vllm.platforms.rocm import RocmPlatform + + if should_raise: + with pytest.raises(ValueError): + RocmPlatform.get_attn_backend_cls( + selected_backend=backend_enum, + head_size=128, + dtype=torch.float16, + kv_cache_dtype="auto", + block_size=block_size, + use_mla=True, + has_sink=False, + use_sparse=False, + ) + else: + backend_path = RocmPlatform.get_attn_backend_cls( + selected_backend=backend_enum, + head_size=128, + dtype=torch.float16, + kv_cache_dtype="auto", + block_size=block_size, + use_mla=True, + has_sink=False, + use_sparse=False, + ) + assert backend_path == expected_backend_path + + +def test_aiter_fa_requires_gfx9(mock_vllm_config): + """Test that ROCM_AITER_FA requires gfx9 architecture.""" + from vllm.attention.backends.registry import _Backend + from vllm.platforms.rocm import RocmPlatform + + # Mock on_gfx9 to return False + with ( + patch("vllm.platforms.rocm.on_gfx9", return_value=False), + pytest.raises( + ValueError, + match="only supported on gfx9", + ), + ): + RocmPlatform.get_attn_backend_cls( + selected_backend=_Backend.ROCM_AITER_FA, + head_size=128, + dtype=torch.float16, + kv_cache_dtype="auto", + block_size=16, + use_mla=False, + has_sink=False, + use_sparse=False, + ) + + +def test_sparse_not_supported(mock_vllm_config): + """Test that sparse attention is not supported on ROCm.""" + from vllm.platforms.rocm import RocmPlatform + + with pytest.raises( + AssertionError, match="Sparse MLA backend on ROCm only supports block size 1" + ): + RocmPlatform.get_attn_backend_cls( + selected_backend=None, + head_size=128, + dtype=torch.float16, + kv_cache_dtype="auto", + block_size=16, + use_mla=False, + has_sink=False, + use_sparse=True, + ) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f9005fd7d044..f3ec965bd088 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -262,30 +262,64 @@ def get_attn_backend_cls( f"is not MLA type while requested for MLA backend." ) - if selected_backend == AttentionBackendEnum.FLEX_ATTENTION: - logger.info("Using FlexAttention backend.") - return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" - if ( - rocm_aiter_ops.is_mha_enabled() - ) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA: - logger.info("Using Aiter Flash Attention backend.") - return AttentionBackendEnum.ROCM_AITER_FA.get_path() - if ( - rocm_aiter_ops.is_triton_unified_attn_enabled() - ) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: - logger.info("Using Aiter Unified Attention backend.") - return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() - if ( - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - or selected_backend == AttentionBackendEnum.ROCM_ATTN - ): - # rocm specific backend, with aiter and/or - # triton prefix-prefill - logger.info("Using Rocm Attention backend.") + if selected_backend == AttentionBackendEnum.TRITON_ATTN: + logger.info("Using Triton Attention backend on V1 engine.") + return AttentionBackendEnum.TRITON_ATTN.get_path() + + if selected_backend == AttentionBackendEnum.ROCM_ATTN: + logger.info("Using Rocm Attention backend on V1 engine.") return AttentionBackendEnum.ROCM_ATTN.get_path() - # default case, using triton unified attention - logger.info("Using Triton Attention backend.") - return AttentionBackendEnum.TRITON_ATTN.get_path() + + if selected_backend == AttentionBackendEnum.ROCM_AITER_FA: + if on_gfx9(): + logger.info("Using Aiter Flash Attention backend on V1 engine.") + return AttentionBackendEnum.ROCM_AITER_FA.get_path() + else: + raise ValueError( + f"The selected backend, {selected_backend.name}, " + "is only supported on gfx9 architectures." + ) + + if selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN: + logger.info("Using Aiter Unified Attention backend on V1 engine.") + return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() + + # Handle automatic backend selection based on environment variables + if selected_backend is None: + # Priority 1: Check for AITER Unified Attention (must check before MHA) + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: + logger.info("Using Aiter Unified Attention backend on V1 engine.") + return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path() + + # Priority 2: Check for AITER MHA (Flash Attention) + # Only use if explicitly enabled (not just VLLM_ROCM_USE_AITER=1) + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): + logger.info("Using Aiter Flash Attention backend on V1 engine.") + return AttentionBackendEnum.ROCM_AITER_FA.get_path() + + # Priority 3: Check for ROCM_ATTN (prefill-decode split) + if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION: + logger.info("Using Rocm Attention backend on V1 engine.") + return AttentionBackendEnum.ROCM_ATTN.get_path() + + # Priority 4: Check for AITER enabled without specific flags + # This defaults to AITER FA only if MHA is not explicitly disabled + if ( + envs.VLLM_ROCM_USE_AITER + and on_gfx9() + and envs.VLLM_ROCM_USE_AITER_MHA is not False + ): + logger.info("Using Aiter Flash Attention backend on V1 engine.") + return AttentionBackendEnum.ROCM_AITER_FA.get_path() + + # Default: Triton Unified Attention + logger.info("Using Triton Attention backend on V1 engine.") + return AttentionBackendEnum.TRITON_ATTN.get_path() + + raise RuntimeError( + "V0 attention backends have been removed. Set VLLM_USE_V1=1 " + "to select a supported backend." + ) @classmethod def set_device(cls, device: torch.device) -> None: