diff --git a/tests/v1/attention/test_rocm_attention_backends_selection.py b/tests/v1/attention/test_rocm_attention_backends_selection.py index a31c053aed21..7baeb70e18cd 100644 --- a/tests/v1/attention/test_rocm_attention_backends_selection.py +++ b/tests/v1/attention/test_rocm_attention_backends_selection.py @@ -47,7 +47,7 @@ def mock_on_gfx9(): ( {}, None, - AttentionBackendEnum.TRITON_ATTN.get_path(), + AttentionBackendEnum.ROCM_ATTN.get_path(), ), # Test Case 2: Explicit TRITON_ATTN backend ( @@ -74,11 +74,10 @@ def mock_on_gfx9(): AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(), ), # Test Case 6: VLLM_ROCM_USE_AITER=1 - # (defaults to AITER FA when MHA not explicitly disabled) ( {"VLLM_ROCM_USE_AITER": "1"}, None, - AttentionBackendEnum.ROCM_AITER_FA.get_path(), + AttentionBackendEnum.ROCM_ATTN.get_path(), ), # Test Case 7: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=1 ( @@ -102,11 +101,10 @@ def mock_on_gfx9(): AttentionBackendEnum.TRITON_ATTN.get_path(), ), # Test Case 10: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0 - # (explicitly disabled) ( {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"}, None, - AttentionBackendEnum.TRITON_ATTN.get_path(), + AttentionBackendEnum.ROCM_ATTN.get_path(), ), # Test Case 11: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN ( diff --git a/vllm/envs.py b/vllm/envs.py index ad220a979d44..e44ab2967df2 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -112,7 +112,6 @@ VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLE_PYNCCL: bool = False VLLM_ROCM_USE_AITER: bool = False - VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True @@ -281,6 +280,14 @@ def maybe_convert_bool(value: str | None) -> bool | None: return bool(int(value)) +def use_aiter() -> bool: + from vllm._aiter_ops import is_aiter_found_and_supported + + return is_aiter_found_and_supported() and os.getenv( + "VLLM_ROCM_USE_AITER", "True" + ).lower() in ("true", "1") + + def disable_compile_cache() -> bool: return bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))) @@ -951,14 +958,7 @@ def _get_or_set_default() -> str: ), # Disable aiter ops unless specifically enabled. # Acts as a parent switch to enable the rest of the other operations. - "VLLM_ROCM_USE_AITER": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1") - ), - # Whether to use aiter paged attention. - # By default is disabled. - "VLLM_ROCM_USE_AITER_PAGED_ATTN": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in ("true", "1") - ), + "VLLM_ROCM_USE_AITER": use_aiter, # use aiter linear op if aiter ops are enabled # The following list of related ops # - scaled_mm (per-tensor / rowwise) @@ -980,9 +980,9 @@ def _get_or_set_default() -> str: os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in ("true", "1") ), # Whether to use aiter mha ops. - # By default is enabled. + # By default is disabled. "VLLM_ROCM_USE_AITER_MHA": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1") + os.getenv("VLLM_ROCM_USE_AITER_MHA", "False").lower() in ("true", "1") ), # Whether to use aiter fp4 gemm asm. # By default is disabled. diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 49225fc2e19e..efba40a6a3be 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -321,6 +321,11 @@ def get_attn_backend_cls( vllm_config is not None and vllm_config.attention_config.use_prefill_decode_attention ): + logger.warning_once( + "use_prefill_decode_attention is deprecated and will be removed in " + "future releases. " + "Use --attention_config.backend to select the desired backend" + ) logger.info("Using Rocm Attention backend.") return AttentionBackendEnum.ROCM_ATTN.get_path() @@ -334,9 +339,9 @@ def get_attn_backend_cls( logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path() - # Default: Triton Unified Attention - logger.info("Using Triton Attention backend.") - return AttentionBackendEnum.TRITON_ATTN.get_path() + # Default: ROCm split Attention + logger.info("Using ROCm Attention backend.") + return AttentionBackendEnum.ROCM_ATTN.get_path() raise RuntimeError( f"Attention backend {selected_backend.name} is not supported on "