Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def mock_on_gfx9():
(
{},
None,
AttentionBackendEnum.TRITON_ATTN.get_path(),
AttentionBackendEnum.ROCM_ATTN.get_path(),
),
# Test Case 2: Explicit TRITON_ATTN backend
(
Expand All @@ -74,11 +74,10 @@ def mock_on_gfx9():
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
),
# Test Case 6: VLLM_ROCM_USE_AITER=1
# (defaults to AITER FA when MHA not explicitly disabled)
(
{"VLLM_ROCM_USE_AITER": "1"},
None,
AttentionBackendEnum.ROCM_AITER_FA.get_path(),
AttentionBackendEnum.ROCM_ATTN.get_path(),
),
# Test Case 7: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=1
(
Expand All @@ -102,11 +101,10 @@ def mock_on_gfx9():
AttentionBackendEnum.TRITON_ATTN.get_path(),
),
# Test Case 10: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
# (explicitly disabled)
(
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
None,
AttentionBackendEnum.TRITON_ATTN.get_path(),
AttentionBackendEnum.ROCM_ATTN.get_path(),
),
# Test Case 11: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
(
Expand Down
22 changes: 11 additions & 11 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_DISABLE_PYNCCL: bool = False
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
Expand Down Expand Up @@ -281,6 +280,14 @@ def maybe_convert_bool(value: str | None) -> bool | None:
return bool(int(value))


def use_aiter() -> bool:
from vllm._aiter_ops import is_aiter_found_and_supported

return is_aiter_found_and_supported() and os.getenv(
"VLLM_ROCM_USE_AITER", "True"
).lower() in ("true", "1")


def disable_compile_cache() -> bool:
return bool(int(os.getenv("VLLM_DISABLE_COMPILE_CACHE", "0")))

Expand Down Expand Up @@ -951,14 +958,7 @@ def _get_or_set_default() -> str:
),
# Disable aiter ops unless specifically enabled.
# Acts as a parent switch to enable the rest of the other operations.
"VLLM_ROCM_USE_AITER": lambda: (
os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1")
),
# Whether to use aiter paged attention.
# By default is disabled.
"VLLM_ROCM_USE_AITER_PAGED_ATTN": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in ("true", "1")
),
"VLLM_ROCM_USE_AITER": use_aiter,
# use aiter linear op if aiter ops are enabled
# The following list of related ops
# - scaled_mm (per-tensor / rowwise)
Expand All @@ -980,9 +980,9 @@ def _get_or_set_default() -> str:
os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in ("true", "1")
),
# Whether to use aiter mha ops.
# By default is enabled.
# By default is disabled.
"VLLM_ROCM_USE_AITER_MHA": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in ("true", "1")
os.getenv("VLLM_ROCM_USE_AITER_MHA", "False").lower() in ("true", "1")
),
# Whether to use aiter fp4 gemm asm.
# By default is disabled.
Expand Down
11 changes: 8 additions & 3 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ def get_attn_backend_cls(
vllm_config is not None
and vllm_config.attention_config.use_prefill_decode_attention
):
logger.warning_once(
"use_prefill_decode_attention is deprecated and will be removed in "
"future releases. "
"Use --attention_config.backend to select the desired backend"
)
logger.info("Using Rocm Attention backend.")
return AttentionBackendEnum.ROCM_ATTN.get_path()

Expand All @@ -334,9 +339,9 @@ def get_attn_backend_cls(
logger.info("Using Aiter Flash Attention backend.")
return AttentionBackendEnum.ROCM_AITER_FA.get_path()

# Default: Triton Unified Attention
logger.info("Using Triton Attention backend.")
return AttentionBackendEnum.TRITON_ATTN.get_path()
# Default: ROCm split Attention
logger.info("Using ROCm Attention backend.")
return AttentionBackendEnum.ROCM_ATTN.get_path()

raise RuntimeError(
f"Attention backend {selected_backend.name} is not supported on "
Expand Down