From 22de4f2d285ff390abda463522f0a0d02b75ded1 Mon Sep 17 00:00:00 2001 From: JartX Date: Wed, 29 Oct 2025 23:37:35 +0100 Subject: [PATCH 1/7] fa_upstream_detection for rdna3 rocm Signed-off-by: JartX --- vllm/platforms/rocm.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d3535c9781c4..bd3bfcea21bb 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -72,7 +72,6 @@ "0x74a0": "AMD_Instinct_MI300A", "0x74a1": "AMD_Instinct_MI300X", "0x74b5": "AMD_Instinct_MI300X", # MI300X VF - "0x74a2": "AMD_Instinct_MI308X", "0x74a5": "AMD_Instinct_MI325X", "0x74b9": "AMD_Instinct_MI325X", # MI325X VF "0x74a9": "AMD_Instinct_MI300X_HF", @@ -206,6 +205,7 @@ class RocmPlatform(Platform): @classmethod def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + import os from importlib.util import find_spec from vllm.attention.backends.registry import _Backend @@ -216,6 +216,16 @@ def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": if on_gfx9() and find_spec("flash_attn") is not None: return _Backend.FLASH_ATTN + if ( + os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") == "TRUE" + and os.environ.get("GPU_ARCHS") == "gfx1100" + and find_spec("flash_attn") is not None + and on_gfx1x() + ): + logger.info("Using Vit FLASH_ATTN upstream V1 engine on gfx1x rdna3") + return _Backend.FLASH_ATTN + + logger.info("Using Vit TORCH_SDPA V1 engine") return _Backend.TORCH_SDPA @classmethod @@ -414,7 +424,7 @@ def verify_quantization(cls, quant: str) -> None: "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" " is not set, enabling VLLM_USE_TRITON_AWQ." ) - os.environ["VLLM_USE_TRITON_AWQ"] = "1" + envs.VLLM_USE_TRITON_AWQ = True @classmethod def get_punica_wrapper(cls) -> str: From 25028ebde0cd1bc1d93e1dc3ea88953baff53ca0 Mon Sep 17 00:00:00 2001 From: JartX Date: Thu, 30 Oct 2025 01:06:17 +0100 Subject: [PATCH 2/7] working_on_fa_rdna3 Signed-off-by: JartX --- docker/Dockerfile.rocm | 29 ++++++++++++++++++++---- vllm/attention/layer.py | 17 ++++++++++++-- vllm/attention/ops/vit_attn_wrappers.py | 14 +++++++++++- vllm/model_executor/models/qwen2_5_vl.py | 10 +++++++- vllm/platforms/rocm.py | 4 +++- 5 files changed, 64 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index adb0879f20d4..499ff96a17b2 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -11,10 +11,29 @@ ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH:-${PYTORCH_ROCM_ARCH}} # Install some basic utilities RUN apt-get update -q -y && apt-get install -q -y \ sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev \ - apt-transport-https ca-certificates wget curl + apt-transport-https ca-certificates wget curl git + # Remove sccache RUN python3 -m pip install --upgrade pip RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" + +# BUILD FA only if both FLASH_ATTENTION_TRITON_AMD_ENABLE and GPU_ARCHS are passed/declared and non-empty +ARG FLASH_ATTENTION_TRITON_AMD_ENABLE +ARG GPU_ARCHS + +RUN if [ -n "${FLASH_ATTENTION_TRITON_AMD_ENABLE}" ] && [ -n "${GPU_ARCHS}" ]; then \ + echo "Compiling Flash Attention with GPU_ARCHS=${GPU_ARCHS}..." ; \ + export FLASH_ATTENTION_TRITON_AMD_ENABLE="${FLASH_ATTENTION_TRITON_AMD_ENABLE}"; \ + export GPU_ARCHS="${GPU_ARCHS}"; \ + git clone --single-branch --branch main_perf https://github.com/ROCm/flash-attention.git \ + && cd flash-attention \ + && python3 setup.py install \ + && cd .. \ + && rm -rf flash-attention ; \ + else \ + echo "Skipping Flash Attention compilation (FLASH_ATTENTION_TRITON_AMD_ENABLE and/or GPU_ARCHS not set)." ; \ + fi + ARG COMMON_WORKDIR WORKDIR ${COMMON_WORKDIR} @@ -27,9 +46,9 @@ FROM base AS fetch_vllm_1 ARG VLLM_REPO="https://github.com/vllm-project/vllm.git" ARG VLLM_BRANCH="main" ONBUILD RUN git clone ${VLLM_REPO} \ - && cd vllm \ - && git fetch -v --prune -- origin ${VLLM_BRANCH} \ - && git checkout FETCH_HEAD \ + && cd vllm \ + && git fetch -v --prune -- origin ${VLLM_BRANCH} \ + && git checkout FETCH_HEAD \ && if [ ${VLLM_REPO} != "https://github.com/vllm-project/vllm.git" ] ; then \ git remote add upstream "https://github.com/vllm-project/vllm.git" \ && git fetch upstream ; fi @@ -116,4 +135,4 @@ ENV SAFETENSORS_FAST_GPU=1 # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 -CMD ["/bin/bash"] +CMD ["/bin/bash"] \ No newline at end of file diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 22eaa22b8b38..2bf463a2c8d3 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -48,9 +48,10 @@ ) if current_platform.is_rocm(): - from vllm.platforms.rocm import on_gfx9 + from vllm.platforms.rocm import on_gfx9, on_gfx1x else: on_gfx9 = lambda *args, **kwargs: False + on_gfx1x = lambda *args, **kwargs: False FP8_DTYPE = current_platform.fp8_dtype() @@ -103,13 +104,25 @@ def maybe_get_vit_flash_attn_backend( use_upstream_fa: bool, attn_backend_override: _Backend | None = None, ) -> tuple[_Backend, Callable | None]: + import os + from importlib.util import find_spec + if current_platform.is_rocm(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): attn_backend = _Backend.ROCM_AITER_FA + elif ( + os.environ.get("FLASH_ATTENTION_TRITON_AMD_ENABLE") == "TRUE" + and os.environ.get("GPU_ARCHS") == "gfx1100" + and find_spec("flash_attn") is not None + and on_gfx1x() + and attn_backend_override is None + ): + attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True + elif ( check_upstream_fa_availability(torch.get_default_dtype()) - and on_gfx9() and attn_backend_override is None ): attn_backend = _Backend.FLASH_ATTN diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index f71f49a1a31b..a792472fa8c5 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -14,8 +14,15 @@ import einops import torch +import vllm.envs as envs from vllm.utils.torch_utils import direct_register_custom_op +from vllm.platforms import current_platform + +if current_platform.is_rocm(): + from vllm.platforms.rocm import on_gfx9, on_gfx1x +else: + on_gfx9 = lambda *args, **kwargs: False def xformers_attn_seqlens_wrapper( @@ -64,7 +71,12 @@ def flash_attn_maxseqlen_wrapper( is_rocm_aiter: bool, use_upstream_fa: bool, ) -> torch.Tensor: - if is_rocm_aiter: + if ( + current_platform.is_rocm() + and on_gfx9 + and envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_MHA + ): from aiter import flash_attn_varlen_func else: if use_upstream_fa: diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 41cb7084057d..71c98d40aac8 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -416,6 +416,14 @@ def forward( q, k = torch.chunk(qk_rotated, 2, dim=0) if self.is_flash_attn_backend: + from importlib.util import find_spec + + if not isinstance(max_seqlen, torch.Tensor): + max_seqlen = torch.tensor( + max_seqlen, device=q.device, dtype=torch.int32 + ) + self.use_upstream_fa = find_spec("flash_attn") is not None + context_layer = vit_flash_attn_wrapper( q, k, @@ -423,7 +431,7 @@ def forward( cu_seqlens, max_seqlen, batch_size, - self.attn_backend == _Backend.ROCM_AITER_FA, + self.attn_backend == _Backend.FLASH_ATTN, self.use_upstream_fa, ) elif self.attn_backend == _Backend.TORCH_SDPA: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index bd3bfcea21bb..2d13d9923799 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -222,7 +222,9 @@ def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": and find_spec("flash_attn") is not None and on_gfx1x() ): - logger.info("Using Vit FLASH_ATTN upstream V1 engine on gfx1x rdna3") + logger.info( + "Using ViT FlashAttention (upstream) on V1 engine (gfx1x / RDNA3)." + ) return _Backend.FLASH_ATTN logger.info("Using Vit TORCH_SDPA V1 engine") From 58f0be71dc6da30ba38d87720ac1a1ffea912419 Mon Sep 17 00:00:00 2001 From: JartX Date: Thu, 30 Oct 2025 01:20:46 +0100 Subject: [PATCH 3/7] remove is_rocm_aiter Signed-off-by: JartX --- vllm/attention/ops/vit_attn_wrappers.py | 5 +---- vllm/model_executor/models/qwen2_5_vl.py | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index a792472fa8c5..1cd31d087fd9 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -68,7 +68,6 @@ def flash_attn_maxseqlen_wrapper( cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, batch_size: int, - is_rocm_aiter: bool, use_upstream_fa: bool, ) -> torch.Tensor: if ( @@ -108,7 +107,6 @@ def flash_attn_maxseqlen_wrapper_fake( cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, batch_size: int, - is_rocm_aiter: bool, use_upstream_fa: bool, ) -> torch.Tensor: b, s, h, d = q.shape @@ -129,9 +127,8 @@ def vit_flash_attn_wrapper( cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, batch_size: int, - is_rocm_aiter: bool, use_upstream_fa: bool, ) -> torch.Tensor: return torch.ops.vllm.flash_attn_maxseqlen_wrapper( - q, k, v, cu_seqlens, max_seqlen, batch_size, is_rocm_aiter, use_upstream_fa + q, k, v, cu_seqlens, max_seqlen, batch_size, use_upstream_fa ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 71c98d40aac8..5ea24d63a098 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -431,7 +431,6 @@ def forward( cu_seqlens, max_seqlen, batch_size, - self.attn_backend == _Backend.FLASH_ATTN, self.use_upstream_fa, ) elif self.attn_backend == _Backend.TORCH_SDPA: From 90d3b7c58e81b89e9504585e37eff2c932dce0d2 Mon Sep 17 00:00:00 2001 From: JartX Date: Thu, 30 Oct 2025 01:27:58 +0100 Subject: [PATCH 4/7] missing () on on_gfx9() Signed-off-by: JartX --- vllm/attention/ops/vit_attn_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 1cd31d087fd9..bfedd0829d00 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -72,7 +72,7 @@ def flash_attn_maxseqlen_wrapper( ) -> torch.Tensor: if ( current_platform.is_rocm() - and on_gfx9 + and on_gfx9() and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA ): From cf3682234230c37d40248cffbe995cd4021b7811 Mon Sep 17 00:00:00 2001 From: JartX Date: Thu, 30 Oct 2025 02:23:31 +0100 Subject: [PATCH 5/7] default FLASH_ATTENTION_TRITON_AMD_ENABLE GPU_ARCHS if passed on build-args Signed-off-by: JartX --- docker/Dockerfile.rocm | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 499ff96a17b2..208eea777dee 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -34,6 +34,9 @@ RUN if [ -n "${FLASH_ATTENTION_TRITON_AMD_ENABLE}" ] && [ -n "${GPU_ARCHS}" ]; t echo "Skipping Flash Attention compilation (FLASH_ATTENTION_TRITON_AMD_ENABLE and/or GPU_ARCHS not set)." ; \ fi +ENV FLASH_ATTENTION_TRITON_AMD_ENABLE=${FLASH_ATTENTION_TRITON_AMD_ENABLE} +ENV GPU_ARCHS=${GPU_ARCHS} + ARG COMMON_WORKDIR WORKDIR ${COMMON_WORKDIR} From 2af8555c3d3bd0a73744f603df68a3faaa9e4477 Mon Sep 17 00:00:00 2001 From: JartX Date: Thu, 30 Oct 2025 13:53:30 +0100 Subject: [PATCH 6/7] readd code and 7900XTX device id Signed-off-by: JartX --- vllm/platforms/rocm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 2d13d9923799..6c729cc41ebb 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -72,10 +72,12 @@ "0x74a0": "AMD_Instinct_MI300A", "0x74a1": "AMD_Instinct_MI300X", "0x74b5": "AMD_Instinct_MI300X", # MI300X VF + "0x74a2": "AMD_Instinct_MI308X", "0x74a5": "AMD_Instinct_MI325X", "0x74b9": "AMD_Instinct_MI325X", # MI325X VF "0x74a9": "AMD_Instinct_MI300X_HF", "0x74bd": "AMD_Instinct_MI300X_HF", + "ox744c": "AMD_7900XTX_RDNA3" } # Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES` @@ -426,7 +428,7 @@ def verify_quantization(cls, quant: str) -> None: "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" " is not set, enabling VLLM_USE_TRITON_AWQ." ) - envs.VLLM_USE_TRITON_AWQ = True + os.environ["VLLM_USE_TRITON_AWQ"] = "1" @classmethod def get_punica_wrapper(cls) -> str: From c03438b174b7c6bb89a6f51b77b2936f38c9e8fd Mon Sep 17 00:00:00 2001 From: JartX Date: Thu, 30 Oct 2025 13:54:00 +0100 Subject: [PATCH 7/7] readd code and 7900XTX device id Signed-off-by: JartX --- docker/Dockerfile.rocm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 208eea777dee..bb157c5b6074 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -138,4 +138,4 @@ ENV SAFETENSORS_FAST_GPU=1 # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 -CMD ["/bin/bash"] \ No newline at end of file +CMD ["/bin/bash"]