Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def clear_cache():

DEVICE_REGULAR_ATTN_BACKENDS = {
"cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"],
"hip": ["ROCM_FLASH"],
"hip": ["ROCM_ATTN"],
"cpu": ["TORCH_SDPA"],
}

Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/attention/test_rocm_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def clear_cache():
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
def test_selector(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH")
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_ATTN")

# Set the current platform to ROCm using monkeypatch
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig
Expand Down Expand Up @@ -214,7 +214,7 @@ def run_attention_backend(
actual_backend = _Backend.FLEX_ATTENTION
use_direct_block_mask = False

builder_cls, impl_cls = get_attention_backend(actual_backend)
builder_cls, impl_cls = try_get_attention_backend(actual_backend)

# Mock flashinfer's get_per_layer_parameters if needed
if actual_backend == _Backend.FLASHINFER:
Expand Down
6 changes: 3 additions & 3 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend,
try_get_attention_backend,
)
from vllm import _custom_ops as ops
from vllm.attention.backends.registry import _Backend
Expand Down Expand Up @@ -239,7 +239,7 @@ def run_attention_backend(
) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""

builder_cls, impl_cls = get_attention_backend(backend)
builder_cls, impl_cls = try_get_attention_backend(backend)

# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
Expand Down Expand Up @@ -400,7 +400,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
# Determine if this is decode or prefill
is_decode = []
for i, backend in enumerate(BACKENDS_TO_TEST):
builder_cls, _ = get_attention_backend(backend)
builder_cls, _ = try_get_attention_backend(backend)
is_decode.append(q_len <= builder_cls.reorder_batch_threshold)

# Split q into nope and rope components
Expand Down
52 changes: 14 additions & 38 deletions tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pytest
import torch

from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.abstract import AttentionImpl
from vllm.attention.backends.registry import _Backend, backend_to_class_str
from vllm.config import (
CacheConfig,
CompilationConfig,
Expand All @@ -20,9 +21,11 @@
SchedulerConfig,
VllmConfig,
)
from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec


Expand Down Expand Up @@ -117,44 +120,17 @@ def create_common_attn_metadata(
)


def get_attention_backend(backend_name: _Backend):
"""Set up attention backend classes for testing.

Args:
backend_name: Name of the backend ("flash_attn", "flashinfer", etc.)
vllm_config: VllmConfig instance

Returns:
Tuple of (backend_builder_class, backend_impl_class)
"""
backend_map = {
_Backend.FLASH_ATTN: (
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
if current_platform.is_cuda()
else "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
),
_Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend",
_Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", # noqa: E501
_Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501
_Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
_Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend", # noqa: E501
_Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", # noqa: E501
_Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", # noqa: E501
_Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend", # noqa: E501
_Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", # noqa: E501
}

if backend_name not in backend_map:
raise ValueError(f"Unknown backend: {backend_name}")

backend_class_name = backend_map[backend_name]

def try_get_attention_backend(
backend: _Backend,
) -> tuple[type[AttentionMetadataBuilder], type[AttentionImpl]]:
"""Try to get the attention backend class, skipping test if not found."""
backend_class_str = backend_to_class_str(backend)
try:
backend_class = resolve_obj_by_qualname(backend_class_name)
backend_class = resolve_obj_by_qualname(backend_class_str)
return backend_class.get_builder_cls(), backend_class.get_impl_cls()
except ImportError as e:
pytest.skip(f"{backend_name} not available: {e}")
pytest.skip(f"{backend_class_str} not available: {e}")
raise AssertionError("unreachable") from None


def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec:
Expand Down
10 changes: 5 additions & 5 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import (
Expand Down Expand Up @@ -535,11 +535,11 @@ def create_deterministic_logits(token_ids):
sampling_metadata = mock.MagicMock()

if attn_backend == "FLASH_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)
elif attn_backend == "TRITON_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TRITON_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TRITON_ATTN)
elif attn_backend == "TREE_ATTN":
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
else:
raise ValueError(f"Unsupported attention backend: {attn_backend}")

Expand Down Expand Up @@ -674,7 +674,7 @@ def create_deterministic_logits(token_ids, k: int):
proposer.attn_layer_names = ["layer.0"]

# Get the tree attention metadata builder.
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.TREE_ATTN)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/spec_decode/test_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
get_attention_backend,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import (
Expand Down Expand Up @@ -177,7 +177,7 @@ def create_deterministic_logits(batch_size, vocab_size, token_offset):
sampling_metadata = mock.MagicMock()

# Setup attention metadata
attn_metadata_builder_cls, _ = get_attention_backend(_Backend.FLASH_ATTN)
attn_metadata_builder_cls, _ = try_get_attention_backend(_Backend.FLASH_ATTN)

attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/spec_decode/test_tree_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend,
try_get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import ParallelConfig, SpeculativeConfig
Expand Down Expand Up @@ -63,7 +63,7 @@ def forward_attention(

# Build common metadata.
model_name = "meta-llama/Meta-Llama-3-8B"
builder_cls, impl_cls = get_attention_backend(backend)
builder_cls, impl_cls = try_get_attention_backend(backend)
vllm_config = create_vllm_config(model_name=model_name, max_model_len=max(seq_lens))
if spec_token_tree is not None:
# Create speculative config if token tree is specified.
Expand Down
66 changes: 64 additions & 2 deletions vllm/attention/backends/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
"""Attention backend registry"""

import enum
from typing import Optional

from vllm.utils import resolve_obj_by_qualname


class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
TRITON_ATTN = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
ROCM_ATTN = enum.auto()
ROCM_AITER_MLA = enum.auto()
ROCM_AITER_FA = enum.auto() # used for ViT attn backend
TORCH_SDPA = enum.auto()
Expand All @@ -24,5 +27,64 @@ class _Backend(enum.Enum):
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto()
ROCM_AITER_UNIFIED_ATTN = enum.auto()


BACKEND_MAP = {}


def register_attn_backend(backend: _Backend, class_path: Optional[str] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this work for OOT backends, wouldn't they have to extend the _Backend enum?

Copy link
Contributor Author

@MatthewBonanni MatthewBonanni Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plan is for them to override an existing backend (see #24794 (comment)). We could extend this in the future (and find an alternative to the enum) to make it cleaner though.

"""
Decorator: register a custom attention backend into BACKEND_MAPPING.
- If class_path is provided, use it.
- Otherwise, auto-generate from the class object.
Validation: only checks if 'backend' is a valid _Backend enum member.
Overwriting existing mappings is allowed.
"""
if not isinstance(backend, _Backend):
raise ValueError(f"{backend} is not a valid _Backend enum value.")

def decorator(cls):
path = class_path or f"{cls.__module__}.{cls.__qualname__}"
BACKEND_MAP[backend] = path
return cls

return decorator


def backend_to_class_str(backend: _Backend) -> str:
"""Get the backend class string

Args:
backend: The backend enum value

Returns:
The backend class string
"""
return BACKEND_MAP[backend]


def backend_to_class(backend: _Backend) -> type:
"""Get the backend class.

Args:
backend: The backend enum value

Returns:
The backend class
"""
backend_class_name = backend_to_class_str(backend)
return resolve_obj_by_qualname(backend_class_name)


def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
"""
Convert a string backend name to a _Backend enum value.

Returns:
_Backend: enum value if backend_name is a valid in-tree type
None: otherwise it's an invalid in-tree type or an out-of-tree platform
is loaded.
"""
assert backend_name is not None
return _Backend[backend_name] if backend_name in _Backend.__members__ else None
4 changes: 2 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed.kv_transfer import (
Expand Down
15 changes: 1 addition & 14 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,14 @@

import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname

logger = init_logger(__name__)


def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
"""
Convert a string backend name to a _Backend enum value.

Returns:
* _Backend: enum value if backend_name is a valid in-tree type
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
loaded.
"""
assert backend_name is not None
return _Backend[backend_name] if backend_name in _Backend.__members__ else None


def get_env_variable_attn_backend() -> Optional[_Backend]:
"""
Get the backend override specified by the vLLM attention
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import zmq

from vllm import envs
from vllm.attention.backends.registry import _Backend
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AttentionType,
is_quantized_kv_cache,
)
from vllm.attention.backends.registry import _Backend, register_attn_backend
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
Expand All @@ -38,6 +39,7 @@
logger = init_logger(__name__)


@register_attn_backend(_Backend.TORCH_SDPA)
class TorchSDPABackend(AttentionBackend):
accept_output_buffer: bool = False

Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
AttentionType,
is_quantized_kv_cache,
)
from vllm.attention.backends.registry import _Backend, register_attn_backend
from vllm.attention.layer import Attention
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (
Expand Down Expand Up @@ -45,6 +46,7 @@
logger = init_logger(__name__)


@register_attn_backend(_Backend.FLASH_ATTN)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supports_quant_query_input: bool = True
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
AttentionImpl,
AttentionType,
)
from vllm.attention.backends.registry import _Backend, register_attn_backend
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand Down Expand Up @@ -154,6 +155,7 @@ def trtllm_prefill_attn_kvfp8_dequant(
return mock_kv_cache, mock_block_table


@register_attn_backend(_Backend.FLASHINFER)
class FlashInferBackend(AttentionBackend):
accept_output_buffer: bool = True

Expand Down
Loading