Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/scripts/hardware_ci/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ apply_rocm_test_overrides() {

# --- LoRA: disable custom paged attention ---
if [[ $cmds == *"pytest -v -s lora"* ]]; then
cmds=${cmds//"pytest -v -s lora"/"VLLM_ROCM_CUSTOM_PAGED_ATTN=0 pytest -v -s lora"}
cmds=${cmds//"pytest -v -s lora"/"pytest -v -s lora"}
fi

# --- Kernel ignores ---
Expand Down
2 changes: 1 addition & 1 deletion docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ Priority is **1 = highest** (tried first).
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | | ✅ | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | | ✅ | ❌ | All | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |

Expand Down
4 changes: 2 additions & 2 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def is_aiter_found_and_supported() -> bool:
VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery.
"""
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
from vllm.platforms.rocm import on_mi3xx

return on_gfx9()
return on_mi3xx()
return False


Expand Down
5 changes: 0 additions & 5 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
Expand Down Expand Up @@ -994,10 +993,6 @@ def _get_or_set_default() -> str:
"VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
# Pad the weights for the moe kernel
"VLLM_ROCM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))),
# custom paged attention kernel for MI3* cards
"VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: (
os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")
),
# Whether to use the shuffled kv cache layout
"VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: (
os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1")
Expand Down
31 changes: 8 additions & 23 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def use_rocm_custom_paged_attention(
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and sinks is None
)

Expand All @@ -279,7 +278,6 @@ def use_rocm_custom_paged_attention(
and max_seq_len <= 128 * 1024
and alibi_slopes is None
and kv_cache_dtype == "auto"
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
and sinks is None
)

Expand Down Expand Up @@ -310,7 +308,7 @@ def _get_backend_priorities(
use_mla: bool,
use_sparse: bool,
) -> list[AttentionBackendEnum]:
from vllm._aiter_ops import rocm_aiter_ops
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops

if use_sparse:
return [AttentionBackendEnum.ROCM_AITER_MLA_SPARSE]
Expand All @@ -327,28 +325,15 @@ def _get_backend_priorities(
AttentionBackendEnum.TRITON_MLA,
]

backends = []

# Priority 1: Check for AITER Unified Attention (must check before MHA)
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN)

# Priority 2: Check for AITER MHA (Flash Attention)
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA:
backends = [
AttentionBackendEnum.ROCM_ATTN,
]
if rocm_aiter_ops.is_mha_enabled():
backends.append(AttentionBackendEnum.ROCM_AITER_FA)

# Priority 3: Check for ROCM_ATTN (prefill-decode split)
from vllm.config import get_current_vllm_config_or_none

vllm_config = get_current_vllm_config_or_none()
if (
vllm_config is not None
and vllm_config.attention_config.use_prefill_decode_attention
):
backends.append(AttentionBackendEnum.ROCM_ATTN)

# Default: Triton Unified Attention
if is_aiter_found_and_supported():
backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN)
backends.append(AttentionBackendEnum.TRITON_ATTN)

return backends


Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,10 @@ def supports_mm_prefix(cls) -> bool:

@classmethod
def supports_sink(cls) -> bool:
return True
# ROCM custom attention kernel does not support sinks.
# Callink this backend with sinks will cause it to fall back to the Triton
# kernel, which is less efficient than the proper triton backends.
return False

forward_includes_kv_cache_update: bool = False

Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/attention/ops/chunked_prefill_paged_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
import torch

from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton

from .prefix_prefill import context_attention_fwd

logger = init_logger(__name__)

float8_info = torch.finfo(current_platform.fp8_dtype())


Expand Down Expand Up @@ -392,6 +395,10 @@ def chunked_prefill_paged_decode(
fp8_out_scale=output_scale,
)
else:
logger.warning_once(
"Cannot use ROCm custom paged attention kernel,"
" falling back to Triton implementation."
)
real_block_size = value_cache.shape[3]
# The standard model directly uses the original block_size.
# Non-standard 544 uses 32 to accommodate integer division logic.
Expand Down
Loading