diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index e7170babb6c9..4d6515328011 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -213,4 +213,4 @@ configuration. | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | +| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | diff --git a/tests/v1/attention/test_rocm_attention_backends_selection.py b/tests/v1/attention/test_rocm_attention_backends_selection.py index a31c053aed21..3badf3ace9a3 100644 --- a/tests/v1/attention/test_rocm_attention_backends_selection.py +++ b/tests/v1/attention/test_rocm_attention_backends_selection.py @@ -29,11 +29,18 @@ def mock_vllm_config(): @pytest.fixture def mock_on_gfx9(): - """Mock the on_gfx9 function to return True.""" + """Mock gfx9 arch detection to return True.""" with patch("vllm.platforms.rocm.on_gfx9", return_value=True): yield +@pytest.fixture +def mock_on_mi3xx(): + """Mock mi3xx arch detection to return True.""" + with patch("vllm.platforms.rocm.on_mi3xx", return_value=True): + yield + + @pytest.mark.parametrize( "env_vars, selected_backend, expected_backend_path", [ @@ -122,6 +129,7 @@ def test_standard_attention_backend_selection( expected_backend_path, mock_vllm_config, mock_on_gfx9, + mock_on_mi3xx, monkeypatch, ): """Test standard attention backend selection with various configurations.""" @@ -313,16 +321,16 @@ def test_mla_backend_selection( assert backend_path == expected_backend_path -def test_aiter_fa_requires_gfx9(mock_vllm_config): - """Test that ROCM_AITER_FA requires gfx9 architecture.""" +def test_aiter_fa_requires_mi3xx(mock_vllm_config): + """Test that ROCM_AITER_FA requires mi3xx architecture.""" from vllm.platforms.rocm import RocmPlatform - # Mock on_gfx9 to return False + # Mock on_mi3xx to return False (used by supports_compute_capability) with ( - patch("vllm.platforms.rocm.on_gfx9", return_value=False), + patch("vllm.platforms.rocm.on_mi3xx", return_value=False), pytest.raises( ValueError, - match="only supported on gfx9", + match="compute capability not supported", ), ): attn_selector_config = AttentionSelectorConfig( @@ -342,11 +350,12 @@ def test_aiter_fa_requires_gfx9(mock_vllm_config): def test_sparse_not_supported(mock_vllm_config): - """Test that sparse attention is not supported on ROCm.""" + """Test that sparse MLA without use_mla flag raises an error.""" from vllm.platforms.rocm import RocmPlatform with pytest.raises( - AssertionError, match="Sparse MLA backend on ROCm only supports block size 1" + ValueError, + match="No valid attention backend found", ): attn_selector_config = AttentionSelectorConfig( head_size=128, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 6dbdd7dcbd51..7b465db446ab 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -31,6 +31,10 @@ class AiterMLABackend(MLACommonBackend): "fp8_e5m2", ] + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + @staticmethod def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [1] diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index f6c1790f60c8..2da2bbd6bb5a 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -19,6 +19,7 @@ from vllm.v1.attention.backend import ( AttentionLayer, AttentionType, + MultipleOf, is_quantized_kv_cache, ) from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd @@ -33,6 +34,20 @@ class TritonMLABackend(MLACommonBackend): "bfloat16", ] + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [MultipleOf(16)] + + @classmethod + def supports_block_size(cls, block_size: int | None) -> bool: + if block_size is None: + return True + return block_size % 16 == 0 + @staticmethod def get_name() -> str: return "TRITON_MLA" diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index dbfb924a87b8..bba7e7b97087 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -29,6 +29,12 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [MultipleOf(16)] + @classmethod + def supports_block_size(cls, block_size: int | None) -> bool: + if block_size is None: + return True + return block_size % 16 == 0 + @classmethod def supports_head_size(cls, head_size: int) -> bool: return head_size >= 32 diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index e8d34822eb56..96c4033d8adc 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -188,6 +188,12 @@ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: # uses our optimized kernel logic. return [16, 32, 544] + @classmethod + def supports_block_size(cls, block_size: int | None) -> bool: + if block_size is None: + return True + return block_size in (16, 32, 544) + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 80, 96, 128, 160, 192, 224, 256] diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 953d7b3c45dd..e3734b3a2644 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -273,6 +273,12 @@ class TritonAttentionBackend(AttentionBackend): def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: return [MultipleOf(16)] + @classmethod + def supports_block_size(cls, block_size: int | None) -> bool: + if block_size is None: + return True + return block_size % 16 == 0 + forward_includes_kv_cache_update: bool = False @staticmethod