diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch index b88b9c499220..d663c82c3885 100644 --- a/docker/Dockerfile.nightly_torch +++ b/docker/Dockerfile.nightly_torch @@ -76,34 +76,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/common.txt -# must put before installing xformers, so it can install the correct version of xfomrers. -ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0' -ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} - -# Build xformers with cuda and torch nightly -# following official xformers guidance: https://github.com/facebookresearch/xformers#build -# todo(elainewy): cache xformers build result for faster build -ARG max_jobs=16 -ENV MAX_JOBS=${max_jobs} -ARG XFORMERS_COMMIT=f2de641ef670510cadab099ce6954031f52f191c - -ENV CCACHE_DIR=/root/.cache/ccache -RUN --mount=type=cache,target=/root/.cache/ccache \ - --mount=type=cache,target=/root/.cache/uv \ - echo 'git clone xformers...' \ - && git clone https://github.com/facebookresearch/xformers.git --recursive \ - && cd xformers \ - && git checkout ${XFORMERS_COMMIT} \ - && git submodule update --init --recursive \ - && echo 'finish git clone xformers...' \ - && rm -rf build \ - && python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \ - && cd .. \ - && rm -rf xformers - -RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system xformers-dist/*.whl --verbose - # build can take a long time, and the torch nightly version fetched from url can be different in next docker stage. # track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt @@ -233,11 +205,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/vllm --mount=type=cache,target=/root/.cache/uv \ uv pip install --system vllm-dist/*.whl --verbose -# install xformers again for the new environment -RUN --mount=type=bind,from=base,src=/workspace/xformers-dist,target=/vllm-workspace/xformers-dist \ - --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system /vllm-workspace/xformers-dist/*.whl --verbose - ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0' # install package for build flashinfer @@ -307,7 +274,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/nightly_torch_test.txt # Logging to confirm the torch versions -RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer' +RUN pip freeze | grep -E 'torch|vllm|flashinfer' # Logging to confirm all the packages are installed RUN pip freeze diff --git a/docs/contributing/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md index 09fd85a466ee..735bb2e20533 100644 --- a/docs/contributing/ci/update_pytorch_version.md +++ b/docs/contributing/ci/update_pytorch_version.md @@ -98,21 +98,6 @@ to warm it up so that future builds are faster. Buildkite new build popup

-## Update dependencies - -Several vLLM dependencies like xFormers depend on PyTorch and need -to be updated accordingly. Rather than waiting for all of them to publish new -releases (which would take too much time), they can be built from -source to unblock the update process. - -### xFormers - -```bash -export TORCH_CUDA_ARCH_LIST='7.5 8.0+PTX 9.0a' -MAX_JOBS=16 uv pip install --system \ - --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.32.post2" -``` - ## Update all the different vLLM platforms Rather than attempting to update all vLLM platforms in a single pull request, it's more manageable diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index 9e86f785b10c..94920dc5306b 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -283,7 +283,7 @@ Currently, vLLM supports multiple backends for efficient Attention computation a If desired, you can also manually set the backend of your choice by configuring the environment variable `VLLM_ATTENTION_BACKEND` to one of the following options: -- On NVIDIA CUDA: `FLASH_ATTN`, `FLASHINFER` or `XFORMERS`. +- On NVIDIA CUDA: `FLASH_ATTN` or `FLASHINFER`. - On AMD ROCm: `TRITON_ATTN`, `ROCM_ATTN`, `ROCM_AITER_FA` or `ROCM_AITER_UNIFIED_ATTN`. For AMD ROCm, you can further control the specific Attention implementation using the following variables: diff --git a/examples/online_serving/openai_embedding_long_text/service.sh b/examples/online_serving/openai_embedding_long_text/service.sh index 1577de85f7ff..b5c92749466b 100644 --- a/examples/online_serving/openai_embedding_long_text/service.sh +++ b/examples/online_serving/openai_embedding_long_text/service.sh @@ -22,7 +22,6 @@ API_KEY=${API_KEY:-"your-api-key"} POOLING_TYPE=${POOLING_TYPE:-"auto"} # auto, MEAN, CLS, LAST export VLLM_ENABLE_CHUNKED_PROCESSING=true export CUDA_VISIBLE_DEVICES=2,3,4,5 -# export VLLM_ATTENTION_BACKEND=XFORMERS echo "🚀 Starting vLLM Embedding Server with Enhanced Chunked Processing" echo "==================================================================" diff --git a/requirements/cuda.txt b/requirements/cuda.txt index d63fe9e1e77c..15e8aadc56f4 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -9,6 +9,5 @@ torch==2.9.0 torchaudio==2.9.0 # These must be updated alongside torch torchvision==0.24.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -xformers==0.0.33.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.9 # FlashInfer should be updated together with the Dockerfile flashinfer-python==0.5.2 diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 0cf1e85d4e8e..521d6c33dd39 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -74,9 +74,6 @@ def test_models( model_executor: str, enable_prompt_embeds: bool, ) -> None: - if backend == "XFORMERS" and model == "google/gemma-2-2b-it": - pytest.skip(f"{backend} does not support gemma2 with full context length.") - with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", backend) diff --git a/tests/kernels/attention/test_attention.py b/tests/kernels/attention/test_attention.py index 9662e73321eb..1a7d5ce0ddc1 100644 --- a/tests/kernels/attention/test_attention.py +++ b/tests/kernels/attention/test_attention.py @@ -13,12 +13,6 @@ from vllm.platforms import current_platform from vllm.utils.mem_utils import get_max_shared_memory_bytes -if not current_platform.is_rocm(): - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask - - from tests.kernels.utils import make_alibi_bias - FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer @@ -448,129 +442,6 @@ def ref_multi_query_kv_attention( return torch.cat(ref_outputs, dim=0) -@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif( - current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." -) -@torch.inference_mode() -def test_multi_query_kv_attention( - num_seqs: int, - num_heads: tuple[int, int], - head_size: int, - dtype: torch.dtype, - seed: int, - device: str, - use_alibi: bool = False, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. - # As the xformers library is already tested with its own tests, we can use - # a smaller MAX_SEQ_LEN here. - max_len = min(MAX_SEQ_LEN, 4096) - seq_lens = random.sample(range(1, max_len), num_seqs) - num_tokens = sum(seq_lens) - - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - qkv = torch.empty( - num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype - ) - qkv.uniform_(-scale, scale) - query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1) - - num_queries_per_kv = num_query_heads // num_kv_heads - if num_queries_per_kv > 1: - # Handle MQA and GQA - key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - alibi_bias = None - if use_alibi: - alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens) - output = torch.empty_like(query) - start = 0 - # Dynamic sequence length not supported with custom attn_bias. - for i, seq_len in enumerate(seq_lens): - end = start + seq_len - out = xops.memory_efficient_attention_forward( - query[None, start:end], - key[None, start:end], - value[None, start:end], - attn_bias=attn_bias[i], - p=0.0, - scale=scale, - ) - output[start:end].copy_(out.view_as(query[start:end])) - start += seq_len - # xformers.AttentionBias to Tensor for use in reference impl. - alibi_bias = [ - b.materialize((1, num_query_heads, i, i), device=device).squeeze() - for b, i in zip(attn_bias, seq_lens) - ] - else: - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - ) - output = output.squeeze(0) - - cu_seq_lens = [0] - for seq_len in seq_lens: - cu_seq_lens.append(cu_seq_lens[-1] + seq_len) - ref_output = ref_multi_query_kv_attention( - cu_seq_lens, - query, - key, - value, - scale, - alibi_bias, - dtype, - ) - atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 - rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) - - -@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", [64]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.skipif( - current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm." -) -@torch.inference_mode() -def test_multi_query_kv_attention_with_alibi( - num_seqs: int, - num_heads: tuple[int, int], - head_size: int, - dtype: torch.dtype, - seed: int, - device: str, -) -> None: - return test_multi_query_kv_attention( - num_seqs, - num_heads, - head_size, - dtype, - seed, - device, - use_alibi=True, - ) - - @pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention]) def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None: head_size = 64 diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 9be56a33f76c..cd34b520ea71 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -34,7 +34,7 @@ def clear_cache(): } DEVICE_REGULAR_ATTN_BACKENDS = { - "cuda": ["XFORMERS", "FLASHINFER", "FLASH_ATTN"], + "cuda": ["FLASHINFER", "FLASH_ATTN"], "hip": ["ROCM_ATTN"], "cpu": ["CPU_ATTN"], } @@ -207,12 +207,6 @@ def test_env( ) expected = "FLASHINFER" assert backend.get_name() == expected - elif name == "XFORMERS": - backend = get_attn_backend( - 32, torch.float16, None, block_size, use_mla=use_mla - ) - expected = "XFORMERS" - assert backend.get_name() == expected elif name == "FLASH_ATTN": backend = get_attn_backend( 32, torch.float16, None, block_size, use_mla=use_mla diff --git a/tests/kernels/attention/test_mha_attn.py b/tests/kernels/attention/test_mha_attn.py index a878ac6396ce..ae3c63cc62d6 100644 --- a/tests/kernels/attention/test_mha_attn.py +++ b/tests/kernels/attention/test_mha_attn.py @@ -24,10 +24,6 @@ def clear_cache(): """Clear lru cache to ensure each test case runs without caching.""" _cached_get_attn_backend.cache_clear() - # Clear xformers availability cache - import vllm.attention.layer as layer_module - - layer_module.USE_XFORMERS_OPS = None @pytest.mark.parametrize("device", ["cpu", "hip", "cuda"]) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 5d5a26fbfc2c..9307ef7814a8 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -509,43 +509,6 @@ def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs: ) -def make_alibi_bias( - alibi_slopes: torch.Tensor, - num_kv_heads: int, - dtype: torch.dtype, - seq_lens: list[int], -) -> list[Any]: - """Create ALiBi biases compatible with xFormers attention tests.""" - from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias - - if alibi_slopes is None: - return [None for _ in seq_lens] - - attn_biases: list[Any] = [] - num_heads = alibi_slopes.shape[0] - assert num_heads >= num_kv_heads, ( - "ALiBi slopes expect at least as many heads as KV heads" - ) - - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) - bias = bias[None, :] - bias[:, None] - - padded_len = (seq_len + 7) // 8 * 8 - bias_tensor = torch.empty( - 1, - num_heads, - seq_len, - padded_len, - device=alibi_slopes.device, - dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias_tensor.mul_(alibi_slopes[:, None, None]) - attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor)) - - return attn_biases - - def _make_metadata_tensors( seq_lens: list[int] | None, context_lens: list[int] | None, @@ -649,23 +612,12 @@ def make_kv_cache( Returns: - * kv_cache: 2 x num_blocks x (block_size * num_heads * head_size) - * for backend 'XFORMERS' * kv_cache: 2 x num_blocks x block_size x num_heads x head_size * for backend 'FLASH_ATTN' """ - if backend == "XFORMERS": - kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to( - device - ) - elif backend == "FLASH_ATTN": - kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to( - device - ) - else: - raise ValueError( - f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'." - ) + if backend != "FLASH_ATTN": + raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.") + kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(device) if default_val is not None: kv_cache[:, :, :] = default_val return kv_cache @@ -843,22 +795,14 @@ def assert_actual_matches_ideal( * output_under_test: actually observed output value """ ideal_output = test_params.packed_qkvo.ideal_output - if backend == "XFORMERS": - torch.testing.assert_close( - ideal_output, output_under_test.view_as(ideal_output) - ) - - elif backend == "FLASH_ATTN": - # For FlashAttention override the accuracy thresholds to non default - # values since we notice a higher difference between the ideal and - # actual output. - torch.testing.assert_close( - ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016 - ) - else: - raise ValueError( - f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'." - ) + if backend != "FLASH_ATTN": + raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.") + # For FlashAttention override the accuracy thresholds to non default + # values since we notice a higher difference between the ideal and + # actual output. + torch.testing.assert_close( + ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016 + ) # Copied/modified from torch._refs.__init__.py diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index 1cf8ed602b6a..e430826461a1 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -57,10 +57,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: return generated_texts -@pytest.mark.xfail( - current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm", -) def test_minicpmv_lora(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, @@ -84,10 +80,6 @@ def test_minicpmv_lora(minicpmv_lora_files): @pytest.mark.skipif( current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" ) -@pytest.mark.xfail( - current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm", -) @multi_gpu_test(num_gpus=4) def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( @@ -108,10 +100,6 @@ def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): @pytest.mark.skipif( current_platform.is_cuda_alike(), reason="Skipping to avoid redundant model tests" ) -@pytest.mark.xfail( - current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm", -) @multi_gpu_test(num_gpus=4) def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( diff --git a/tests/lora/test_qwen2vl.py b/tests/lora/test_qwen2vl.py index 1800ca107a42..7d8c940100ca 100644 --- a/tests/lora/test_qwen2vl.py +++ b/tests/lora/test_qwen2vl.py @@ -2,12 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -import pytest - import vllm from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest -from vllm.platforms import current_platform from vllm.sampling_params import BeamSearchParams @@ -142,10 +139,6 @@ def run_beam_search_test( QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct" -@pytest.mark.xfail( - current_platform.is_rocm(), - reason="Qwen2-VL dependency xformers incompatible with ROCm", -) def test_qwen2vl_lora(qwen2vl_lora_files): """Test Qwen 2.0 VL model with LoRA""" config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files) @@ -156,10 +149,6 @@ def test_qwen2vl_lora(qwen2vl_lora_files): tester.run_test(TEST_IMAGES, expected_outputs=EXPECTED_OUTPUTS, lora_id=lora_id) -@pytest.mark.xfail( - current_platform.is_rocm(), - reason="Qwen2-VL dependency xformers incompatible with ROCm", -) def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): """Test Qwen 2.0 VL model with LoRA through beam search.""" config = TestConfig(model_path=QWEN2VL_MODEL_PATH, lora_path=qwen2vl_lora_files) @@ -178,10 +167,6 @@ def test_qwen2vl_lora_beam_search(qwen2vl_lora_files): ) -@pytest.mark.xfail( - current_platform.is_rocm(), - reason="Qwen2.5-VL dependency xformers incompatible with ROCm", -) def test_qwen25vl_lora(qwen25vl_lora_files): """Test Qwen 2.5 VL model with LoRA""" config = TestConfig(model_path=QWEN25VL_MODEL_PATH, lora_path=qwen25vl_lora_files) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 6747cf7743b1..125e4e382774 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -43,7 +43,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" - XFORMERS = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" ROCM_AITER_TRITON_MLA = ( diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a8e796a1eab6..f1d57ac50fb9 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -51,31 +51,6 @@ FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) -USE_XFORMERS_OPS = None - - -def check_xformers_availability(): - global USE_XFORMERS_OPS - if USE_XFORMERS_OPS is not None: - return USE_XFORMERS_OPS - - if current_platform.is_cuda() and current_platform.has_device_capability(100): - # Xformers FA is not compatible with B200 - USE_XFORMERS_OPS = False - else: - try: - from importlib.util import find_spec - - find_spec("xformers.ops") - USE_XFORMERS_OPS = True - except ImportError: - USE_XFORMERS_OPS = False - - # the warning only needs to be shown once - if not USE_XFORMERS_OPS: - logger.warning("Xformers is not available, falling back.") - - return USE_XFORMERS_OPS def check_upstream_fa_availability(dtype: torch.dtype): @@ -533,7 +508,6 @@ def __init__( if backend in { AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.XFORMERS, AttentionBackendEnum.PALLAS, AttentionBackendEnum.ROCM_AITER_FA, AttentionBackendEnum.FLASH_ATTN, @@ -549,12 +523,6 @@ def __init__( ) ) - if ( - self.attn_backend == AttentionBackendEnum.XFORMERS - and not check_xformers_availability() - ): - self.attn_backend = AttentionBackendEnum.TORCH_SDPA - self.is_flash_attn_backend = self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, @@ -614,12 +582,6 @@ def forward( max_seqlen_k=kv_len, softmax_scale=self.scale, ) - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - from xformers import ops as xops - - out = xops.memory_efficient_attention_forward( - query, key, value, scale=self.scale - ) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: query, key, value = (x.transpose(1, 2) for x in (query, key, value)) out = F.scaled_dot_product_attention(query, key, value, scale=self.scale) diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 06a9f7cd8226..46f8f5117f7a 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -3,7 +3,7 @@ """ This file contains ops for ViT attention to be compatible with torch.compile as there are operations here not supported by torch.compile (for instance, -`to_list` in xformers attn, or `.item()` in flash attention) +`.item()` in flash attention) Using these ops and wrapping vision blocks with `torch.compile` can speed up throughput in vision models by ~5% relative on H100, and improve token @@ -19,42 +19,6 @@ from vllm.utils.torch_utils import direct_register_custom_op -def xformers_attn_seqlens_wrapper( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor -) -> torch.Tensor: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask - - attn_bias = BlockDiagonalMask.from_seqlens( - q_seqlen=seqlens.tolist(), kv_seqlen=None, device=q.device - ) - context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None - ) - context_layer = einops.rearrange(context_layer, "b s h d -> s b (h d)").contiguous() - return context_layer - - -def xformers_attn_seqlens_wrapper_fake( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor -) -> torch.Tensor: - b, s, h, d = q.shape - return torch.empty((s, b, h * d), dtype=q.dtype, device=q.device) - - -direct_register_custom_op( - op_name="xformers_attn_seqlens_wrapper", - op_func=xformers_attn_seqlens_wrapper, - fake_impl=xformers_attn_seqlens_wrapper_fake, -) - - -def vit_xformers_attn_wrapper( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seqlens: torch.Tensor -) -> torch.Tensor: - return torch.ops.vllm.xformers_attn_seqlens_wrapper(q, k, v, seqlens) - - def flash_attn_maxseqlen_wrapper( q: torch.Tensor, k: torch.Tensor, diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index e9af08b2316d..ad19b58aa155 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -36,7 +36,14 @@ def get_env_variable_attn_backend() -> AttentionBackendEnum | None: * None otherwise """ backend_name = os.environ.get(STR_BACKEND_ENV_VAR) - return None if backend_name is None else AttentionBackendEnum[backend_name] + if backend_name is None: + return None + if backend_name == "XFORMERS": + raise ValueError( + "Attention backend 'XFORMERS' has been removed (See PR #29262 for " + "details). Please select a supported attention backend." + ) + return AttentionBackendEnum[backend_name] # Global state allows a particular choice of backend diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 9f62b35ed515..00a81a319bf7 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -173,6 +173,12 @@ def _validate_mm_encoder_attn_backend( # We need to import the real type here (deferred to avoid circular import). from vllm.attention.backends.registry import AttentionBackendEnum + if isinstance(value, str) and value.upper() == "XFORMERS": + raise ValueError( + "Attention backend 'XFORMERS' has been removed (See PR #29262 for " + "details). Please select a supported attention backend." + ) + if value is None or isinstance(value, AttentionBackendEnum): return value diff --git a/vllm/envs.py b/vllm/envs.py index 9b1ed1fc680b..56558548d398 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -640,7 +640,6 @@ def get_vllm_port() -> int | None: # Example options: # - "TORCH_SDPA": use torch.nn.MultiheadAttention # - "FLASH_ATTN": use FlashAttention - # - "XFORMERS": use XFormers # - "FLASHINFER": use flashinfer # - "FLASHMLA": use FlashMLA # - "FLASH_ATTN_MLA": use FlashAttention for MLA diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 2d2251e83b5b..5460018d0d67 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -306,7 +306,6 @@ def __init__( if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.XFORMERS, AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( @@ -324,7 +323,6 @@ def forward( rotary_pos_emb: torch.Tensor | None = None, *, max_seqlen: int | None = None, - seqlens: list[int] | None = None, ) -> torch.Tensor: # [S, C] -> [S, B=1, C] x = hidden_states.unsqueeze(1) @@ -374,16 +372,6 @@ def forward( out_i = out_i.permute(0, 2, 1, 3) outputs.append(out_i) context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask - - attn_bias = BlockDiagonalMask.from_seqlens( - q_seqlen=seqlens, kv_seqlen=None, device=q.device - ) - context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None - ) else: raise RuntimeError("Unsupported attention backend") @@ -545,14 +533,12 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, max_seqlen: int | None = None, - seqlens: list[int] | None = None, ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen, - seqlens=seqlens, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -663,18 +649,14 @@ def rot_pos_emb(self, grid_thw: list[list[int]]) -> torch.Tensor: rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def compute_attn_mask_seqlen( - self, cu_seqlens: torch.Tensor - ) -> tuple[int | None, list[int] | None]: - max_seqlen, seqlens = None, None + def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None: + max_seqlen = None if ( self.attn_backend == AttentionBackendEnum.FLASH_ATTN or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - return max_seqlen, seqlens + return max_seqlen def forward( self, hidden_states: torch.Tensor, grid_thw: list[list[int]] @@ -694,14 +676,13 @@ def forward( ) cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens]) - max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) for blk in self.blocks: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen, - seqlens=seqlens, ) if self.post_trunk_norm is not None: diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index daa5bf03ea4a..07b34fbc8add 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -214,7 +214,6 @@ def __init__( if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.XFORMERS, AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( @@ -259,7 +258,6 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, max_seqlen: int | None = None, # Only used for Flash Attention - seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -311,20 +309,6 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask - - attn_bias = BlockDiagonalMask.from_seqlens( - q_seqlen=seqlens, kv_seqlen=None, device=q.device - ) - - context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None - ) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() output, _ = self.proj(context_layer) return output @@ -404,14 +388,12 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, max_seqlen: int | None = None, # Only used for Flash Attention - seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen, - seqlens=seqlens, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states @@ -562,18 +544,14 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) return rotary_pos_emb - def compute_attn_mask_seqlen( - self, cu_seqlens: torch.Tensor - ) -> tuple[int | None, list[int] | None]: - max_seqlen, seqlens = None, None + def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None: + max_seqlen = None if ( self.attn_backend == AttentionBackendEnum.FLASH_ATTN or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - return max_seqlen, seqlens + return max_seqlen def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, num_pad=0 @@ -598,8 +576,8 @@ def forward( if hidden_states.ndim == 2: hidden_states = hidden_states.unsqueeze(dim=1) - # pre-compute seqlens for attn mask to reduce cuMemcpy operations - max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations + max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) for i, blk in enumerate(self.blocks): hidden_states = blk( @@ -607,7 +585,6 @@ def forward( cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen, - seqlens=seqlens, ) final_output = self.ln(hidden_states) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index d141e9549806..7e0370886884 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -309,7 +309,6 @@ def __init__( if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.XFORMERS, AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( @@ -345,7 +344,6 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: int | None = None, # Only used for Flash Attention - seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -400,20 +398,6 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask - - attn_bias = BlockDiagonalMask.from_seqlens( - q_seqlen=seqlens, kv_seqlen=None, device=q.device - ) - - context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None - ) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() output, _ = self.proj(context_layer) return output @@ -461,7 +445,6 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: int | None = None, # Only used for Flash Attention - seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: x_attn = self.attn( self.norm1(x), @@ -469,7 +452,6 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, - seqlens=seqlens, ) x_fused_norm, residual = self.norm2(x, residual=x_attn) x = residual + self.mlp(x_fused_norm) @@ -803,15 +785,14 @@ def rot_pos_emb( def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[int | None, list[int] | None]: - max_seqlen, seqlens = None, None - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + ) -> int | None: + max_seqlen = None if ( self.attn_backend == AttentionBackendEnum.FLASH_ATTN or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - return max_seqlen, seqlens + return max_seqlen def forward( self, @@ -836,8 +817,9 @@ def forward( ).cumsum(dim=0, dtype=torch.int32) cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) - # pre-compute seqlens for attn mask to reduce cuMemcpy operations - max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + # pre-compute max_seqlen for attn mask to reduce cuMemcpy operations + max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() x = self.embeddings( x, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1] ) @@ -851,7 +833,6 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, - seqlens=seqlens, ) # adapter diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 8fc3db296aa7..302260b95299 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -9,6 +9,7 @@ import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from einops import rearrange from transformers import PretrainedConfig from transformers.activations import GELUActivation @@ -424,7 +425,7 @@ def __init__( if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.XFORMERS, + AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( @@ -451,7 +452,6 @@ def forward( ) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() batch_size = q.shape[0] if rope_emb is None: @@ -498,17 +498,21 @@ def forward( softmax_scale=self.scale, ) context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask - - attn_bias = BlockDiagonalMask.from_seqlens( - q_seqlen=seqlens, kv_seqlen=None, device=q.device - ) - - context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None - ) + elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = ( + rearrange(x, "b s h d -> b h s d") for x in (q_i, k_i, v_i) + ) + output_i = F.scaled_dot_product_attention(q_i, k_i, v_i, dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) if outputs else q[:, :0] context_layer = rearrange(context_layer, "b s h d -> b s (h d)").contiguous() diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index dee0c16ab0f6..74bb868492da 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -38,7 +38,6 @@ ) from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, - vit_xformers_attn_wrapper, ) from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -657,7 +656,6 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor | None, max_seqlen: torch.Tensor | None, - seqlens: torch.Tensor | None, ) -> torch.Tensor: batch_size, _, _ = hidden_states.shape @@ -703,10 +701,6 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - if seqlens is None: - raise ValueError("xFormers attention backend requires seqlens tensor.") - context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) else: raise RuntimeError( f"PaddleOCR-VL does not support {self.attn_backend} backend now." @@ -818,7 +812,6 @@ def forward( cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor | None, max_seqlen: torch.Tensor | None, - seqlens: torch.Tensor | None, ) -> torch.Tensor: residual = hidden_states @@ -828,7 +821,6 @@ def forward( cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen, - seqlens=seqlens, ) hidden_states = residual + hidden_states @@ -870,7 +862,6 @@ def __init__( if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.XFORMERS, AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( @@ -943,14 +934,11 @@ def forward( cu_seqlens = cu_seqlens.to(device=device) max_seqlen = None - seqlens = None if self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] hidden_states = inputs_embeds for encoder_layer in self.layers: @@ -959,7 +947,6 @@ def forward( cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, max_seqlen=max_seqlen, - seqlens=seqlens, ) return hidden_states diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 8a034fd72b02..6011d93a795d 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -74,6 +74,7 @@ ) try: + # Note: vLLM does not install xformers by default. from xformers import ops as xops if current_platform.is_cuda() and current_platform.has_device_capability(100): diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 1500a437613c..8c707c2561af 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -46,7 +46,6 @@ from vllm.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, vit_torch_sdpa_wrapper, - vit_xformers_attn_wrapper, ) from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig @@ -375,7 +374,6 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention - seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, head * 3 * head_dim] x, _ = self.qkv(x) @@ -435,8 +433,6 @@ def forward( v, cu_seqlens, ) - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens) output, _ = self.proj(context_layer) return output @@ -448,9 +444,7 @@ def forward( "cu_seqlens": 0, "rotary_pos_emb_cos": 0, "rotary_pos_emb_sin": 0, - "seqlens": 0, }, - mark_unbacked_dims={"seqlens": 0}, enable_if=should_torch_compile_mm_vit, ) class Qwen2_5_VisionBlock(nn.Module): @@ -501,7 +495,6 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention - seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: x_attn = self.attn( self.norm1(x), @@ -509,7 +502,6 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, - seqlens=seqlens, ) x_fused_norm, residual = self.norm2(x, residual=x_attn) x = residual + self.mlp(x_fused_norm) @@ -670,7 +662,6 @@ def __init__( if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.XFORMERS, AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( @@ -822,17 +813,14 @@ def get_rope_by_thw(self, t, h, w): def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: max_seqlen = torch.zeros([], device=cu_seqlens.device) - seqlens = torch.zeros(1, device=cu_seqlens.device) if self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - return max_seqlen, seqlens + return max_seqlen @staticmethod def invert_permutation(perm: torch.Tensor) -> torch.Tensor: @@ -897,10 +885,8 @@ def forward( # transformers # pre-compute seqlens for window/full attn to reduce cuMemcpy operations - max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens) - max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen( - cu_window_seqlens - ) + max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens) + max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens) cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True) cu_window_seqlens = cu_window_seqlens.to(device=self.device, non_blocking=True) @@ -927,11 +913,9 @@ def forward( if layer_num in self.fullatt_block_indexes: cu_seqlens_now = cu_seqlens max_seqlen_now = max_seqlen_full - seqlens_now = seqlens_full else: cu_seqlens_now = cu_window_seqlens max_seqlen_now = max_seqlen_window - seqlens_now = seqlens_window hidden_states = blk( hidden_states, @@ -939,7 +923,6 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen_now, - seqlens=seqlens_now, ) # For Qwen2.5-VL-3B, float16 will overflow at last block diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 479a7871e364..9d1d023aed17 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -348,7 +348,6 @@ def __init__( if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.XFORMERS, AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( @@ -384,7 +383,6 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: int | None = None, # Only used for Flash Attention - seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: # [s, b, c] --> [s, b, 3 * head * head_dim] x, _ = self.qkv(x) @@ -445,20 +443,6 @@ def forward( context_layer = rearrange( context_layer, "b s h d -> s b (h d)" ).contiguous() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalMask - - attn_bias = BlockDiagonalMask.from_seqlens( - q_seqlen=seqlens, kv_seqlen=None, device=q.device - ) - - context_layer = xops.memory_efficient_attention_forward( - q, k, v, attn_bias=attn_bias, p=0, scale=None - ) - context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" - ).contiguous() output, _ = self.proj(context_layer) return output @@ -509,7 +493,6 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: int | None = None, # Only used for Flash Attention - seqlens: list[int] | None = None, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -517,7 +500,6 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, - seqlens=seqlens, ) x = x + self.mlp(self.norm2(x)) @@ -728,18 +710,14 @@ def rot_pos_emb( sin_combined = sin[pos_ids].flatten(1) return cos_combined, sin_combined - def compute_attn_mask_seqlen( - self, cu_seqlens: torch.Tensor - ) -> tuple[int | None, list[int] | None]: - max_seqlen, seqlens = None, None + def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> int | None: + max_seqlen = None if self.attn_backend in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.ROCM_AITER_FA, }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - return max_seqlen, seqlens + return max_seqlen def forward( self, @@ -771,7 +749,7 @@ def forward( x = x.unsqueeze(1) # pre-compute seqlens for attn mask to reduce cuMemcpy operations - max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) for blk in self.blocks: x = blk( @@ -780,7 +758,6 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, - seqlens=seqlens, ) # adapter diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index 54ef56f83344..61f218f16d79 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -224,7 +224,6 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention - seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -232,7 +231,6 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, - seqlens=seqlens, ) x = x + self.mlp(self.norm2(x)) @@ -500,14 +498,11 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: max_seqlen = torch.zeros([], device=cu_seqlens.device) - seqlens = torch.zeros(1, device=cu_seqlens.device) if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - return max_seqlen, seqlens + return max_seqlen def forward( self, @@ -533,7 +528,7 @@ def forward( hidden_states = hidden_states.unsqueeze(1) rotary_pos_emb_cos = rotary_pos_emb_cos.to(hidden_states.device) rotary_pos_emb_sin = rotary_pos_emb_sin.to(hidden_states.device) - max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) hidden_states_list = [] deepstack_visual_indexes = self.deepstack_visual_indexes @@ -545,7 +540,6 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, - seqlens=seqlens, ) if ( deepstack_visual_indexes is not None diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 90c4894d33e8..4cd6fa14c32d 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -235,7 +235,6 @@ def forward( rotary_pos_emb_cos: torch.Tensor, rotary_pos_emb_sin: torch.Tensor, max_seqlen: torch.Tensor, # Only used for Flash Attention - seqlens: torch.Tensor, # Only used for xFormers ) -> torch.Tensor: x = x + self.attn( self.norm1(x), @@ -243,7 +242,6 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, - seqlens=seqlens, ) x = x + self.mlp(self.norm2(x)) @@ -391,7 +389,6 @@ def __init__( if self.attn_backend not in { AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.XFORMERS, AttentionBackendEnum.ROCM_AITER_FA, }: raise RuntimeError( @@ -531,17 +528,14 @@ def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor: def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: max_seqlen = torch.zeros([], device=cu_seqlens.device) - seqlens = torch.zeros(1, device=cu_seqlens.device) if ( self.attn_backend == AttentionBackendEnum.FLASH_ATTN or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA ): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - elif self.attn_backend == AttentionBackendEnum.XFORMERS: - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - return max_seqlen, seqlens + return max_seqlen def forward( self, @@ -569,7 +563,7 @@ def forward( cu_seqlens = torch.from_numpy(cu_seqlens) hidden_states = hidden_states.unsqueeze(1) - max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens) cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) deepstack_feature_lists = [] @@ -580,7 +574,6 @@ def forward( rotary_pos_emb_cos=rotary_pos_emb_cos, rotary_pos_emb_sin=rotary_pos_emb_sin, max_seqlen=max_seqlen, - seqlens=seqlens, ) if layer_num in self.deepstack_visual_indexes: deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index f9bf242b7194..06793a3d1bb1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -277,12 +277,7 @@ def get_vit_attn_backend( except ImportError: pass - if cls.has_device_capability(100): - # xFormers doesn't support Blackwell, fall back to SDPA - # See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501 - return AttentionBackendEnum.TORCH_SDPA - else: - return AttentionBackendEnum.XFORMERS + return AttentionBackendEnum.TORCH_SDPA @classmethod def get_valid_backends( diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 3ef44e770320..d94da71b289f 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -49,7 +49,6 @@ def __dir__() -> list[str]: # Possible string values of STR_BACKEND_ENV_VAR # register, corresponding to possible backends STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" -STR_XFORMERS_ATTN_VAL: str = "XFORMERS" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_INVALID_VAL: str = "INVALID" diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py deleted file mode 100644 index 5039c44b9c3e..000000000000 --- a/vllm/v1/attention/backends/xformers.py +++ /dev/null @@ -1,420 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with XFormersAttention.""" - -from dataclasses import dataclass -from typing import ClassVar, Optional - -import torch - -from vllm.attention.backends.abstract import ( - AttentionBackend, - AttentionImpl, - AttentionType, - MultipleOf, -) -from vllm.attention.ops.triton_unified_attention import unified_attention -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, - CommonAttentionMetadata, - split_decodes_and_prefills, -) -from vllm.v1.kv_cache_interface import AttentionSpec - -try: - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import ( - AttentionBias, - PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, - ) - - XFORMERS_AVAILABLE = True -except ImportError: - XFORMERS_AVAILABLE = False - -from vllm import _custom_ops as ops - -logger = init_logger(__name__) - - -class XFormersAttentionBackend(AttentionBackend): - accept_output_buffer: bool = True - supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] - - @staticmethod - def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: - return [MultipleOf(16)] - - @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return [ - 32, - 40, - 48, - 56, - 64, - 72, - 80, - 88, - 96, - 104, - 112, - 120, - 128, - 136, - 144, - 152, - 160, - 168, - 176, - 184, - 192, - 200, - 208, - 216, - 224, - 232, - 240, - 248, - 256, - ] - - @staticmethod - def get_name() -> str: - return "XFORMERS" - - @staticmethod - def get_impl_cls() -> type["XFormersAttentionImpl"]: - return XFormersAttentionImpl - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - if block_size % 16 != 0: - raise ValueError("Block size must be a multiple of 16.") - return (2, num_blocks, block_size, num_kv_heads, head_size) - - @staticmethod - def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]: - return XFormersAttentionMetadataBuilder - - @staticmethod - def use_cascade_attention(*args, **kwargs) -> bool: - return False - - -@dataclass -class XFormersAttentionMetadata: - num_actual_tokens: int # Number of tokens excluding padding. - max_query_len: int - query_start_loc: torch.Tensor - max_seq_len: int - seq_lens: torch.Tensor - block_table: torch.Tensor - slot_mapping: torch.Tensor - - num_prefill_tokens: int = 0 - num_decode_tokens: int = 0 - num_prefills: int = 0 - num_decodes: int = 0 - - # Biases for different attention types. - attn_bias: Optional["AttentionBias"] = None - - # Self-attention prefill/decode metadata cache - _cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None - _cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None - - @property - def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - # Recover cached prefill-phase attention - # metadata structure - return self._cached_prefill_metadata - - q_start_loc = self.query_start_loc[self.num_decodes :] - q_seqlens = torch.diff(q_start_loc) - kv_seqlens = self.seq_lens[self.num_decodes :] - # Construct & cache prefill-phase attention metadata structure - self._cached_prefill_metadata = XFormersAttentionMetadata( - num_actual_tokens=self.num_prefill_tokens, - max_query_len=int(q_seqlens.max().item()), - query_start_loc=q_start_loc - q_start_loc[0], - max_seq_len=int(kv_seqlens.max().item()), - seq_lens=kv_seqlens, - block_table=self.block_table[self.num_decodes :], - slot_mapping=self.slot_mapping[self.num_decode_tokens :], - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - # Recover cached decode-phase attention - # metadata structure - return self._cached_decode_metadata - - q_start_loc = self.query_start_loc - q_seqlens = torch.diff(q_start_loc) - decode_kv_seqlens = self.seq_lens[: self.num_decodes] - # Construct & cache decode-phase attention metadata structure - self._cached_decode_metadata = XFormersAttentionMetadata( - num_actual_tokens=self.num_decode_tokens, - max_query_len=int(q_seqlens[: self.num_decodes].max().item()), - query_start_loc=q_start_loc[: self.num_decodes + 1], - max_seq_len=int(decode_kv_seqlens.max().item()), - seq_lens=decode_kv_seqlens, - block_table=self.block_table[: self.num_decodes], - slot_mapping=self.slot_mapping[: self.num_decode_tokens], - attn_bias=self.attn_bias, - ) - return self._cached_decode_metadata - - -class XFormersAttentionMetadataBuilder( - AttentionMetadataBuilder[XFormersAttentionMetadata] -): - reorder_batch_threshold: int = 1 - - def __init__( - self, - kv_cache_spec: AttentionSpec, - layer_names: list[str], - vllm_config: VllmConfig, - device: torch.device, - ): - super().__init__(kv_cache_spec, layer_names, vllm_config, device) - - assert XFORMERS_AVAILABLE - self.block_size = kv_cache_spec.block_size - self._num_decodes = 0 - self._num_decode_tokens = 0 - - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ) -> XFormersAttentionMetadata: - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills( - common_attn_metadata, decode_threshold=self.reorder_batch_threshold - ) - ) - - num_actual_tokens = common_attn_metadata.num_actual_tokens - q_start_loc = common_attn_metadata.query_start_loc - q_seqlens = torch.diff(q_start_loc) - max_query_len = common_attn_metadata.max_query_len - kv_seqlens = common_attn_metadata.seq_lens - max_seq_len = common_attn_metadata.max_seq_len - block_table = common_attn_metadata.block_table_tensor - slot_mapping = common_attn_metadata.slot_mapping - - bias = None - if num_decodes > 0: - # Construct the decoder bias. - decode_q_seqlens = q_seqlens[:num_decodes] - decode_kv_seqlens = kv_seqlens[:num_decodes] - bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( - q_seqlen=decode_q_seqlens.tolist(), - kv_seqlen=decode_kv_seqlens.tolist(), - page_size=self.block_size, - block_tables=block_table[:num_decodes], - device=block_table.device, - ) - - return XFormersAttentionMetadata( - num_actual_tokens=num_actual_tokens, - num_prefill_tokens=num_prefill_tokens, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, - num_decodes=num_decodes, - max_query_len=max_query_len, - query_start_loc=q_start_loc, - max_seq_len=max_seq_len, - seq_lens=kv_seqlens, - block_table=block_table, - slot_mapping=slot_mapping, - attn_bias=bias, - ) - - -class XFormersAttentionImpl(AttentionImpl): - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: list[float] | None, - sliding_window: int | None, - kv_cache_dtype: str, - logits_soft_cap: float | None = None, - attn_type: AttentionType = AttentionType.DECODER, - kv_sharing_target_layer_name: str | None = None, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") - if alibi_slopes is not None: - raise NotImplementedError("XFormers does not support alibi slopes yet.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.kv_cache_dtype = kv_cache_dtype - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - if sliding_window is None: - self.sliding_window = (-1, -1) - else: - self.sliding_window = (sliding_window - 1, 0) - if logits_soft_cap is None: - # Setting logits_soft_cap to 0 means no soft cap. - logits_soft_cap = 0 - self.logits_soft_cap = logits_soft_cap - - if attn_type != AttentionType.DECODER: - raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "XFormersAttentionImpl." - ) - - def forward( - self, - layer: torch.nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: XFormersAttentionMetadata, - output: torch.Tensor | None = None, - output_scale: torch.Tensor | None = None, - output_block_scale: torch.Tensor | None = None, - ) -> torch.Tensor: - """Forward pass with XFormers. - - Args: - query: shape = [num_tokens, num_heads, head_size] - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache: shape = - [2, num_blocks, block_size, num_kv_heads, head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - assert output is not None, "Output tensor must be provided." - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for XFormersAttentionImpl" - ) - - if attn_metadata is None: - # Profiling run. - return output.fill_(0) - - # Cache the input KVs. - key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - num_actual_tokens = attn_metadata.num_actual_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - if prefill_meta := attn_metadata.prefill_metadata: - descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1]) - unified_attention( - q=query[num_decode_tokens:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[num_decode_tokens:num_actual_tokens], - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - seqused_k=prefill_meta.seq_lens, - max_seqlen_k=prefill_meta.max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=prefill_meta.block_table, - softcap=self.logits_soft_cap, - q_descale=None, # Not supported - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - ) - - if decode_meta := attn_metadata.decode_metadata: - # Query for decode. KV is not needed because it is already cached. - decode_query = query[:num_decode_tokens] - # Reshape query to [1, B_T, G, H, D]. - q = decode_query.view( - 1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size - ) - # Reshape the k and v caches to [1, Bkv_T, G, H, D] - cache_k = key_cache.view( - 1, -1, self.num_kv_heads, 1, self.head_size - ).expand( - 1, - -1, - self.num_kv_heads, - self.num_queries_per_kv, - self.head_size, - ) - cache_v = value_cache.view( - 1, -1, self.num_kv_heads, 1, self.head_size - ).expand( - 1, - -1, - self.num_kv_heads, - self.num_queries_per_kv, - self.head_size, - ) - - attn_bias = decode_meta.attn_bias - output[:num_decode_tokens] = xops.memory_efficient_attention_forward( - q, - cache_k, - cache_v, - attn_bias=attn_bias, - p=0.0, - scale=self.scale, - ).view(decode_query.shape) - - # Reshape the output tensor. - return output