diff --git a/tests/models/multimodal/generation/test_dots_ocr.py b/tests/models/multimodal/generation/test_dots_ocr.py new file mode 100644 index 000000000000..3f826c3a6530 --- /dev/null +++ b/tests/models/multimodal/generation/test_dots_ocr.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import asdict + +import pytest + +from vllm import LLM, EngineArgs, SamplingParams +from vllm.attention.backends.registry import _MHA_Backend +from vllm.multimodal.utils import encode_image_base64 +from vllm.platforms import current_platform + +MODEL_NAME = "rednote-hilab/dots.ocr" + +# Exact prompt from dots.ocr +# https://github.com/rednote-hilab/dots.ocr/blob/d72d1d8c5bdd0362eb264f714cdbd1e5daa7cdff/dots_ocr/utils/prompts.py#L3 +# ruff: noqa: E501 +PROMPT = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. + +1. Bbox format: [x1, y1, x2, y2] + +2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. + +3. Text Extraction & Formatting Rules: + - Picture: For the 'Picture' category, the text field should be omitted. + - Formula: Format its text as LaTeX. + - Table: Format its text as HTML. + - All Others (Text, Title, etc.): Format their text as Markdown. + +4. Constraints: + - The output text must be the original text from the image, with no translation. + - All layout elements must be sorted according to human reading order. + +5. Final Output: The entire output must be a single JSON object. +""" + + +@pytest.mark.core_model +@pytest.mark.parametrize("prompt", [PROMPT]) +@pytest.mark.parametrize( + "mm_encoder_attn_backend", + [None] + current_platform.get_supported_vit_attn_backends(), +) +def test_dots_ocr_vit_attn_backend_functionality( + image_assets, + prompt: str, + mm_encoder_attn_backend: _MHA_Backend | None, +): + # images = [asset.pil_image for asset in image_assets] + # Use the stop_sign image which has clear text + stop_sign_image = [ + asset.pil_image for asset in image_assets if asset.name == "stop_sign" + ][0] + + image_urls = [f"data:image/jpeg;base64,{encode_image_base64(stop_sign_image)}"] + + engine_args = EngineArgs( + model=MODEL_NAME, + trust_remote_code=True, + max_model_len=32768, + max_num_seqs=1, + limit_mm_per_prompt={"image": 1}, + mm_encoder_attn_backend=mm_encoder_attn_backend, + ) + + # From the demo example of dots.ocr + # https://github.com/rednote-hilab/dots.ocr/blob/d72d1d8c5bdd0362eb264f714cdbd1e5daa7cdff/dots_ocr/model/inference.py#L22 + + placeholders = [ + {"type": "image_url", "image_url": {"url": image_url}} + for image_url in image_urls + ] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": f"<|img|><|imgpad|><|endofimg|>{prompt}"}, + ], + }, + ] + + engine_args = asdict(engine_args) | {"seed": 42} + llm = LLM(**engine_args) + + sampling_params = SamplingParams( + temperature=0.1, + max_tokens=16384, + stop_token_ids=None, + top_p=0.9, + ) + + outputs = llm.chat( + messages=messages, + sampling_params=sampling_params, + ) + + print("-" * 50) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + assert len(generated_text) > 10, ( + f"Generated text is too short: {generated_text}" + ) + assert "stop" in generated_text.lower(), ( + f"Generated text does not contain 'stop': {generated_text}" + ) + print("-" * 50) diff --git a/tests/models/multimodal/generation/test_ernie45_vl.py b/tests/models/multimodal/generation/test_ernie45_vl.py new file mode 100644 index 000000000000..eec064f85bc6 --- /dev/null +++ b/tests/models/multimodal/generation/test_ernie45_vl.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import asdict + +import pytest +from transformers import AutoProcessor + +from vllm import LLM, EngineArgs, SamplingParams +from vllm.attention.backends.registry import _MHA_Backend +from vllm.multimodal.utils import encode_image_base64 +from vllm.platforms import current_platform + +from ....utils import large_gpu_test + +MODEL_NAME = "baidu/ERNIE-4.5-VL-28B-A3B-PT" + +QUESTION = "What is the content of each image?" + + +@large_gpu_test(min_gb=80) +@pytest.mark.parametrize("question", [QUESTION]) +@pytest.mark.parametrize( + "mm_encoder_attn_backend", + [None] + current_platform.get_supported_vit_attn_backends(), +) +def test_ernie45_vl_vit_attn_backend_functionality( + image_assets, + question: str, + mm_encoder_attn_backend: _MHA_Backend | None, +): + images = [asset.pil_image for asset in image_assets] + + image_urls = [ + f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images + ] + + engine_args = EngineArgs( + model=MODEL_NAME, + trust_remote_code=True, + max_model_len=16384, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + mm_encoder_attn_backend=mm_encoder_attn_backend, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + }, + ] + + processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + engine_args = asdict(engine_args) | {"seed": 42} + llm = LLM(**engine_args) + + sampling_params = SamplingParams( + temperature=0.0, max_tokens=256, stop_token_ids=None + ) + + outputs = llm.generate( + { + "prompt": prompt, + "multi_modal_data": {"image": images}, + }, + sampling_params=sampling_params, + ) + + print("-" * 50) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + assert len(generated_text) > 10, ( + f"Generated text is too short: {generated_text}" + ) + print("-" * 50) diff --git a/tests/models/multimodal/generation/test_glm4_1v.py b/tests/models/multimodal/generation/test_glm4_1v.py new file mode 100644 index 000000000000..018e3d6b5aba --- /dev/null +++ b/tests/models/multimodal/generation/test_glm4_1v.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import asdict + +import pytest +from transformers import AutoProcessor + +from vllm import LLM, EngineArgs, SamplingParams +from vllm.attention.backends.registry import _MHA_Backend +from vllm.multimodal.utils import encode_image_base64 +from vllm.platforms import current_platform + +MODEL_NAME = "zai-org/GLM-4.1V-9B-Thinking" + +QUESTION = "What is the content of each image?" + + +@pytest.mark.parametrize("question", [QUESTION]) +@pytest.mark.parametrize( + "mm_encoder_attn_backend", + [None] + current_platform.get_supported_vit_attn_backends(), +) +def test_glm4_1v_vit_attn_backend_vit_attn_backend_functionality( + image_assets, + question: str, + mm_encoder_attn_backend: _MHA_Backend, +): + images = [asset.pil_image for asset in image_assets] + + image_urls = [ + f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images + ] + + engine_args = EngineArgs( + model=MODEL_NAME, + trust_remote_code=True, + max_model_len=32768, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + mm_encoder_attn_backend=mm_encoder_attn_backend, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + }, + ] + + processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + engine_args = asdict(engine_args) | {"seed": 42} + llm = LLM(**engine_args) + + sampling_params = SamplingParams( + temperature=0.0, max_tokens=256, stop_token_ids=None + ) + + outputs = llm.generate( + { + "prompt": prompt, + "multi_modal_data": {"image": images}, + }, + sampling_params=sampling_params, + ) + + print("-" * 50) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + assert len(generated_text) > 10, ( + f"Generated text is too short: {generated_text}" + ) + print("-" * 50) diff --git a/tests/models/multimodal/generation/test_keye.py b/tests/models/multimodal/generation/test_keye.py index 6f98bde1d91e..77987cdbaf42 100644 --- a/tests/models/multimodal/generation/test_keye.py +++ b/tests/models/multimodal/generation/test_keye.py @@ -1,35 +1,38 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import asdict -from typing import NamedTuple import pytest -from PIL.Image import Image from transformers import AutoProcessor from vllm import LLM, EngineArgs, SamplingParams +from vllm.attention.backends.registry import _MHA_Backend from vllm.multimodal.utils import encode_image_base64 +from vllm.platforms import current_platform MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview" QUESTION = "What is the content of each image?" -class ModelRequestData(NamedTuple): - engine_args: EngineArgs - prompt: str - image_data: list[Image] - stop_token_ids: list[int] | None = None - chat_template: str | None = None - sampling_params: SamplingParams | None = None - - -@pytest.mark.core_model @pytest.mark.parametrize("question", [QUESTION]) -def test_keye_vl( +@pytest.mark.parametrize( + "mm_encoder_attn_backend", + [None] + current_platform.get_supported_vit_attn_backends(), +) +def test_keye_vl_vit_attn_backend_functionality( image_assets, question: str, + mm_encoder_attn_backend: _MHA_Backend | None, ): + if mm_encoder_attn_backend is not None and mm_encoder_attn_backend not in { + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.XFORMERS, + _MHA_Backend.VLLM_FLASH_ATTN, + _MHA_Backend.ROCM_AITER_FA, + }: + pytest.skip(f"Keye-VL does not support {mm_encoder_attn_backend} backend now.") + images = [asset.pil_image for asset in image_assets] image_urls = [ @@ -42,6 +45,7 @@ def test_keye_vl( max_model_len=8192, max_num_seqs=5, limit_mm_per_prompt={"image": len(image_urls)}, + mm_encoder_attn_backend=mm_encoder_attn_backend, ) placeholders = [{"type": "image", "image": url} for url in image_urls] diff --git a/tests/models/multimodal/generation/test_maverick.py b/tests/models/multimodal/generation/test_maverick.py index 6fc2efa418dd..c5afa3df2e4b 100644 --- a/tests/models/multimodal/generation/test_maverick.py +++ b/tests/models/multimodal/generation/test_maverick.py @@ -21,6 +21,8 @@ from transformers import AutoConfig, AutoProcessor, AutoTokenizer, GenerationConfig from vllm import LLM, SamplingParams +from vllm.attention.backends.registry import _MHA_Backend +from vllm.platforms import current_platform from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, FullAttentionSpec @@ -600,6 +602,10 @@ def run_reduced_model(llm: LLM, should_profile: bool = False) -> None: ) @pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("tp,ep", [(2, True)]) +@pytest.mark.parametrize( + "mm_encoder_attn_backend", + [None] + current_platform.get_supported_vit_attn_backends(), +) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dummy_maverick( monkeypatch, @@ -610,6 +616,7 @@ def test_dummy_maverick( enforce_eager: bool, tp: int, ep: bool, + mm_encoder_attn_backend: _MHA_Backend | None, output_dir: str = "/tmp/reduced_maverick", force_recreate: bool = True, profile: bool = False, @@ -638,6 +645,7 @@ def test_dummy_maverick( enforce_eager=enforce_eager, tensor_parallel_size=tp, enable_expert_parallel=ep, + mm_encoder_attn_backend=mm_encoder_attn_backend, ) check_attention_spec_interleaved_rope( diff --git a/tests/models/multimodal/generation/test_ovis2_5.py b/tests/models/multimodal/generation/test_ovis2_5.py new file mode 100644 index 000000000000..bfdaa45f7904 --- /dev/null +++ b/tests/models/multimodal/generation/test_ovis2_5.py @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import asdict + +import pytest + +from vllm import LLM, EngineArgs, SamplingParams +from vllm.attention.backends.registry import _MHA_Backend +from vllm.multimodal.utils import encode_image_base64 +from vllm.platforms import current_platform + +# This model uses the ViT Class from +# vllm/model_executor/models/siglip2navit.py +MODEL_NAME = "AIDC-AI/Ovis2.5-2B" + +QUESTION = "What is the content of each image?" + + +@pytest.mark.parametrize("question", [QUESTION]) +@pytest.mark.parametrize( + "mm_encoder_attn_backend", + [None] + current_platform.get_supported_vit_attn_backends(), +) +def test_ovis2_5_vit_attn_backend_functionality( + image_assets, + question: str, + mm_encoder_attn_backend: _MHA_Backend | None, +): + images = [asset.pil_image for asset in image_assets] + + image_urls = [ + f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images + ] + + engine_args = EngineArgs( + model=MODEL_NAME, + trust_remote_code=True, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={"image": len(image_urls)}, + mm_encoder_attn_backend=mm_encoder_attn_backend, + ) + + placeholders = "\n".join( + f"Image-{i}: \n" for i, _ in enumerate(image_urls, start=1) + ) + prompt = ( + f"<|im_start|>user\n\n{placeholders}\n{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + engine_args = asdict(engine_args) | {"seed": 42} + llm = LLM(**engine_args) + + sampling_params = SamplingParams( + temperature=0.0, max_tokens=256, stop_token_ids=None + ) + + outputs = llm.generate( + { + "prompt": prompt, + "multi_modal_data": {"image": images}, + }, + sampling_params=sampling_params, + ) + + print("-" * 50) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + assert len(generated_text) > 10, ( + f"Generated text is too short: {generated_text}" + ) + print("-" * 50) diff --git a/tests/models/multimodal/generation/test_qwen2_5_vl.py b/tests/models/multimodal/generation/test_qwen2_5_vl.py index 1a7d854352ae..b76c0b90b3c5 100644 --- a/tests/models/multimodal/generation/test_qwen2_5_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_5_vl.py @@ -3,7 +3,9 @@ import pytest +from vllm.attention.backends.registry import _MHA_Backend from vllm.multimodal.video import sample_frames_from_video +from vllm.platforms import current_platform from ....conftest import VIDEO_ASSETS @@ -34,6 +36,9 @@ def qwen2_5_vl_chat_template(*query): @pytest.mark.parametrize("num_frames", [16]) @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize( + "encoder_attn_backend", [None] + current_platform.get_supported_vit_attn_backends() +) def test_qwen2_5_vl_evs_functionality( vllm_runner, video_assets, @@ -42,6 +47,7 @@ def test_qwen2_5_vl_evs_functionality( num_frames: int, dtype: str, max_tokens: int, + encoder_attn_backend: _MHA_Backend | None, ) -> None: """Test EVS (Efficient Video Sampling) functionality with different pruning rates. @@ -66,6 +72,7 @@ def test_qwen2_5_vl_evs_functionality( limit_mm_per_prompt={"video": 1}, tensor_parallel_size=1, video_pruning_rate=video_pruning_rate, + mm_encoder_attn_backend=encoder_attn_backend, ) as vllm_model: # Generate output - this should not crash outputs = vllm_model.generate_greedy(prompts, max_tokens, videos=videos) diff --git a/tests/models/multimodal/generation/test_qwen2_vl.py b/tests/models/multimodal/generation/test_qwen2_vl.py index e10b8e1e77af..e1d9ebf940a1 100644 --- a/tests/models/multimodal/generation/test_qwen2_vl.py +++ b/tests/models/multimodal/generation/test_qwen2_vl.py @@ -8,8 +8,10 @@ import torch from PIL import Image +from vllm.attention.backends.registry import _MHA_Backend from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import rescale_video_size, sample_frames_from_video +from vllm.platforms import current_platform from ....conftest import ( IMAGE_ASSETS, @@ -273,6 +275,7 @@ def run_embedding_input_test( mm_limit: int, tensor_parallel_size: int, distributed_executor_backend: str | None = None, + mm_encoder_attn_backend: _MHA_Backend | None = None, ): """Inference result should be the same between original image/video input and image/video embeddings input. @@ -293,6 +296,7 @@ def run_embedding_input_test( distributed_executor_backend=distributed_executor_backend, default_torch_num_threads=1, enable_mm_embeds=True, + mm_encoder_attn_backend=mm_encoder_attn_backend, ) as vllm_model: outputs_per_case_for_original_input = [ vllm_model.generate_greedy_logprobs( @@ -377,6 +381,7 @@ def test_qwen2_vl_image_embeddings_input( num_logprobs=num_logprobs, mm_limit=1, tensor_parallel_size=1, + mm_encoder_attn_backend=_MHA_Backend.FLASH_ATTN, ) @@ -428,6 +433,7 @@ def test_qwen2_vl_multiple_image_embeddings_input( num_logprobs=num_logprobs, mm_limit=2, tensor_parallel_size=1, + mm_encoder_attn_backend=_MHA_Backend.FLASH_ATTN, ) @@ -447,6 +453,10 @@ def test_qwen2_vl_multiple_image_embeddings_input( @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) +@pytest.mark.parametrize( + "mm_encoder_attn_backend", + [None] + current_platform.get_supported_vit_attn_backends(), +) def test_qwen2_vl_video_embeddings_input( vllm_runner, video_assets, @@ -455,6 +465,7 @@ def test_qwen2_vl_video_embeddings_input( dtype: str, max_tokens: int, num_logprobs: int, + mm_encoder_attn_backend: _MHA_Backend | None, ) -> None: num_frames = 4 sampled_vids = [ @@ -480,4 +491,5 @@ def test_qwen2_vl_video_embeddings_input( num_logprobs=num_logprobs, mm_limit=1, tensor_parallel_size=1, + mm_encoder_attn_backend=mm_encoder_attn_backend, ) diff --git a/tests/models/multimodal/generation/test_qwen3_omni_moe_thinker.py b/tests/models/multimodal/generation/test_qwen3_omni_moe_thinker.py new file mode 100644 index 000000000000..d42536515092 --- /dev/null +++ b/tests/models/multimodal/generation/test_qwen3_omni_moe_thinker.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import asdict + +import pytest +from transformers import AutoProcessor + +from vllm import LLM, EngineArgs, SamplingParams +from vllm.attention.backends.registry import _MHA_Backend +from vllm.multimodal.utils import encode_image_base64 +from vllm.platforms import current_platform + +from ....utils import large_gpu_test + +MODEL_NAME = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + +QUESTION = "What is the content of each image?" + + +@large_gpu_test(min_gb=80) +@pytest.mark.parametrize("question", [QUESTION]) +@pytest.mark.parametrize( + "mm_encoder_attn_backend", + [None] + current_platform.get_supported_vit_attn_backends(), +) +def test_qwen3_omni_moe_thinker_vit_attn_backend_functionality( + image_assets, + question: str, + mm_encoder_attn_backend: _MHA_Backend | None, +): + images = [asset.pil_image for asset in image_assets] + + image_urls = [ + f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images + ] + + engine_args = EngineArgs( + model=MODEL_NAME, + trust_remote_code=True, + max_model_len=32768, + max_num_seqs=2, + limit_mm_per_prompt={"image": 3, "video": 3, "audio": 3}, + mm_encoder_attn_backend=mm_encoder_attn_backend, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + }, + ] + + processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + engine_args = asdict(engine_args) | {"seed": 42} + llm = LLM(**engine_args) + + sampling_params = SamplingParams( + temperature=0.6, + top_p=0.95, + top_k=20, + max_tokens=16384, + ) + + outputs = llm.generate( + { + "prompt": prompt, + "multi_modal_data": {"image": images}, + }, + sampling_params=sampling_params, + ) + + print("-" * 50) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + assert len(generated_text) > 10, ( + f"Generated text is too short: {generated_text}" + ) + print("-" * 50) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 05d0159d0861..cceae7624560 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -30,6 +30,15 @@ class _Backend(enum.Enum): ROCM_AITER_UNIFIED_ATTN = enum.auto() +class _MHA_Backend(enum.Enum): + VLLM_FLASH_ATTN = enum.auto() # CUDA-only + FLASH_ATTN = enum.auto() # CUDA/ROCm/XPU + XFORMERS = enum.auto() # CUDA + ROCM_AITER_FA = enum.auto() # ROCM-only + TORCH_SDPA = enum.auto() # CUDA/ROCm/TPU/XPU/CPU + PALLAS = enum.auto() # TPU only + + BACKEND_MAP = { _Backend.FLASH_ATTN: "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", # noqa: E501 _Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", # noqa: E501 @@ -108,3 +117,18 @@ def backend_name_to_enum(backend_name: str) -> _Backend | None: """ assert backend_name is not None return _Backend[backend_name] if backend_name in _Backend.__members__ else None + + +def mha_backend_name_to_enum(backend_name: str) -> _MHA_Backend | None: + """ + Convert a string backend name to a _MHA_Backend enum value. + + Returns: + _MHA_Backend: enum value if backend_name is a valid in-tree type + None: otherwise it's an invalid in-tree type or an out-of-tree platform + is loaded. + """ + assert backend_name is not None + return ( + _MHA_Backend[backend_name] if backend_name in _MHA_Backend.__members__ else None + ) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 17e025155a43..bebfa851444d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -12,7 +12,10 @@ import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.backends.abstract import AttentionBackend, MLAAttentionImpl -from vllm.attention.backends.registry import _Backend, backend_name_to_enum +from vllm.attention.backends.registry import ( + _MHA_Backend, + backend_name_to_enum, +) from vllm.attention.selector import get_attn_backend from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config @@ -99,50 +102,28 @@ def check_upstream_fa_availability(dtype: torch.dtype): def maybe_get_vit_flash_attn_backend( - attn_backend: _Backend, - use_upstream_fa: bool, - attn_backend_override: _Backend | None = None, -) -> tuple[_Backend, Callable | None]: - if current_platform.is_rocm(): - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - attn_backend = _Backend.ROCM_AITER_FA - - elif ( - check_upstream_fa_availability(torch.get_default_dtype()) - and on_gfx9() - and attn_backend_override is None - ): - attn_backend = _Backend.FLASH_ATTN - use_upstream_fa = True - else: - return _Backend.TORCH_SDPA, None - - elif current_platform.is_cuda(): - if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() - ): - attn_backend = _Backend.FLASH_ATTN - use_upstream_fa = True - elif current_platform.is_xpu(): - assert attn_backend == _Backend.FLASH_ATTN, ( - "XPU platform only supports FLASH_ATTN as vision attention backend." - ) - use_upstream_fa = False - else: - return _Backend.TORCH_SDPA, None - - if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: - if attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.attention.utils.fa_utils import flash_attn_varlen_func + attn_backend: _MHA_Backend | None, +) -> Callable | None: + # At this point, + # we already have the attn_backend, + # overriding logic is done in the platform-specific implementation. + # so we don't need to override backend here. + # Just return the attn_backend and flash_attn_varlen_func. + + if attn_backend == _MHA_Backend.FLASH_ATTN and current_platform.is_cuda_alike(): + from flash_attn import flash_attn_varlen_func + elif attn_backend == _MHA_Backend.FLASH_ATTN and current_platform.is_xpu(): + from vllm.attention.utils.fa_utils import flash_attn_varlen_func + elif attn_backend == _MHA_Backend.VLLM_FLASH_ATTN: + from vllm.vllm_flash_attn import flash_attn_varlen_func + elif attn_backend == _MHA_Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func else: flash_attn_varlen_func = None - return attn_backend, flash_attn_varlen_func + # if attn_backend is TORCH_SDPA, + # it will reach here and the flash_attn_varlen_func will be None. + return flash_attn_varlen_func def _init_kv_cache_quant( @@ -521,49 +502,37 @@ def __init__( attn_backend_override=attn_backend_override, ) - # Some auto-selected backends can be upgraded - # to upstream flash attention if available. - # If vllm native fa is selected, we use it directly. - use_upstream_fa = False - self.attn_backend = ( backend if backend in { - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.PALLAS, - _Backend.ROCM_AITER_FA, - _Backend.FLASH_ATTN, + _MHA_Backend.TORCH_SDPA, + _MHA_Backend.XFORMERS, + _MHA_Backend.PALLAS, + _MHA_Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.VLLM_FLASH_ATTN, } - else _Backend.TORCH_SDPA + else _MHA_Backend.TORCH_SDPA ) - self.attn_backend, self._flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - use_upstream_fa, - attn_backend_override=attn_backend_override, - ) + self._flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( + self.attn_backend, ) - if self.attn_backend == _Backend.XFORMERS and not check_xformers_availability(): - self.attn_backend = _Backend.TORCH_SDPA + if ( + self.attn_backend == _MHA_Backend.XFORMERS + and not check_xformers_availability() + ): + self.attn_backend = _MHA_Backend.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.ROCM_AITER_FA, + _MHA_Backend.VLLM_FLASH_ATTN, } - # this condition is just to make sure that the - # use_upstream_fa in the log is correct - if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: - use_upstream_fa = True - - logger.info_once( - f"MultiHeadAttention attn_backend: {self.attn_backend}, " - f"use_upstream_fa: {use_upstream_fa}" - ) + logger.info_once(f"MultiHeadAttention attn_backend: {self.attn_backend}, ") def forward( self, @@ -606,17 +575,17 @@ def forward( max_seqlen_k=kv_len, softmax_scale=self.scale, ) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward( query, key, value, scale=self.scale ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == _MHA_Backend.TORCH_SDPA: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) out = out.transpose(1, 2) - elif self.attn_backend == _Backend.PALLAS: + elif self.attn_backend == _MHA_Backend.PALLAS: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) from torch_xla.experimental.custom_kernel import flash_attention diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index ef73720efe09..360ad862bbf6 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -11,9 +11,9 @@ from vllm.config.utils import config if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _MHA_Backend else: - _Backend = Any + _MHA_Backend = Any @dataclass @@ -125,10 +125,10 @@ class MultiModalConfig: DP (which is controlled by `--data-parallel-size`). This is only supported on a per-model basis and falls back to `"weights"` if the encoder does not support DP.""" - mm_encoder_attn_backend: _Backend | None = None + mm_encoder_attn_backend: _MHA_Backend | None = None """Optional override for the multi-modal encoder attention backend when using vision transformers. Accepts any value from - `vllm.attention.backends.registry._Backend` (e.g. `FLASH_ATTN`).""" + `vllm.attention.backends.registry._MHA_Backend` (e.g. `FLASH_ATTN`).""" interleave_mm_strings: bool = False """Enable fully interleaved support for multimodal prompts, while using --chat-template-content-format=string.""" @@ -167,19 +167,19 @@ def _validate_limit_per_prompt( @field_validator("mm_encoder_attn_backend", mode="before") @classmethod - def _validate_mm_encoder_attn_backend(cls, value: object) -> _Backend | None: + def _validate_mm_encoder_attn_backend(cls, value: object) -> _MHA_Backend | None: from vllm.attention.backends.registry import ( - _Backend as BackendEnum, + _MHA_Backend as BackendEnum, ) from vllm.attention.backends.registry import ( - backend_name_to_enum, + mha_backend_name_to_enum, ) if value is None or isinstance(value, BackendEnum): return value if isinstance(value, str): - candidate = backend_name_to_enum(value.upper()) + candidate = mha_backend_name_to_enum(value.upper()) if candidate is not None: return candidate diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 6d462ad8ae62..f90baf888875 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -9,7 +9,7 @@ from torch.nn import LayerNorm from transformers.models.qwen2_vl import Qwen2VLProcessor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import _MHA_Backend from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -256,7 +256,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() @@ -293,27 +293,24 @@ def __init__( torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - self.use_upstream_fa, - attn_backend_override=attn_backend_override, - ) + self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( + self.attn_backend, ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.VLLM_FLASH_ATTN, + _MHA_Backend.TORCH_SDPA, + _MHA_Backend.XFORMERS, + _MHA_Backend.ROCM_AITER_FA, }: raise RuntimeError( f"Unsupported vision attention backend: {self.attn_backend}" ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.ROCM_AITER_FA, + _MHA_Backend.VLLM_FLASH_ATTN, } def forward( @@ -361,7 +358,7 @@ def forward( self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == _MHA_Backend.TORCH_SDPA: outputs = [] for i in range(1, len(cu_seqlens)): s = int(cu_seqlens[i - 1]) @@ -373,7 +370,7 @@ def forward( out_i = out_i.permute(0, 2, 1, 3) outputs.append(out_i) context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -514,7 +511,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() @@ -567,7 +564,7 @@ def __init__( require_post_norm: bool | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() self.config = config @@ -582,10 +579,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != _MHA_Backend.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = _MHA_Backend.FLASH_ATTN self.out_hidden_size = config.hidden_size # Keep blocks for compatibility with other vision towers num_layers = ( @@ -666,11 +664,11 @@ def compute_attn_mask_seqlen( ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == _MHA_Backend.FLASH_ATTN + or self.attn_backend == _MHA_Backend.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 86536b21c33f..0821e217afd2 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -36,7 +36,7 @@ from einops import rearrange, repeat from transformers import BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import _MHA_Backend from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -164,7 +164,7 @@ def __init__( projection_size: int, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -200,28 +200,24 @@ def __init__( attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False - - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - self.use_upstream_fa, - attn_backend_override=attn_backend_override, - ) + self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( + self.attn_backend, ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.VLLM_FLASH_ATTN, + _MHA_Backend.TORCH_SDPA, + _MHA_Backend.XFORMERS, + _MHA_Backend.ROCM_AITER_FA, }: raise RuntimeError( f"Ernie45-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.ROCM_AITER_FA, + _MHA_Backend.VLLM_FLASH_ATTN, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -291,7 +287,7 @@ def forward( context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == _MHA_Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] for i in range(1, len(cu_seqlens)): @@ -310,7 +306,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -370,7 +366,7 @@ def __init__( norm_layer: Callable[[int], nn.Module] | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() @@ -463,7 +459,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() patch_size = vision_config.patch_size @@ -515,10 +511,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != _MHA_Backend.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = _MHA_Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -565,11 +562,11 @@ def compute_attn_mask_seqlen( ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == _MHA_Backend.FLASH_ATTN + or self.attn_backend == _MHA_Backend.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 121e84469c52..a19eb27c8e8f 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -46,7 +46,7 @@ from transformers.models.glm4v.video_processing_glm4v import Glm4vVideoProcessor from transformers.video_utils import VideoMetadata -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import _MHA_Backend from vllm.attention.layer import ( check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, @@ -252,7 +252,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -295,29 +295,26 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - self.use_upstream_fa, - attn_backend_override=attn_backend_override, - ) + self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( + self.attn_backend, ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.VLLM_FLASH_ATTN, + _MHA_Backend.TORCH_SDPA, + _MHA_Backend.XFORMERS, + _MHA_Backend.ROCM_AITER_FA, }: raise RuntimeError( f"GLM-4V does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.ROCM_AITER_FA, + _MHA_Backend.VLLM_FLASH_ATTN, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -370,14 +367,14 @@ def forward( cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, - dropout_p=0.0, + dropout_p=0.0, # AITER FA only supports float dtype causal=False, ) context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == _MHA_Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. outputs = [] for i in range(1, len(cu_seqlens)): @@ -396,7 +393,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -425,7 +422,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -703,7 +700,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() @@ -772,10 +769,11 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() + if ( + self.attn_backend != _MHA_Backend.FLASH_ATTN + and check_upstream_fa_availability(torch.get_default_dtype()) ): - self.attn_backend = _Backend.FLASH_ATTN + self.attn_backend = _MHA_Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -824,8 +822,8 @@ def compute_attn_mask_seqlen( max_seqlen, seqlens = None, None seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == _MHA_Backend.FLASH_ATTN + or self.attn_backend == _MHA_Backend.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 5f8659a3064e..184266e54268 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -16,7 +16,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.utils import torch_int -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import _MHA_Backend from vllm.attention.layer import ( maybe_get_vit_flash_attn_backend, ) @@ -360,7 +360,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() self.config = config @@ -405,26 +405,24 @@ def __init__( attn_backend_override=attn_backend_override, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - use_upstream_fa=False, - attn_backend_override=attn_backend_override, - ) + self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( + self.attn_backend, ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.XFORMERS, + _MHA_Backend.VLLM_FLASH_ATTN, + _MHA_Backend.ROCM_AITER_FA, }: raise RuntimeError( f"Keye-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.VLLM_FLASH_ATTN, + _MHA_Backend.ROCM_AITER_FA, } def forward( @@ -489,7 +487,7 @@ def forward( softmax_scale=self.scale, ) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -536,7 +534,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -590,7 +588,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() self.config = config @@ -685,7 +683,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() self.config = config @@ -768,7 +766,7 @@ def __init__( config: PretrainedConfig, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index f6461ae9a412..338f6bae4cca 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -10,7 +10,7 @@ import torch.nn as nn from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import _MHA_Backend from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.linear import ReplicatedLinear @@ -106,7 +106,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() self.config = config @@ -135,7 +135,7 @@ def _init_backbone( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): model_type = config.model_type if model_type == "siglip2_navit": diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index d337f1606943..f574be39435c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -42,8 +42,7 @@ Qwen2_5_VLVisionConfig, ) -from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import maybe_get_vit_flash_attn_backend +from vllm.attention.backends.registry import _MHA_Backend from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, vit_torch_sdpa_wrapper, @@ -315,9 +314,8 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, - use_upstream_fa: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend: _MHA_Backend = _MHA_Backend.TORCH_SDPA, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -353,24 +351,11 @@ def __init__( disable_tp=use_data_parallel, ) self.attn_backend = attn_backend - self.use_upstream_fa = use_upstream_fa - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - self.use_upstream_fa, - attn_backend_override=attn_backend_override, - ) - ) - # On ROCm with FLASH_ATTN backend, upstream flash_attn is used - from vllm.platforms import current_platform - if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN: - self.use_upstream_fa = True - if current_platform.is_xpu(): - self.use_upstream_fa = False self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.ROCM_AITER_FA, + _MHA_Backend.VLLM_FLASH_ATTN, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -431,10 +416,10 @@ def forward( cu_seqlens, max_seqlen, batch_size, - self.attn_backend == _Backend.ROCM_AITER_FA, - self.use_upstream_fa, + self.attn_backend == _MHA_Backend.ROCM_AITER_FA, + self.attn_backend == _MHA_Backend.FLASH_ATTN, ) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == _MHA_Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform @@ -450,7 +435,7 @@ def forward( v, cu_seqlens, ) - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) output, _ = self.proj(context_layer) @@ -478,9 +463,8 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, - use_upstream_fa: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend: _MHA_Backend = _MHA_Backend.TORCH_SDPA, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -495,7 +479,6 @@ def __init__( prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, attn_backend=attn_backend, - use_upstream_fa=use_upstream_fa, attn_backend_override=attn_backend_override, ) self.mlp = Qwen2_5_VisionMLP( @@ -656,7 +639,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() @@ -692,26 +675,18 @@ def __init__( head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) - use_upstream_fa = False self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - use_upstream_fa, - attn_backend_override=attn_backend_override, - ) - ) - if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.VLLM_FLASH_ATTN, + _MHA_Backend.TORCH_SDPA, + _MHA_Backend.XFORMERS, + _MHA_Backend.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen2.5-VL does not support {self.attn_backend} backend now." @@ -730,7 +705,6 @@ def __init__( prefix=f"{prefix}.blocks.{layer_idx}", use_data_parallel=use_data_parallel, attn_backend=self.attn_backend, - use_upstream_fa=use_upstream_fa, attn_backend_override=attn_backend_override, ) for layer_idx in range(depth) @@ -850,9 +824,9 @@ def compute_attn_mask_seqlen( ) -> tuple[torch.Tensor, torch.Tensor]: max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) - if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + if self.attn_backend in {_MHA_Backend.FLASH_ATTN, _MHA_Backend.ROCM_AITER_FA}: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9206ac8f9d03..e15be89586fa 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -43,9 +43,8 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import _MHA_Backend from vllm.attention.layer import ( - check_upstream_fa_availability, maybe_get_vit_flash_attn_backend, ) from vllm.config import VllmConfig @@ -329,7 +328,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() # Per attention head and per partition values. @@ -367,29 +366,22 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - self.use_upstream_fa, - attn_backend_override=attn_backend_override, - ) + self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( + self.attn_backend, ) - if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, - }: + from vllm.platforms import current_platform + + if self.attn_backend not in current_platform.get_supported_vit_attn_backends(): raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now." ) self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.ROCM_AITER_FA, + _MHA_Backend.VLLM_FLASH_ATTN, } def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: @@ -460,7 +452,7 @@ def forward( context_layer = rearrange( output, "(b s) h d -> s b (h d)", b=batch_size ).contiguous() - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == _MHA_Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform @@ -485,7 +477,7 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask @@ -515,7 +507,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() if norm_layer is None: @@ -679,7 +671,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() @@ -739,10 +731,6 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() - ): - self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -789,9 +777,9 @@ def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor ) -> tuple[int | None, list[int] | None]: max_seqlen, seqlens = None, None - if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: + if self.attn_backend in {_MHA_Backend.FLASH_ATTN, _MHA_Backend.ROCM_AITER_FA}: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index f20e67902721..d5c47e4b5a24 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -47,8 +47,7 @@ ) from transformers.models.whisper import WhisperFeatureExtractor -from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.backends.registry import _MHA_Backend from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.distributed import get_pp_group @@ -301,7 +300,7 @@ def __init__( norm_eps: float = 1e-6, quant_config: QuantizationConfig | None = None, prefix: str = "", - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -377,10 +376,6 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - torch.get_default_dtype() - ): - self.attn_backend = _Backend.FLASH_ATTN @property def dtype(self) -> torch.dtype: @@ -490,9 +485,9 @@ def compute_attn_mask_seqlen( ) -> tuple[torch.Tensor, torch.Tensor]: max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) - if self.attn_backend == _Backend.FLASH_ATTN: + if self.attn_backend == _MHA_Backend.FLASH_ATTN: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 2d8f431bb8fa..b8d6cbe65912 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -49,8 +49,7 @@ ) from transformers.video_utils import VideoMetadata -from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.backends.registry import _MHA_Backend from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions @@ -198,8 +197,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend: _Backend = _Backend.TORCH_SDPA, - use_upstream_fa: bool = False, + attn_backend: _MHA_Backend = _MHA_Backend.TORCH_SDPA, ) -> None: super().__init__() if norm_layer is None: @@ -214,7 +212,6 @@ def __init__( prefix=f"{prefix}.attn", use_data_parallel=use_data_parallel, attn_backend=attn_backend, - use_upstream_fa=use_upstream_fa, ) self.mlp = Qwen3_VisionMLP( dim, @@ -306,7 +303,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ) -> None: super().__init__() self.hidden_size = vision_config.hidden_size @@ -370,20 +367,13 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - use_upstream_fa = False - if ( - self.attn_backend != _Backend.FLASH_ATTN - and self.attn_backend != _Backend.ROCM_AITER_FA - and check_upstream_fa_availability(torch.get_default_dtype()) - ): - self.attn_backend = _Backend.FLASH_ATTN - use_upstream_fa = True if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.XFORMERS, - _Backend.ROCM_AITER_FA, + _MHA_Backend.VLLM_FLASH_ATTN, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.TORCH_SDPA, + _MHA_Backend.XFORMERS, + _MHA_Backend.ROCM_AITER_FA, }: raise RuntimeError( f"Qwen3-VL does not support {self.attn_backend} backend now." @@ -400,7 +390,6 @@ def __init__( prefix=f"{prefix}.blocks.{layer_idx}", use_data_parallel=use_data_parallel, attn_backend=self.attn_backend, - use_upstream_fa=use_upstream_fa, ) for layer_idx in range(vision_config.depth) ] @@ -510,11 +499,11 @@ def compute_attn_mask_seqlen( max_seqlen = torch.zeros([], device=cu_seqlens.device) seqlens = torch.zeros(1, device=cu_seqlens.device) if ( - self.attn_backend == _Backend.FLASH_ATTN - or self.attn_backend == _Backend.ROCM_AITER_FA + self.attn_backend == _MHA_Backend.FLASH_ATTN + or self.attn_backend == _MHA_Backend.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == _Backend.XFORMERS: + elif self.attn_backend == _MHA_Backend.XFORMERS: seqlens = cu_seqlens[1:] - cu_seqlens[:-1] return max_seqlen, seqlens diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index bab5c1d82ded..fcf2c71d8c78 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -12,7 +12,7 @@ from transformers import Siglip2VisionConfig from transformers.configuration_utils import PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import _MHA_Backend from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn @@ -208,7 +208,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() self.config = config @@ -253,25 +253,22 @@ def __init__( dtype=torch.get_default_dtype(), attn_backend_override=attn_backend_override, ) - self.use_upstream_fa = False - self.attn_backend, self.flash_attn_varlen_func = ( - maybe_get_vit_flash_attn_backend( - self.attn_backend, - self.use_upstream_fa, - attn_backend_override=attn_backend_override, - ) + self.flash_attn_varlen_func = maybe_get_vit_flash_attn_backend( + self.attn_backend, ) if self.attn_backend not in { - _Backend.FLASH_ATTN, - _Backend.TORCH_SDPA, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.VLLM_FLASH_ATTN, + _MHA_Backend.TORCH_SDPA, + _MHA_Backend.ROCM_AITER_FA, }: - self.attn_backend = _Backend.TORCH_SDPA + self.attn_backend = _MHA_Backend.TORCH_SDPA self.is_flash_attn_backend = self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.ROCM_AITER_FA, + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.VLLM_FLASH_ATTN, + _MHA_Backend.ROCM_AITER_FA, } def forward( @@ -308,8 +305,14 @@ def forward( attn_output = self.flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen ).reshape(seq_length, -1) - elif self.attn_backend == _Backend.TORCH_SDPA: + elif self.attn_backend == _MHA_Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. + from vllm.platforms import current_platform + + if current_platform.is_rocm(): + queries = queries.contiguous() + keys = keys.contiguous() + values = values.contiguous() batch_size = cu_seqlens.shape[0] - 1 outputs = [] cu = cu_seqlens.tolist() @@ -376,7 +379,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -440,7 +443,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() self.config = config @@ -626,7 +629,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() self.config = config @@ -667,7 +670,7 @@ def __init__( quant_config: QuantizationConfig | None = None, prefix: str = "", use_data_parallel: bool = False, - attn_backend_override: _Backend | None = None, + attn_backend_override: _MHA_Backend | None = None, ): super().__init__() diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 9f94387c700d..687fd883014a 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -10,7 +10,7 @@ import torch from transformers import PretrainedConfig -from vllm.attention.backends.registry import _Backend +from vllm.attention.backends.registry import _MHA_Backend from vllm.config import VllmConfig from vllm.distributed import ( get_tensor_model_parallel_rank, @@ -83,22 +83,15 @@ def get_vit_attn_backend( head_size: int, dtype: torch.dtype, *, - attn_backend_override: _Backend | None = None, -) -> _Backend: + attn_backend_override: _MHA_Backend | None = None, +) -> _MHA_Backend: """ Get the available attention backend for Vision Transformer. """ - if attn_backend_override is not None: - return attn_backend_override - # Lazy import to avoid circular dependency - from vllm.attention.selector import get_env_variable_attn_backend - - selected_backend: _Backend | None = get_env_variable_attn_backend() - if selected_backend is not None: - return selected_backend - - return current_platform.get_vit_attn_backend(head_size, dtype) + return current_platform.get_vit_attn_backend( + head_size, dtype, backend=attn_backend_override + ) def should_torch_compile_mm_vit(vllm_config: VllmConfig) -> bool: diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 32734c3aba5e..129a720bb901 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -22,10 +22,11 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _Backend, _MHA_Backend from vllm.config import VllmConfig else: _Backend = None + _MHA_Backend = None logger = init_logger(__name__) @@ -216,16 +217,51 @@ def get_current_memory_usage( return torch.cuda.max_memory_allocated(device) @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": - from vllm.attention.backends.registry import _Backend + def get_supported_vit_attn_backends(cls) -> list["_MHA_Backend"]: + from vllm.attention.backends.registry import _MHA_Backend + + return [ + _MHA_Backend.TORCH_SDPA, + _MHA_Backend.XFORMERS, + _MHA_Backend.VLLM_FLASH_ATTN, + _MHA_Backend.FLASH_ATTN, + ] + + @classmethod + def get_vit_attn_backend( + cls, head_size: int, dtype: torch.dtype, backend: "_MHA_Backend" | None = None + ) -> "_MHA_Backend": + # ViT Attention should be checked and override + # in the platform-specific implementation. + # we should not override this in any other places, + # like the model_executor/models/.py + + # So the steps are: + # 1. Check if the backend is None or not: + # a. If not, check if the backend is supported by the platform. + # b. If None, continue to the default selection logic. + + from vllm.attention.backends.registry import _MHA_Backend + + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention. " + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend # For Blackwell GPUs, force TORCH_SDPA for now. # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 if cls.has_device_capability(100): - return _Backend.TORCH_SDPA + logger.info_once( + f"Using backend {_MHA_Backend.TORCH_SDPA} for vit attention" + ) + return _MHA_Backend.TORCH_SDPA if dtype not in (torch.float16, torch.bfloat16): - return _Backend.XFORMERS + logger.info_once(f"Using backend {_MHA_Backend.XFORMERS} for vit attention") + return _MHA_Backend.XFORMERS if cls.has_device_capability(80): FLASH_ATTN_V1 = ( @@ -236,14 +272,29 @@ def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": is_default_fa_supported = is_attn_backend_supported( FLASH_ATTN_V1, head_size, dtype, allow_import_error=False ) - if is_default_fa_supported: - return _Backend.FLASH_ATTN + # lazy import to avoid circular import + from vllm.attention.layer import check_upstream_fa_availability + + if is_default_fa_supported and check_upstream_fa_availability(dtype=dtype): + logger.info_once( + f"Using backend {_MHA_Backend.FLASH_ATTN} for vit attention" + ) + return _MHA_Backend.FLASH_ATTN + elif is_default_fa_supported: + logger.info_once( + f"Using backend {_MHA_Backend.VLLM_FLASH_ATTN} for vit attention" + ) + return _MHA_Backend.VLLM_FLASH_ATTN else: # Fallback to XFORMERS - return _Backend.XFORMERS + logger.info_once( + f"Using backend {_MHA_Backend.XFORMERS} for vit attention" + ) + return _MHA_Backend.XFORMERS else: # Fallback for Volta/Turing GPUs or FA not supported - return _Backend.XFORMERS + logger.info_once(f"Using backend {_MHA_Backend.XFORMERS} for vit attention") + return _MHA_Backend.XFORMERS @classmethod def get_attn_backend_cls( diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 15e3b3a22bde..503d5678cec7 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -7,7 +7,7 @@ import random import sys from datetime import timedelta -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple, Optional import numpy as np import torch @@ -17,7 +17,7 @@ if TYPE_CHECKING: from torch.distributed import PrefixStore, ProcessGroup - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _Backend, _MHA_Backend from vllm.config import VllmConfig from vllm.inputs import ProcessorInputs, PromptType from vllm.pooling_params import PoolingParams @@ -173,11 +173,45 @@ def import_kernels(cls) -> None: import vllm._moe_C # noqa: F401 @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + def get_supported_vit_attn_backends(cls) -> list["_MHA_Backend"]: + from vllm.attention.backends.registry import _MHA_Backend + + return [ + _MHA_Backend.TORCH_SDPA, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["_MHA_Backend"] = None, + ) -> "_MHA_Backend": + # ViT Attention should be checked and override + # in the platform-specific implementation. + # we should not override this in any other places, + # like the model_executor/models/.py + + # So the steps are: + # 1. Check if the backend is None or not: + # a. If not, check if the backend is supported by the platform. + # b. If None, continue to the default selection logic. + # Import _Backend here to avoid circular import. - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _MHA_Backend + + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention" + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend - return _Backend.TORCH_SDPA + logger.info_once( + f"Using default backend {_MHA_Backend.TORCH_SDPA} for vit attention" + ) + return _MHA_Backend.TORCH_SDPA @classmethod def get_attn_backend_cls( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 9745e4b08cf0..039499162776 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -3,7 +3,7 @@ import os from functools import cache, lru_cache, wraps -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -14,10 +14,11 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _Backend, _MHA_Backend from vllm.config import VllmConfig else: _Backend = None + _MHA_Backend = None logger = init_logger(__name__) @@ -201,18 +202,59 @@ class RocmPlatform(Platform): ] @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": - from importlib.util import find_spec + def get_supported_vit_attn_backends(cls) -> list["_MHA_Backend"]: + from vllm.attention.backends.registry import _MHA_Backend - from vllm.attention.backends.registry import _Backend + return [ + _MHA_Backend.FLASH_ATTN, + _MHA_Backend.ROCM_AITER_FA, + _MHA_Backend.TORCH_SDPA, + ] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["_MHA_Backend"] = None, + ) -> "_MHA_Backend": + # ViT Attention should be checked and override + # in the platform-specific implementation. + # we should not override this in any other places, + # like the model_executor/models/.py + + # So the steps are: + # 1. Check if the backend is None or not: + # a. If not, check if the backend is supported by the platform. + # b. If None, continue to the default selection logic. + + from vllm.attention.backends.registry import _MHA_Backend + + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention. " + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - return _Backend.ROCM_AITER_FA + logger.info_once( + f"Using backend {_MHA_Backend.ROCM_AITER_FA} for vit attention" + ) + return _MHA_Backend.ROCM_AITER_FA - if on_gfx9() and find_spec("flash_attn") is not None: - return _Backend.FLASH_ATTN + # lazy import to avoid circular import + from vllm.attention.layer import check_upstream_fa_availability + + if on_gfx9() and check_upstream_fa_availability(dtype=dtype): + logger.info_once( + f"Using backend {_MHA_Backend.FLASH_ATTN} for vit attention" + ) + return _MHA_Backend.FLASH_ATTN - return _Backend.TORCH_SDPA + logger.info_once(f"Using backend {_MHA_Backend.TORCH_SDPA} for vit attention") + return _MHA_Backend.TORCH_SDPA @classmethod def get_attn_backend_cls( diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 1a4b67a1762f..a8eac3761235 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -15,7 +15,7 @@ from .interface import Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _Backend, _MHA_Backend from vllm.config import ModelConfig, VllmConfig from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams @@ -25,6 +25,7 @@ VllmConfig = None PoolingParams = None _Backend = None + _MHA_Backend = None logger = init_logger(__name__) @@ -112,6 +113,15 @@ def get_lora_vocab_padding_size(cls) -> int: def inference_mode(cls): return torch.no_grad() + @classmethod + def get_supported_vit_attn_backends(cls) -> list["_MHA_Backend"]: + from vllm.attention.backends.registry import _MHA_Backend + + return [ + _MHA_Backend.TORCH_SDPA, + _MHA_Backend.PALLAS, # currently only is used in LLama4 VL model + ] + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.config import CompilationMode, CUDAGraphMode diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index e4ecd0c807da..d4ed9dbbacc7 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -3,7 +3,7 @@ import contextlib import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -14,12 +14,13 @@ from .interface import DeviceCapability, Platform, PlatformEnum if TYPE_CHECKING: - from vllm.attention.backends.registry import _Backend + from vllm.attention.backends.registry import _Backend, _MHA_Backend from vllm.config import ModelConfig, VllmConfig else: ModelConfig = None VllmConfig = None _Backend = None + _MHA_Backend = None logger = init_logger(__name__) @@ -113,10 +114,42 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: return device_props.total_memory @classmethod - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: - from vllm.attention.backends.registry import _Backend + def get_supported_vit_attn_backends(cls) -> list["_MHA_Backend"]: + from vllm.attention.backends.registry import _MHA_Backend + + # as mentioned in this PR: https://github.com/vllm-project/vllm/pull/27525 + # XPU only supports FLASH_ATTN for vit attention + return [_MHA_Backend.FLASH_ATTN] + + @classmethod + def get_vit_attn_backend( + cls, + head_size: int, + dtype: torch.dtype, + backend: Optional["_MHA_Backend"] = None, + ) -> "_MHA_Backend": + # ViT Attention should be checked and override + # in the platform-specific implementation. + # we should not override this in any other places, + # like the model_executor/models/.py + + # So the steps are: + # 1. Check if the backend is None or not: + # a. If not, check if the backend is supported by the platform. + # b. If None, continue to the default selection logic. + + from vllm.attention.backends.registry import _MHA_Backend + + if backend is not None: + assert backend in cls.get_supported_vit_attn_backends(), ( + f"Backend {backend} is not supported for vit attention. " + f"Supported backends are: {cls.get_supported_vit_attn_backends()}" + ) + logger.info_once(f"Using backend {backend} for vit attention") + return backend - return _Backend.FLASH_ATTN + logger.info_once(f"Using backend {_MHA_Backend.FLASH_ATTN} for vit attention") + return _MHA_Backend.FLASH_ATTN @classmethod def inference_mode(cls):