diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index adb0879f20d4..bb157c5b6074 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -11,10 +11,32 @@ ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}} # Install some basic utilities RUN apt-get update -q -y && apt-get install -q -y \ sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev \ - apt-transport-https ca-certificates wget curl + apt-transport-https ca-certificates wget curl git + # Remove sccache RUN python3 -m pip install --upgrade pip RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" + +# BUILD FA only if both FLASH_ATTENTION_TRITON_AMD_ENABLE and GPU_ARCHS are passed/declared and non-empty +ARG FLASH_ATTENTION_TRITON_AMD_ENABLE +ARG GPU_ARCHS + +RUN if [ -n "${FLASH_ATTENTION_TRITON_AMD_ENABLE}" ] && [ -n "${GPU_ARCHS}" ]; then \ + echo "Compiling Flash Attention with GPU_ARCHS=${GPU_ARCHS}..." ; \ + export FLASH_ATTENTION_TRITON_AMD_ENABLE="${FLASH_ATTENTION_TRITON_AMD_ENABLE}"; \ + export GPU_ARCHS="${GPU_ARCHS}"; \ + git clone --single-branch --branch main_perf https://github.com/ROCm/flash-attention.git \ + && cd flash-attention \ + && python3 setup.py install \ + && cd .. \ + && rm -rf flash-attention ; \ + else \ + echo "Skipping Flash Attention compilation (FLASH_ATTENTION_TRITON_AMD_ENABLE and/or GPU_ARCHS not set)." ; \ + fi + +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE=${FLASH_ATTENTION_TRITON_AMD_ENABLE} +ENV GPU_ARCHS=${GPU_ARCHS} + ARG COMMON_WORKDIR WORKDIR ${COMMON_WORKDIR} @@ -27,9 +49,9 @@ FROM base AS fetch_vllm_1 ARG VLLM_REPO="https://github.com/vllm-project/vllm.git" ARG VLLM_BRANCH="main" ONBUILD RUN git clone ${VLLM_REPO} \ - && cd vllm \ - && git fetch -v --prune -- origin ${VLLM_BRANCH} \ - && git checkout FETCH_HEAD \ + && cd vllm \ + && git fetch -v --prune -- origin ${VLLM_BRANCH} \ + && git checkout FETCH_HEAD \ && if [ ${VLLM_REPO} != "https://github.com/vllm-project/vllm.git" ] ; then \ git remote add upstream "https://github.com/vllm-project/vllm.git" \ && git fetch upstream ; fi diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 17e025155a43..c7c53bfacce7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -48,9 +48,10 @@ ) if current_platform.is_rocm(): - from vllm.platforms.rocm import on_gfx9 + from vllm.platforms.rocm import on_gfx9, on_gfx1x else: on_gfx9 = lambda *args, **kwargs: False + on_gfx1x = lambda *args, **kwargs: False FP8_DTYPE = current_platform.fp8_dtype() @@ -103,13 +104,25 @@ def maybe_get_vit_flash_attn_backend( use_upstream_fa: bool, attn_backend_override: _Backend | None = None, ) -> tuple[_Backend, Callable | None]: + import os + from importlib.util import find_spec + if current_platform.is_rocm(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): attn_backend = _Backend.ROCM_AITER_FA + elif ( + os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") == "TRUE" + and os.environ.get("GPU_ARCHS") == "gfx1100" + and find_spec("flash_attn") is not None + and on_gfx1x() + and attn_backend_override is None + ): + attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + elif ( check_upstream_fa_availability(torch.get_default_dtype()) - and on_gfx9() and attn_backend_override is None ): attn_backend = _Backend.FLASH_ATTN diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 6cefe7441668..84a0817c3058 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -14,8 +14,15 @@ import einops import torch +import vllm.envs as envs from vllm.utils.torch_utils import direct_register_custom_op +from vllm.platforms import current_platform + +if current_platform.is_rocm(): + from vllm.platforms.rocm import on_gfx9, on_gfx1x +else: + on_gfx9 = lambda *args, **kwargs: False def xformers_attn_seqlens_wrapper( @@ -61,10 +68,14 @@ def flash_attn_maxseqlen_wrapper( cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, batch_size: int, - is_rocm_aiter: bool, use_upstream_fa: bool, ) -> torch.Tensor: - if is_rocm_aiter: + if ( + current_platform.is_rocm() + and on_gfx9() + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_MHA + ): from aiter import flash_attn_varlen_func else: if use_upstream_fa: @@ -96,7 +107,6 @@ def flash_attn_maxseqlen_wrapper_fake( cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, batch_size: int, - is_rocm_aiter: bool, use_upstream_fa: bool, ) -> torch.Tensor: b, s, h, d = q.shape @@ -117,9 +127,8 @@ def vit_flash_attn_wrapper( cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, batch_size: int, - is_rocm_aiter: bool, use_upstream_fa: bool, ) -> torch.Tensor: return torch.ops.vllm.flash_attn_maxseqlen_wrapper( - q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa + q, k, v, cu_seqlens, max_seqlen, batch_size, use_upstream_fa ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 3585783e4ccc..9e5b1205b741 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -422,6 +422,14 @@ def forward( q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: + from importlib.util import find_spec + + if not isinstance(max_seqlen, torch.Tensor): + max_seqlen = torch.tensor( + max_seqlen, device=q.device, dtype=torch.int32 + ) + self.use_upstream_fa = find_spec("flash_attn") is not None + context_layer = vit_flash_attn_wrapper( q, k, @@ -429,7 +437,6 @@ def forward( cu_seqlens, max_seqlen, batch_size, - self.attn_backend == _Backend.ROCM_AITER_FA, self.use_upstream_fa, ) elif self.attn_backend == _Backend.TORCH_SDPA: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 0c03a5564db8..329dddbfb554 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -77,6 +77,7 @@ "0x74b9": "AMD_Instinct_MI325X", # MI325X VF "0x74a9": "AMD_Instinct_MI300X_HF", "0x74bd": "AMD_Instinct_MI300X_HF", + "ox744c": "AMD_7900XTX_RDNA3" } # Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES` @@ -202,6 +203,7 @@ class RocmPlatform(Platform): @classmethod def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + import os from importlib.util import find_spec from vllm.attention.backends.registry import _Backend @@ -212,6 +214,18 @@ def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": if on_gfx9() and find_spec("flash_attn") is not None: return _Backend.FLASH_ATTN + if ( + os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") == "TRUE" + and os.environ.get("GPU_ARCHS") == "gfx1100" + and find_spec("flash_attn") is not None + and on_gfx1x() + ): + logger.info( + "Using ViT FlashAttention (upstream) on V1 engine (gfx1x / RDNA3)." + ) + return _Backend.FLASH_ATTN + + logger.info("Using Vit TORCH_SDPA V1 engine") return _Backend.TORCH_SDPA @classmethod