From 0001743e6dae31120ea31f800d6fc7d6ed5dceb2 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 29 Sep 2025 10:58:25 -0400 Subject: [PATCH 1/6] add registry Signed-off-by: Matthew Bonanni --- vllm/attention/backends/registry.py | 91 +++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 vllm/attention/backends/registry.py diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py new file mode 100644 index 000000000000..b3162f92cfb0 --- /dev/null +++ b/vllm/attention/backends/registry.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention backend registry""" + +import enum +from typing import Optional, Type + +from vllm.utils import resolve_obj_by_qualname + + +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + TRITON_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + ROCM_AITER_MLA = enum.auto() # Supported by V1 + ROCM_AITER_FA = enum.auto() # used for ViT attn backend + TORCH_SDPA = enum.auto() + FLASHINFER = enum.auto() + FLASHINFER_MLA = enum.auto() + TRITON_MLA = enum.auto() # Supported by V1 + CUTLASS_MLA = enum.auto() + FLASHMLA = enum.auto() # Supported by V1 + FLASH_ATTN_MLA = enum.auto() # Supported by V1 + PALLAS = enum.auto() + IPEX = enum.auto() + DUAL_CHUNK_FLASH_ATTN = enum.auto() + DIFFERENTIAL_FLASH_ATTN = enum.auto() + NO_ATTENTION = enum.auto() + FLEX_ATTENTION = enum.auto() + TREE_ATTN = enum.auto() + ROCM_ATTN = enum.auto() + + +BACKEND_MAPPING = {} + + +def register_attn_backend(backend: _Backend, class_path: str): + """ + Decorator: register a custom attention backend into BACKEND_MAPPING. + Validation: only checks if 'backend' is a valid _Backend enum member. + Overwriting existing mappings is allowed. + """ + if not isinstance(backend, _Backend): + raise ValueError(f"{backend} is not a valid _Backend enum value.") + + def decorator(cls): + BACKEND_MAPPING[backend] = class_path + return cls + + return decorator + + +def backend_to_class_str(backend: _Backend) -> str: + """Get the backend class string + + Args: + backend: The backend enum value + + Returns: + The backend class string + """ + return BACKEND_MAPPING[backend] + + +def backend_to_class(backend: _Backend) -> Type: + """Get the backend class. + + Args: + backend: The backend enum value + + Returns: + The backend class + """ + backend_class_name = backend_to_class_str(backend) + return resolve_obj_by_qualname(backend_class_name) + + +def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: + """ + Convert a string backend name to a _Backend enum value. + + Returns: + _Backend: enum value if backend_name is a valid in-tree type + None: otherwise it's an invalid in-tree type or an out-of-tree platform + is loaded. + """ + assert backend_name is not None + backend_name = backend_name.removesuffix("_VLLM_V1") + return _Backend[backend_name] if backend_name in _Backend.__members__ else \ + None From 7c31e39b9afe8807809458690ba839fed16caa56 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 29 Sep 2025 11:15:11 -0400 Subject: [PATCH 2/6] point to new registry Signed-off-by: Matthew Bonanni --- tests/compile/test_full_graph.py | 2 +- tests/compile/test_fusion_attn.py | 4 +- tests/kernels/attention/test_mha_attn.py | 3 +- tests/kernels/utils.py | 2 +- tests/v1/attention/test_attention_backends.py | 4 +- tests/v1/attention/test_mla_backends.py | 4 +- tests/v1/attention/utils.py | 3 +- tests/v1/spec_decode/test_eagle.py | 4 +- tests/v1/spec_decode/test_mtp.py | 4 +- tests/v1/spec_decode/test_tree_attention.py | 3 +- vllm/attention/backends/registry.py | 62 ------------------- vllm/attention/layer.py | 3 +- vllm/attention/selector.py | 3 +- .../kv_connector/v1/nixl_connector.py | 3 +- vllm/envs.py | 5 +- vllm/model_executor/models/dots_ocr.py | 2 +- vllm/model_executor/models/ernie45_vl.py | 3 +- vllm/model_executor/models/glm4_1v.py | 2 +- vllm/model_executor/models/keye.py | 2 +- vllm/model_executor/models/qwen2_5_vl.py | 2 +- vllm/model_executor/models/qwen2_vl.py | 3 +- vllm/model_executor/models/qwen3_vl.py | 2 +- vllm/model_executor/models/siglip2navit.py | 2 +- vllm/model_executor/models/vision.py | 3 +- vllm/platforms/__init__.py | 1 - vllm/platforms/cpu.py | 3 +- vllm/platforms/cuda.py | 3 +- vllm/platforms/interface.py | 25 +------- vllm/platforms/rocm.py | 3 +- vllm/platforms/tpu.py | 3 +- vllm/platforms/xpu.py | 3 +- 31 files changed, 50 insertions(+), 121 deletions(-) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index f9f146810924..3ecda1a8ec33 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -11,8 +11,8 @@ import torch from tests.quantization.utils import is_quant_method_supported -from tests.v1.attention.utils import _Backend from vllm import LLM, SamplingParams +from vllm.attention.backends.registry import _Backend from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, PassConfig) diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index eb8c49135428..077cf11d048a 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -8,11 +8,11 @@ from tests.compile.backend import LazyInitPass, TestBackend from tests.models.utils import check_outputs_equal -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata) +from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm import LLM, SamplingParams from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention, AttentionMetadata +from vllm.attention.backends.registry import _Backend from vllm.attention.selector import global_force_attn_backend_context_manager from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index d37b968ed979..cea08e19f52d 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -10,8 +10,9 @@ import pytest import torch +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import MultiHeadAttention -from vllm.attention.selector import _Backend, _cached_get_attn_backend +from vllm.attention.selector import _cached_get_attn_backend from vllm.platforms import current_platform from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cuda import CudaPlatform diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 0fdaa600aefa..db6f29c28c95 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -15,10 +15,10 @@ from tests.kernels.quant_utils import native_w8a8_block_matmul from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType +from vllm.attention.backends.registry import _Backend from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) -from vllm.platforms.interface import _Backend from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 6c17be759ab6..24cdd8afbb3b 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -8,11 +8,11 @@ import torch from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) +from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 228551573ba8..f2d0a5b2407a 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -6,12 +6,12 @@ import pytest import torch -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) from vllm import _custom_ops as ops +from vllm.attention.backends.registry import _Backend from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index adfe2b2db040..2bea45210ff3 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -8,10 +8,11 @@ import pytest import torch +from vllm.attention.backends.registry import _Backend from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, LoadConfig, ModelConfig, ModelDType, ParallelConfig, SchedulerConfig, VllmConfig) -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.utils import resolve_obj_by_qualname from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import FullAttentionSpec diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 690732eb1232..1853fa13118f 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -8,10 +8,10 @@ import torch from tests.utils import get_attn_backend_list_based_on_platform -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, get_attention_backend) +from vllm.attention.backends.registry import _Backend from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index e4881859ece1..d6e5d02cf8d7 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -6,10 +6,10 @@ import pytest import torch -from tests.v1.attention.utils import (BatchSpec, _Backend, - create_common_attn_metadata, +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, create_standard_kv_cache_spec, get_attention_backend) +from vllm.attention.backends.registry import _Backend from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) diff --git a/tests/v1/spec_decode/test_tree_attention.py b/tests/v1/spec_decode/test_tree_attention.py index 51a737496dff..ebb9a3d97861 100644 --- a/tests/v1/spec_decode/test_tree_attention.py +++ b/tests/v1/spec_decode/test_tree_attention.py @@ -6,9 +6,10 @@ import torch -from tests.v1.attention.utils import (_Backend, create_standard_kv_cache_spec, +from tests.v1.attention.utils import (create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) +from vllm.attention.backends.registry import _Backend from vllm.config import ParallelConfig, SpeculativeConfig from vllm.v1.attention.backends.utils import CommonAttentionMetadata diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index b3162f92cfb0..93e87a132121 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -3,9 +3,6 @@ """Attention backend registry""" import enum -from typing import Optional, Type - -from vllm.utils import resolve_obj_by_qualname class _Backend(enum.Enum): @@ -30,62 +27,3 @@ class _Backend(enum.Enum): FLEX_ATTENTION = enum.auto() TREE_ATTN = enum.auto() ROCM_ATTN = enum.auto() - - -BACKEND_MAPPING = {} - - -def register_attn_backend(backend: _Backend, class_path: str): - """ - Decorator: register a custom attention backend into BACKEND_MAPPING. - Validation: only checks if 'backend' is a valid _Backend enum member. - Overwriting existing mappings is allowed. - """ - if not isinstance(backend, _Backend): - raise ValueError(f"{backend} is not a valid _Backend enum value.") - - def decorator(cls): - BACKEND_MAPPING[backend] = class_path - return cls - - return decorator - - -def backend_to_class_str(backend: _Backend) -> str: - """Get the backend class string - - Args: - backend: The backend enum value - - Returns: - The backend class string - """ - return BACKEND_MAPPING[backend] - - -def backend_to_class(backend: _Backend) -> Type: - """Get the backend class. - - Args: - backend: The backend enum value - - Returns: - The backend class - """ - backend_class_name = backend_to_class_str(backend) - return resolve_obj_by_qualname(backend_class_name) - - -def backend_name_to_enum(backend_name: str) -> Optional[_Backend]: - """ - Convert a string backend name to a _Backend enum value. - - Returns: - _Backend: enum value if backend_name is a valid in-tree type - None: otherwise it's an invalid in-tree type or an out-of-tree platform - is loaded. - """ - assert backend_name is not None - backend_name = backend_name.removesuffix("_VLLM_V1") - return _Backend[backend_name] if backend_name in _Backend.__members__ else \ - None diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 4ce6a864d7ad..113602645e89 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -10,6 +10,7 @@ import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.registry import _Backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config @@ -26,7 +27,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.model_executor.models.vision import get_vit_attn_backend -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.utils import GiB_bytes, direct_register_custom_op logger = init_logger(__name__) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 6f048e589f7f..d3214fecfa70 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -11,8 +11,9 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.registry import _Backend from vllm.logger import init_logger -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname logger = init_logger(__name__) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 55d87ea994b5..4706c5130899 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -20,6 +20,7 @@ import zmq from vllm import envs +from vllm.attention.backends.registry import _Backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -32,7 +33,7 @@ from vllm.distributed.utils import divide from vllm.forward_context import ForwardContext from vllm.logger import init_logger -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput diff --git a/vllm/envs.py b/vllm/envs.py index ffa7ed5c3aa5..7083c45682c1 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -617,8 +617,9 @@ def get_vllm_port() -> Optional[int]: # All possible options loaded dynamically from _Backend enum "VLLM_ATTENTION_BACKEND": env_with_choices("VLLM_ATTENTION_BACKEND", None, - lambda: list(__import__('vllm.platforms.interface', \ - fromlist=['_Backend'])._Backend.__members__.keys())), + lambda: list(__import__( + 'vllm.attention.backends.registry', + fromlist=['_Backend'])._Backend.__members__.keys())), # If set, vllm will use flashinfer sampler "VLLM_USE_FLASHINFER_SAMPLER": diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 4845f19bcbc4..8b007f6bab80 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -9,6 +9,7 @@ from torch.nn import LayerNorm from transformers.models.qwen2_vl import Qwen2VLProcessor +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import utils as dist_utils @@ -38,7 +39,6 @@ from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalDataDict -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig, DotsVisionConfig) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index a73ec4f88ffe..a5332658cb75 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -34,6 +34,7 @@ from einops import rearrange, repeat from transformers import BatchFeature +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state @@ -54,7 +55,7 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 722f1e428be7..315a057e6a7d 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -46,6 +46,7 @@ Glm4vVideoProcessor) from transformers.video_utils import VideoMetadata +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import (get_tensor_model_parallel_world_size, @@ -69,7 +70,6 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 10b5c45169f4..90de0582b94a 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -17,6 +17,7 @@ BaseModelOutputWithPooling) from transformers.utils import torch_int +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size @@ -39,7 +40,6 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index da3889d31a7d..a70df3b72be4 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -38,6 +38,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state @@ -62,7 +63,6 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import PromptReplacement, PromptUpdate -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index f83a411459cc..0a7e21f32338 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -41,6 +41,7 @@ from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( Qwen2VLVideoProcessor) +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather @@ -63,7 +64,7 @@ BaseProcessingInfo, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils.tensor_schema import TensorSchema, TensorShape diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index ce92557d6424..012130ed528f 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -43,6 +43,7 @@ smart_resize as video_smart_resize) from transformers.video_utils import VideoMetadata +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -66,7 +67,6 @@ PromptReplacement, PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index 18de4b576c49..d111a10809e7 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -13,6 +13,7 @@ from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig +from vllm.attention.backends.registry import _Backend from vllm.attention.layer import check_upstream_fa_availability from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -22,7 +23,6 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.platforms import _Backend from .vision import get_vit_attn_backend diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index e077691fcec2..50596f331d10 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -11,11 +11,12 @@ from transformers import PretrainedConfig from typing_extensions import assert_never +from vllm.attention.backends.registry import _Backend from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) from vllm.logger import init_logger -from vllm.platforms import _Backend, current_platform +from vllm.platforms import current_platform logger = init_logger(__name__) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 9b64817da648..7549de480ee6 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -9,7 +9,6 @@ from vllm.plugins import load_plugins_by_group from vllm.utils import resolve_obj_by_qualname, supports_xccl -from .interface import _Backend # noqa: F401 from .interface import CpuArchEnum, Platform, PlatformEnum logger = logging.getLogger(__name__) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 0b26446a87d8..2bd1afec53b0 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -12,10 +12,11 @@ import torch +from vllm.attention.backends.registry import _Backend from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS -from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend +from .interface import CpuArchEnum, Platform, PlatformEnum logger = init_logger(__name__) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a9a8d9ea2625..3cc90e5fac09 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -17,10 +17,11 @@ # import custom ops, trigger op registration import vllm._C # noqa import vllm.envs as envs +from vllm.attention.backends.registry import _Backend from vllm.logger import init_logger from vllm.utils import cuda_device_count_stateless, import_pynvml -from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 1691ad62650b..d19d8f5c4577 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -13,6 +13,7 @@ import torch from torch.distributed import PrefixStore, ProcessGroup +from vllm.attention.backends.registry import _Backend from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger @@ -38,30 +39,6 @@ def in_wsl() -> bool: return "microsoft" in " ".join(uname()).lower() -class _Backend(enum.Enum): - FLASH_ATTN = enum.auto() - TRITON_ATTN = enum.auto() - XFORMERS = enum.auto() - ROCM_FLASH = enum.auto() - ROCM_AITER_MLA = enum.auto() # Supported by V1 - ROCM_AITER_FA = enum.auto() # used for ViT attn backend - TORCH_SDPA = enum.auto() - FLASHINFER = enum.auto() - FLASHINFER_MLA = enum.auto() - TRITON_MLA = enum.auto() # Supported by V1 - CUTLASS_MLA = enum.auto() - FLASHMLA = enum.auto() # Supported by V1 - FLASH_ATTN_MLA = enum.auto() # Supported by V1 - PALLAS = enum.auto() - IPEX = enum.auto() - DUAL_CHUNK_FLASH_ATTN = enum.auto() - DIFFERENTIAL_FLASH_ATTN = enum.auto() - NO_ATTENTION = enum.auto() - FLEX_ATTENTION = enum.auto() - TREE_ATTN = enum.auto() - ROCM_ATTN = enum.auto() - - class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 14762f1b7094..5ad4dc5edde8 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -11,10 +11,11 @@ from torch.distributed.distributed_c10d import is_nccl_available import vllm.envs as envs +from vllm.attention.backends.registry import _Backend from vllm.logger import init_logger from vllm.utils import cuda_device_count_stateless -from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 4a4931f7f009..bef501a2e1e9 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -6,12 +6,13 @@ import torch from tpu_info import device +from vllm.attention.backends.registry import _Backend from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS -from .interface import Platform, PlatformEnum, _Backend +from .interface import Platform, PlatformEnum if TYPE_CHECKING: from vllm.config import BlockSize, ModelConfig, VllmConfig diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 12d6a2a2d1ba..a34c6a0e57ff 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -7,10 +7,11 @@ import torch import vllm.envs as envs +from vllm.attention.backends.registry import _Backend from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS -from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig From a0e49d5af0c084170135dac082ae6d737ae60c5d Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 30 Sep 2025 16:43:21 -0400 Subject: [PATCH 3/6] fix circular import (fix docs build) Signed-off-by: Matthew Bonanni --- vllm/platforms/interface.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d19d8f5c4577..df1395fa842a 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -13,17 +13,18 @@ import torch from torch.distributed import PrefixStore, ProcessGroup -from vllm.attention.backends.registry import _Backend from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils import FlexibleArgumentParser else: + _Backend = None ModelConfig = None VllmConfig = None LoRARequest = None @@ -164,11 +165,12 @@ def device_id_to_physical_device_id(cls, device_id: int): @classmethod def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> _Backend: + dtype: torch.dtype) -> "_Backend": + from vllm.attention.backends.registry import _Backend return _Backend.TORCH_SDPA @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, has_sink: bool, use_sparse: bool) -> str: From 82aa912a83a9ad2fa0ec5a710fa7e06030379268 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Wed, 1 Oct 2025 11:26:30 -0400 Subject: [PATCH 4/6] fix circular imports Signed-off-by: Matthew Bonanni --- vllm/platforms/cpu.py | 6 ++++-- vllm/platforms/cuda.py | 8 ++++++-- vllm/platforms/rocm.py | 8 ++++++-- vllm/platforms/tpu.py | 6 ++++-- vllm/platforms/xpu.py | 6 ++++-- 5 files changed, 24 insertions(+), 10 deletions(-) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 2bd1afec53b0..436e295e58e6 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -12,7 +12,6 @@ import torch -from vllm.attention.backends.registry import _Backend from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS @@ -21,8 +20,10 @@ logger = init_logger(__name__) if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import VllmConfig else: + _Backend = None VllmConfig = None @@ -91,10 +92,11 @@ def get_device_name(cls, device_id: int = 0) -> str: return "cpu" @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, has_sink: bool, use_sparse: bool) -> str: + from vllm.attention.backends.registry import _Backend if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 3cc90e5fac09..b7baa614957e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -17,14 +17,16 @@ # import custom ops, trigger op registration import vllm._C # noqa import vllm.envs as envs -from vllm.attention.backends.registry import _Backend from vllm.logger import init_logger from vllm.utils import cuda_device_count_stateless, import_pynvml from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig +else: + _Backend = None logger = init_logger(__name__) @@ -203,7 +205,8 @@ def get_current_memory_usage(cls, @classmethod def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> _Backend: + dtype: torch.dtype) -> "_Backend": + from vllm.attention.backends.registry import _Backend # For Blackwell GPUs, force TORCH_SDPA for now. # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 @@ -231,6 +234,7 @@ def get_vit_attn_backend(cls, head_size: int, def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink, use_sparse) -> str: + from vllm.attention.backends.registry import _Backend if use_mla: if not use_v1: raise RuntimeError( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 5ad4dc5edde8..e12967ad2587 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -11,14 +11,16 @@ from torch.distributed.distributed_c10d import is_nccl_available import vllm.envs as envs -from vllm.attention.backends.registry import _Backend from vllm.logger import init_logger from vllm.utils import cuda_device_count_stateless from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig +else: + _Backend = None logger = init_logger(__name__) @@ -183,7 +185,8 @@ class RocmPlatform(Platform): @classmethod def get_vit_attn_backend(cls, head_size: int, - dtype: torch.dtype) -> _Backend: + dtype: torch.dtype) -> "_Backend": + from vllm.attention.backends.registry import _Backend if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()): # Note: AITER FA is only supported for Qwen-VL models. @@ -197,6 +200,7 @@ def get_vit_attn_backend(cls, head_size: int, def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, has_sink, use_sparse) -> str: + from vllm.attention.backends.registry import _Backend if use_sparse: raise NotImplementedError( "Sparse Attention is not supported on ROCm.") diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index bef501a2e1e9..91a01a4f4ee9 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -6,7 +6,6 @@ import torch from tpu_info import device -from vllm.attention.backends.registry import _Backend from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType @@ -15,6 +14,7 @@ from .interface import Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import BlockSize, ModelConfig, VllmConfig from vllm.pooling_params import PoolingParams else: @@ -22,6 +22,7 @@ ModelConfig = None VllmConfig = None PoolingParams = None + _Backend = None logger = init_logger(__name__) @@ -47,10 +48,11 @@ class TpuPlatform(Platform): ] @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, has_sink, use_sparse) -> str: + from vllm.attention.backends.registry import _Backend if use_sparse: raise NotImplementedError( "Sparse Attention is not supported on TPU.") diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index a34c6a0e57ff..3ccbae58726f 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -7,17 +7,18 @@ import torch import vllm.envs as envs -from vllm.attention.backends.registry import _Backend from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import ModelConfig, VllmConfig else: ModelConfig = None VllmConfig = None + _Backend = None logger = init_logger(__name__) @@ -34,10 +35,11 @@ class XPUPlatform(Platform): device_control_env_var: str = "ZE_AFFINITY_MASK" @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, has_sink: bool, use_sparse) -> str: + from vllm.attention.backends.registry import _Backend if use_sparse: raise NotImplementedError( "Sparse Attention is not supported on XPU.") From 45eb2f51587b8b4f0d4e5a94e3d54d9eb9475b7d Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 2 Oct 2025 12:56:09 -0400 Subject: [PATCH 5/6] remove old comments Signed-off-by: Matthew Bonanni --- vllm/attention/backends/registry.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 93e87a132121..af1cf57425ba 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -10,15 +10,15 @@ class _Backend(enum.Enum): TRITON_ATTN = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() - ROCM_AITER_MLA = enum.auto() # Supported by V1 + ROCM_AITER_MLA = enum.auto() ROCM_AITER_FA = enum.auto() # used for ViT attn backend TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() FLASHINFER_MLA = enum.auto() - TRITON_MLA = enum.auto() # Supported by V1 + TRITON_MLA = enum.auto() CUTLASS_MLA = enum.auto() - FLASHMLA = enum.auto() # Supported by V1 - FLASH_ATTN_MLA = enum.auto() # Supported by V1 + FLASHMLA = enum.auto() + FLASH_ATTN_MLA = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() DUAL_CHUNK_FLASH_ATTN = enum.auto() From bc1d271a19ed823aabd1713c30f67f5fa17c2448 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Thu, 2 Oct 2025 12:57:54 -0400 Subject: [PATCH 6/6] remove deprecated backends Signed-off-by: Matthew Bonanni --- vllm/attention/backends/registry.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index af1cf57425ba..6377e8619b3c 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -21,8 +21,6 @@ class _Backend(enum.Enum): FLASH_ATTN_MLA = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() - DUAL_CHUNK_FLASH_ATTN = enum.auto() - DIFFERENTIAL_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto() TREE_ATTN = enum.auto()