From ecbb5fd888608f18b0683793fe0f0c1d46d7c93e Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Wed, 28 Jan 2026 16:03:42 +0000 Subject: [PATCH 1/4] Change default settings for ROCm. Enable AITER where it is supported. Disable AITER MHA. Switch the default attention backend to ROCM_ATTN Signed-off-by: Gregory Shtrasberg --- vllm/envs.py | 14 ++++++++++---- vllm/platforms/rocm.py | 6 +++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index ad220a979d44..cdeb200ecfc4 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -281,6 +281,14 @@ def maybe_convert_bool(value: str | None) -> bool | None: return bool(int(value)) +def use_aiter() -> bool: + from vllm._aiter_ops import is_aiter_found_and_supported + + return is_aiter_found_and_supported() and os.getenv( + "VLLM_ROCM_USE_AITER", "True" + ).lower() in ("true", "1") + + def disable_compile_cache() -> bool: return bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0"))) @@ -951,9 +959,7 @@ def _get_or_set_default() -> str: ), # Disable aiter ops unless specifically enabled. # Acts as a parent switch to enable the rest of the other operations. - "VLLM_ROCM_USE_AITER": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1") - ), + "VLLM_ROCM_USE_AITER": use_aiter, # Whether to use aiter paged attention. # By default is disabled. "VLLM_ROCM_USE_AITER_PAGED_ATTN": lambda: ( @@ -982,7 +988,7 @@ def _get_or_set_default() -> str: # Whether to use aiter mha ops. # By default is enabled. "VLLM_ROCM_USE_AITER_MHA": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1") + os.getenv("VLLM_ROCM_USE_AITER_MHA", "False").lower() in ("true", "1") ), # Whether to use aiter fp4 gemm asm. # By default is disabled. diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 49225fc2e19e..6a8eea435de6 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -334,9 +334,9 @@ def get_attn_backend_cls( logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path() - # Default: Triton Unified Attention - logger.info("Using Triton Attention backend.") - return AttentionBackendEnum.TRITON_ATTN.get_path() + # Default: ROCm split Attention + logger.info("Using ROCm Attention backend.") + return AttentionBackendEnum.ROCM_ATTN.get_path() raise RuntimeError( f"Attention backend {selected_backend.name} is not supported on " From 61b4581629184224994c955e6394c73b7fd3f453 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 29 Jan 2026 21:52:02 +0000 Subject: [PATCH 2/4] Cleanup unused fields and variables Signed-off-by: Gregory Shtrasberg --- tests/engine/test_arg_utils.py | 5 ----- tests/v1/attention/utils.py | 1 - vllm/config/attention.py | 4 ---- vllm/envs.py | 8 +------- vllm/platforms/rocm.py | 13 +------------ 5 files changed, 2 insertions(+), 29 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 2acb38bc9a18..dd6f8499ceac 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -330,8 +330,6 @@ def test_attention_config(): "FLASH_ATTN", "--attention-config.flash_attn_version", "3", - "--attention-config.use_prefill_decode_attention", - "true", "--attention-config.flash_attn_max_num_splits_for_cuda_graph", "16", "--attention-config.use_cudnn_prefill", @@ -351,7 +349,6 @@ def test_attention_config(): assert engine_args.attention_config.backend is not None assert engine_args.attention_config.backend.name == "FLASH_ATTN" assert engine_args.attention_config.flash_attn_version == 3 - assert engine_args.attention_config.use_prefill_decode_attention is True assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16 assert engine_args.attention_config.use_cudnn_prefill is True assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True @@ -364,7 +361,6 @@ def test_attention_config(): [ "--attention-config=" '{"backend": "FLASHINFER", "flash_attn_version": 2, ' - '"use_prefill_decode_attention": false, ' '"flash_attn_max_num_splits_for_cuda_graph": 8, ' '"use_cudnn_prefill": false, ' '"use_trtllm_ragged_deepseek_prefill": false, ' @@ -378,7 +374,6 @@ def test_attention_config(): assert engine_args.attention_config.backend is not None assert engine_args.attention_config.backend.name == "FLASHINFER" assert engine_args.attention_config.flash_attn_version == 2 - assert engine_args.attention_config.use_prefill_decode_attention is False assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8 assert engine_args.attention_config.use_cudnn_prefill is False assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 3cff52929146..9ba04d2bb308 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -352,7 +352,6 @@ class BackendConfig: name="RocmAttn", attention_config={ "backend": "ROCM_ATTN", - "use_prefill_decode_attention": True, }, comp_config={ "cudagraph_mode": "FULL", diff --git a/vllm/config/attention.py b/vllm/config/attention.py index ee072fb1c86d..3a1ea4873e09 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -22,10 +22,6 @@ class AttentionConfig: """Force vllm to use a specific flash-attention version (2 or 3). Only valid when using the flash-attention backend.""" - use_prefill_decode_attention: bool = False - """Use separate prefill and decode kernels for attention instead of - the unified triton kernel.""" - flash_attn_max_num_splits_for_cuda_graph: int = 32 """Flash Attention max number splits for cuda graph decode.""" diff --git a/vllm/envs.py b/vllm/envs.py index cdeb200ecfc4..e44ab2967df2 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -112,7 +112,6 @@ VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLE_PYNCCL: bool = False VLLM_ROCM_USE_AITER: bool = False - VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True @@ -960,11 +959,6 @@ def _get_or_set_default() -> str: # Disable aiter ops unless specifically enabled. # Acts as a parent switch to enable the rest of the other operations. "VLLM_ROCM_USE_AITER": use_aiter, - # Whether to use aiter paged attention. - # By default is disabled. - "VLLM_ROCM_USE_AITER_PAGED_ATTN": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in ("true", "1") - ), # use aiter linear op if aiter ops are enabled # The following list of related ops # - scaled_mm (per-tensor / rowwise) @@ -986,7 +980,7 @@ def _get_or_set_default() -> str: os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in ("true", "1") ), # Whether to use aiter mha ops. - # By default is enabled. + # By default is disabled. "VLLM_ROCM_USE_AITER_MHA": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_MHA", "False").lower() in ("true", "1") ), diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 6a8eea435de6..1e12fc7ea6fd 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -313,18 +313,7 @@ def get_attn_backend_cls( logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path() - # Priority 3: Check for ROCM_ATTN (prefill-decode split) - from vllm.config import get_current_vllm_config_or_none - - vllm_config = get_current_vllm_config_or_none() - if ( - vllm_config is not None - and vllm_config.attention_config.use_prefill_decode_attention - ): - logger.info("Using Rocm Attention backend.") - return AttentionBackendEnum.ROCM_ATTN.get_path() - - # Priority 4: Check for AITER enabled without specific flags + # Priority 3: Check for AITER enabled without specific flags # This defaults to AITER FA only if MHA is not explicitly disabled if ( envs.VLLM_ROCM_USE_AITER From 65f84dcb33c06a1f74171255b33b6f049bc48b52 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Fri, 30 Jan 2026 19:29:18 +0000 Subject: [PATCH 3/4] Return use_prefill_decode_attention and add a deprecation warning for it Signed-off-by: Gregory Shtrasberg --- tests/engine/test_arg_utils.py | 5 +++++ tests/v1/attention/utils.py | 1 + vllm/config/attention.py | 4 ++++ vllm/platforms/rocm.py | 18 +++++++++++++++++- 4 files changed, 27 insertions(+), 1 deletion(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index dd6f8499ceac..2acb38bc9a18 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -330,6 +330,8 @@ def test_attention_config(): "FLASH_ATTN", "--attention-config.flash_attn_version", "3", + "--attention-config.use_prefill_decode_attention", + "true", "--attention-config.flash_attn_max_num_splits_for_cuda_graph", "16", "--attention-config.use_cudnn_prefill", @@ -349,6 +351,7 @@ def test_attention_config(): assert engine_args.attention_config.backend is not None assert engine_args.attention_config.backend.name == "FLASH_ATTN" assert engine_args.attention_config.flash_attn_version == 3 + assert engine_args.attention_config.use_prefill_decode_attention is True assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 16 assert engine_args.attention_config.use_cudnn_prefill is True assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is True @@ -361,6 +364,7 @@ def test_attention_config(): [ "--attention-config=" '{"backend": "FLASHINFER", "flash_attn_version": 2, ' + '"use_prefill_decode_attention": false, ' '"flash_attn_max_num_splits_for_cuda_graph": 8, ' '"use_cudnn_prefill": false, ' '"use_trtllm_ragged_deepseek_prefill": false, ' @@ -374,6 +378,7 @@ def test_attention_config(): assert engine_args.attention_config.backend is not None assert engine_args.attention_config.backend.name == "FLASHINFER" assert engine_args.attention_config.flash_attn_version == 2 + assert engine_args.attention_config.use_prefill_decode_attention is False assert engine_args.attention_config.flash_attn_max_num_splits_for_cuda_graph == 8 assert engine_args.attention_config.use_cudnn_prefill is False assert engine_args.attention_config.use_trtllm_ragged_deepseek_prefill is False diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 9ba04d2bb308..3cff52929146 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -352,6 +352,7 @@ class BackendConfig: name="RocmAttn", attention_config={ "backend": "ROCM_ATTN", + "use_prefill_decode_attention": True, }, comp_config={ "cudagraph_mode": "FULL", diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 3a1ea4873e09..ee072fb1c86d 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -22,6 +22,10 @@ class AttentionConfig: """Force vllm to use a specific flash-attention version (2 or 3). Only valid when using the flash-attention backend.""" + use_prefill_decode_attention: bool = False + """Use separate prefill and decode kernels for attention instead of + the unified triton kernel.""" + flash_attn_max_num_splits_for_cuda_graph: int = 32 """Flash Attention max number splits for cuda graph decode.""" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1e12fc7ea6fd..efba40a6a3be 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -313,7 +313,23 @@ def get_attn_backend_cls( logger.info("Using Aiter Flash Attention backend.") return AttentionBackendEnum.ROCM_AITER_FA.get_path() - # Priority 3: Check for AITER enabled without specific flags + # Priority 3: Check for ROCM_ATTN (prefill-decode split) + from vllm.config import get_current_vllm_config_or_none + + vllm_config = get_current_vllm_config_or_none() + if ( + vllm_config is not None + and vllm_config.attention_config.use_prefill_decode_attention + ): + logger.warning_once( + "use_prefill_decode_attention is deprecated and will be removed in " + "future releases. " + "Use --attention_config.backend to select the desired backend" + ) + logger.info("Using Rocm Attention backend.") + return AttentionBackendEnum.ROCM_ATTN.get_path() + + # Priority 4: Check for AITER enabled without specific flags # This defaults to AITER FA only if MHA is not explicitly disabled if ( envs.VLLM_ROCM_USE_AITER From 9fa4d8fa86a58bd17d76188f8e40f47893d4911c Mon Sep 17 00:00:00 2001 From: Micah Williamson Date: Thu, 26 Feb 2026 20:21:06 +0000 Subject: [PATCH 4/4] update attention backend selection test in accordance with new defaults Signed-off-by: Micah Williamson --- .../attention/test_rocm_attention_backends_selection.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/v1/attention/test_rocm_attention_backends_selection.py b/tests/v1/attention/test_rocm_attention_backends_selection.py index a31c053aed21..7baeb70e18cd 100644 --- a/tests/v1/attention/test_rocm_attention_backends_selection.py +++ b/tests/v1/attention/test_rocm_attention_backends_selection.py @@ -47,7 +47,7 @@ def mock_on_gfx9(): ( {}, None, - AttentionBackendEnum.TRITON_ATTN.get_path(), + AttentionBackendEnum.ROCM_ATTN.get_path(), ), # Test Case 2: Explicit TRITON_ATTN backend ( @@ -74,11 +74,10 @@ def mock_on_gfx9(): AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(), ), # Test Case 6: VLLM_ROCM_USE_AITER=1 - # (defaults to AITER FA when MHA not explicitly disabled) ( {"VLLM_ROCM_USE_AITER": "1"}, None, - AttentionBackendEnum.ROCM_AITER_FA.get_path(), + AttentionBackendEnum.ROCM_ATTN.get_path(), ), # Test Case 7: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=1 ( @@ -102,11 +101,10 @@ def mock_on_gfx9(): AttentionBackendEnum.TRITON_ATTN.get_path(), ), # Test Case 10: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0 - # (explicitly disabled) ( {"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"}, None, - AttentionBackendEnum.TRITON_ATTN.get_path(), + AttentionBackendEnum.ROCM_ATTN.get_path(), ), # Test Case 11: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN (