diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 407e3c5a632b..65a77a307009 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -282,7 +282,7 @@ apply_rocm_test_overrides() { # --- LoRA: disable custom paged attention --- if [[ $cmds == *"pytest -v -s lora"* ]]; then - cmds=${cmds//"pytest -v -s lora"/"VLLM_ROCM_CUSTOM_PAGED_ATTN=0 pytest -v -s lora"} + cmds=${cmds//"pytest -v -s lora"/"pytest -v -s lora"} fi # --- Kernel ignores --- diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 7c60a136f790..7f41b07d5eef 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -173,7 +173,7 @@ Priority is **1 = highest** (tried first). | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | | `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A | | `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A | -| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A | +| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | All | N/A | | `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | | `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c4ba8053cc58..407f74fda1bc 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -50,9 +50,9 @@ def is_aiter_found_and_supported() -> bool: VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery. """ if current_platform.is_rocm() and IS_AITER_FOUND: - from vllm.platforms.rocm import on_gfx9 + from vllm.platforms.rocm import on_mi3xx - return on_gfx9() + return on_mi3xx() return False diff --git a/vllm/envs.py b/vllm/envs.py index d6240df36051..37a2c92da648 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -115,7 +115,6 @@ VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True - VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 @@ -994,10 +993,6 @@ def _get_or_set_default() -> str: "VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))), # Pad the weights for the moe kernel "VLLM_ROCM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))), - # custom paged attention kernel for MI3* cards - "VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: ( - os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1") - ), # Whether to use the shuffled kv cache layout "VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: ( os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1") diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 0af98d562c12..1944d2d2fb0e 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -264,7 +264,6 @@ def use_rocm_custom_paged_attention( and (block_size == 16 or block_size == 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 128 * 1024 - and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and sinks is None ) @@ -279,7 +278,6 @@ def use_rocm_custom_paged_attention( and max_seq_len <= 128 * 1024 and alibi_slopes is None and kv_cache_dtype == "auto" - and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN and sinks is None ) @@ -310,7 +308,7 @@ def _get_backend_priorities( use_mla: bool, use_sparse: bool, ) -> list[AttentionBackendEnum]: - from vllm._aiter_ops import rocm_aiter_ops + from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops if use_sparse: return [AttentionBackendEnum.ROCM_AITER_MLA_SPARSE] @@ -327,28 +325,15 @@ def _get_backend_priorities( AttentionBackendEnum.TRITON_MLA, ] - backends = [] - - # Priority 1: Check for AITER Unified Attention (must check before MHA) - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: - backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN) - - # Priority 2: Check for AITER MHA (Flash Attention) - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA: + backends = [ + AttentionBackendEnum.ROCM_ATTN, + ] + if rocm_aiter_ops.is_mha_enabled(): backends.append(AttentionBackendEnum.ROCM_AITER_FA) - - # Priority 3: Check for ROCM_ATTN (prefill-decode split) - from vllm.config import get_current_vllm_config_or_none - - vllm_config = get_current_vllm_config_or_none() - if ( - vllm_config is not None - and vllm_config.attention_config.use_prefill_decode_attention - ): - backends.append(AttentionBackendEnum.ROCM_ATTN) - - # Default: Triton Unified Attention + if is_aiter_found_and_supported(): + backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN) backends.append(AttentionBackendEnum.TRITON_ATTN) + return backends diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 2b801d63fbdf..1e874d32bdd1 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -195,7 +195,10 @@ def supports_mm_prefix(cls) -> bool: @classmethod def supports_sink(cls) -> bool: - return True + # ROCM custom attention kernel does not support sinks. + # Callink this backend with sinks will cause it to fall back to the Triton + # kernel, which is less efficient than the proper triton backends. + return False forward_includes_kv_cache_update: bool = False diff --git a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py index 2dbd8755bf4d..000fd4d43b93 100644 --- a/vllm/v1/attention/ops/chunked_prefill_paged_decode.py +++ b/vllm/v1/attention/ops/chunked_prefill_paged_decode.py @@ -10,11 +10,14 @@ import torch from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from .prefix_prefill import context_attention_fwd +logger = init_logger(__name__) + float8_info = torch.finfo(current_platform.fp8_dtype()) @@ -392,6 +395,10 @@ def chunked_prefill_paged_decode( fp8_out_scale=output_scale, ) else: + logger.warning_once( + "Cannot use ROCm custom paged attention kernel," + " falling back to Triton implementation." + ) real_block_size = value_cache.shape[3] # The standard model directly uses the original block_size. # Non-standard 544 uses 32 to accommodate integer division logic.