diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 113602645e89..ac34f279d0b5 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from typing import List, Optional +from typing import Callable, List, Optional import torch import torch.nn as nn @@ -68,9 +68,39 @@ def check_upstream_fa_availability(dtype: torch.dtype): ) and current_platform.has_device_capability(80): from transformers.utils import is_flash_attn_2_available return is_flash_attn_2_available() + if current_platform.is_rocm(): + from importlib.util import find_spec + return find_spec("flash_attn") is not None return False +def maybe_get_vit_flash_attn_backend( + attn_backend: _Backend, + use_upstream_fa: bool) -> tuple[_Backend, Callable]: + if attn_backend != _Backend.FLASH_ATTN and \ + attn_backend != _Backend.ROCM_AITER_FA and \ + check_upstream_fa_availability(torch.get_default_dtype()): + attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + + if current_platform.is_rocm() and \ + attn_backend == _Backend.FLASH_ATTN: + use_upstream_fa = True + + if (attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}): + if attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + if use_upstream_fa: + from flash_attn import flash_attn_varlen_func + else: + from vllm.vllm_flash_attn import flash_attn_varlen_func + else: + flash_attn_varlen_func = None + + return attn_backend, flash_attn_varlen_func + + class Attention(nn.Module, AttentionLayerBase): """Attention layer. @@ -410,13 +440,9 @@ def __init__( # to upstream flash attention if available. # If vllm native fa is selected, we use it directly. use_upstream_fa = False - if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( - dtype): - backend = _Backend.FLASH_ATTN - use_upstream_fa = True - if current_platform.is_rocm() or current_platform.is_xpu(): - # currently, only torch_sdpa is supported on rocm/xpu + if current_platform.is_xpu(): + # currently, only torch_sdpa is supported on xpu self.attn_backend = _Backend.TORCH_SDPA else: @@ -428,17 +454,25 @@ def __init__( _Backend.FLASH_ATTN, } else _Backend.TORCH_SDPA + self.attn_backend, self._flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + use_upstream_fa, + ) + if (self.attn_backend == _Backend.XFORMERS and not check_xformers_availability()): self.attn_backend = _Backend.TORCH_SDPA - if self.attn_backend == _Backend.FLASH_ATTN: - if use_upstream_fa: - from flash_attn import flash_attn_varlen_func - self._flash_attn_varlen_func = flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func - self._flash_attn_varlen_func = flash_attn_varlen_func + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + # this condition is just to make sure that the + # use_upstream_fa in the log is correct + if current_platform.is_rocm() \ + and self.attn_backend == _Backend.FLASH_ATTN: + use_upstream_fa = True logger.info_once( f"MultiHeadAttention attn_backend: {self.attn_backend}, " @@ -466,7 +500,7 @@ def forward( key = torch.repeat_interleave(key, num_repeat, dim=2) value = torch.repeat_interleave(value, num_repeat, dim=2) - if self.attn_backend == _Backend.FLASH_ATTN: + if self.is_flash_attn_backend: cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, @@ -507,14 +541,6 @@ def forward( from torch_xla.experimental.custom_kernel import flash_attention out = flash_attention(query, key, value, sm_scale=self.scale) out = out.transpose(1, 2) - elif self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - - # ROCm Flash Attention expects (batch, seq, heads, head_dim) - out = flash_attn_varlen_func(query, - key, - value, - softmax_scale=self.scale) else: # ViT attention hasn't supported this backend yet raise NotImplementedError( diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 2445f0d784f4..86888c10ee39 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -10,7 +10,8 @@ from transformers.models.qwen2_vl import Qwen2VLProcessor from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.config import VllmConfig from vllm.distributed import utils as dist_utils from vllm.distributed.parallel_state import ( @@ -267,10 +268,12 @@ def __init__(self, self.attn_backend = get_vit_attn_backend( self.hidden_size_per_attention_head, torch.get_default_dtype()) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability(torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.ROCM_AITER_FA @@ -306,25 +309,18 @@ def forward( q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func q_ = q.reshape(bs * q.shape[1], q.shape[2], q.shape[3]) k_ = k.reshape(bs * k.shape[1], k.shape[2], k.shape[3]) v_ = v.reshape(bs * v.shape[1], v.shape[2], v.shape[3]) - output = flash_attn_varlen_func(q_, - k_, - v_, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) + output = self.flash_attn_varlen_func(q_, + k_, + v_, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) context_layer = output.view(bs, -1, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) @@ -611,7 +607,8 @@ def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if (self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 0b8e24407602..8da7b9f2c3e0 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -35,7 +35,8 @@ from transformers import BatchFeature from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -176,14 +177,18 @@ def __init__( dtype=torch.get_default_dtype()) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability(torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.ROCM_AITER_FA + _Backend.FLASH_ATTN, + _Backend.TORCH_SDPA, + _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( f"Ernie45-VL does not support {self.attn_backend} backend now." @@ -239,27 +244,18 @@ def forward( q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) + output = self.flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) context_layer = rearrange(output, "(b s) h d -> s b (h d)", @@ -516,7 +512,8 @@ def compute_attn_mask_seqlen( self, cu_seqlens: torch.Tensor ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if (self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 315a057e6a7d..e6e294a14349 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -47,7 +47,8 @@ from transformers.video_utils import VideoMetadata from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.config import VllmConfig from vllm.distributed import (get_tensor_model_parallel_world_size, parallel_state) @@ -263,19 +264,26 @@ def __init__( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype()) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability(torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA, }: raise RuntimeError( f"GLM-4V does not support {self.attn_backend} backend now.") + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: # [s, b, 3 * head * head_dim] seq_len, bs, _ = qkv.shape @@ -316,17 +324,11 @@ def forward( qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) q, k = torch.chunk(qk_rotated, 2, dim=0) - if self.attn_backend == _Backend.FLASH_ATTN: - # from vllm_flash_attn.flash_attn_interface import ( - # flash_attn_varlen_func) - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func + if self.is_flash_attn_backend: q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func( + output = self.flash_attn_varlen_func( q, k, v, @@ -774,7 +776,8 @@ def compute_attn_mask_seqlen( ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() - if self.attn_backend == _Backend.FLASH_ATTN: + if (self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() return max_seqlen, seqlens diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index a70df3b72be4..3c46516c7905 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -39,7 +39,8 @@ Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -302,6 +303,11 @@ def __init__( disable_tp=use_data_parallel) self.attn_backend = attn_backend self.use_upstream_fa = use_upstream_fa + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) self.is_flash_attn_backend = self.attn_backend in { _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA } @@ -354,25 +360,18 @@ def forward( q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) + output = self.flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) context_layer = rearrange(output, "(b s) h d -> s b (h d)", @@ -618,6 +617,7 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=head_dim, dtype=torch.get_default_dtype()) if self.attn_backend != _Backend.FLASH_ATTN and \ + self.attn_backend != _Backend.ROCM_AITER_FA and \ check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 2ff79765d4be..48dec351bd90 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -42,7 +42,8 @@ Qwen2VLVideoProcessor) from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import (check_upstream_fa_availability, + maybe_get_vit_flash_attn_backend) from vllm.config import VllmConfig from vllm.distributed import parallel_state, tensor_model_parallel_all_gather from vllm.distributed import utils as dist_utils @@ -319,11 +320,12 @@ def __init__( head_size=self.hidden_size_per_attention_head, dtype=torch.get_default_dtype()) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, @@ -331,6 +333,7 @@ def __init__( }: raise RuntimeError( f"Qwen2-VL does not support {self.attn_backend} backend now.") + self.is_flash_attn_backend = self.attn_backend in { _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA } @@ -383,25 +386,18 @@ def forward( q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - output = flash_attn_varlen_func(q, - k, - v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen, - max_seqlen_k=max_seqlen, - dropout_p=0.0, - causal=False) + output = self.flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) context_layer = rearrange(output, "(b s) h d -> s b (h d)", diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index fc8557131c3e..da6ca7940700 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -323,6 +323,7 @@ def __init__( head_size=head_dim, dtype=torch.get_default_dtype()) use_upstream_fa = False if self.attn_backend != _Backend.FLASH_ATTN and \ + self.attn_backend != _Backend.ROCM_AITER_FA and \ check_upstream_fa_availability( torch.get_default_dtype()): self.attn_backend = _Backend.FLASH_ATTN @@ -476,7 +477,8 @@ def compute_attn_mask_seqlen( cu_seqlens: torch.Tensor, ) -> tuple[Optional[int], Optional[list[int]]]: max_seqlen, seqlens = None, None - if self.attn_backend == _Backend.FLASH_ATTN: + if (self.attn_backend == _Backend.FLASH_ATTN + or self.attn_backend == _Backend.ROCM_AITER_FA): max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() elif self.attn_backend == _Backend.XFORMERS: seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() diff --git a/vllm/model_executor/models/siglip2navit.py b/vllm/model_executor/models/siglip2navit.py index d111a10809e7..5bea5b1daf4d 100644 --- a/vllm/model_executor/models/siglip2navit.py +++ b/vllm/model_executor/models/siglip2navit.py @@ -14,7 +14,7 @@ from transformers.configuration_utils import PretrainedConfig from vllm.attention.backends.registry import _Backend -from vllm.attention.layer import check_upstream_fa_availability +from vllm.attention.layer import maybe_get_vit_flash_attn_backend from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -240,11 +240,12 @@ def __init__( self.attn_backend = get_vit_attn_backend( head_size=self.head_dim, dtype=torch.get_default_dtype()) self.use_upstream_fa = False - if self.attn_backend != _Backend.FLASH_ATTN and \ - check_upstream_fa_availability( - torch.get_default_dtype()): - self.attn_backend = _Backend.FLASH_ATTN - self.use_upstream_fa = True + + self.attn_backend, self.flash_attn_varlen_func \ + = maybe_get_vit_flash_attn_backend( + self.attn_backend, + self.use_upstream_fa, + ) if self.attn_backend not in { _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, @@ -286,14 +287,7 @@ def forward( max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if self.is_flash_attn_backend: - if self.attn_backend == _Backend.ROCM_AITER_FA: - from aiter import flash_attn_varlen_func - else: - if self.use_upstream_fa: - from flash_attn import flash_attn_varlen_func - else: - from vllm.vllm_flash_attn import flash_attn_varlen_func - attn_output = flash_attn_varlen_func( + attn_output = self.flash_attn_varlen_func( queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(seq_length, -1) elif self.attn_backend == _Backend.TORCH_SDPA: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index e12967ad2587..de3df03d1fa0 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -189,8 +189,6 @@ def get_vit_attn_backend(cls, head_size: int, from vllm.attention.backends.registry import _Backend if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()): - # Note: AITER FA is only supported for Qwen-VL models. - # TODO: Add support for other VL models in their model class. return _Backend.ROCM_AITER_FA if on_gfx9(): return _Backend.FLASH_ATTN