From e87f29ca044d70f39b6beaec34838a0165459452 Mon Sep 17 00:00:00 2001 From: JartX Date: Mon, 20 Oct 2025 11:09:23 +0200 Subject: [PATCH 1/7] fixbug vit_flash_attn_on_rocm_no_gfx9 use rocm selection for it Signed-off-by: JartX --- vllm/attention/layer.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a028be6ce7f8..085625dea2e8 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -46,6 +46,12 @@ SlidingWindowSpec, ) +if current_platform.is_rocm(): + from vllm.platforms.rocm import on_gfx9 +else: + on_gfx9 = lambda *args, **kwargs: False + + FP8_DTYPE = current_platform.fp8_dtype() logger = init_logger(__name__) USE_XFORMERS_OPS = None @@ -94,13 +100,21 @@ def check_upstream_fa_availability(dtype: torch.dtype): def maybe_get_vit_flash_attn_backend( attn_backend: _Backend, use_upstream_fa: bool ) -> tuple[_Backend, Callable]: - if ( - attn_backend != _Backend.FLASH_ATTN - and attn_backend != _Backend.ROCM_AITER_FA - and check_upstream_fa_availability(torch.get_default_dtype()) - ): - attn_backend = _Backend.FLASH_ATTN - use_upstream_fa = True + if current_platform.is_rocm(): + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): + attn_backend = _Backend.ROCM_AITER_FA + elif on_gfx9(): + attn_backend = _Backend.FLASH_ATTN + else: + return _Backend.TORCH_SDPA, None + else: + if ( + attn_backend != _Backend.FLASH_ATTN + and attn_backend != _Backend.ROCM_AITER_FA + and check_upstream_fa_availability(torch.get_default_dtype()) + ): + attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN: use_upstream_fa = True From c8735e798fc090f1ec00576be3041a19c8c05695 Mon Sep 17 00:00:00 2001 From: JartX Date: Mon, 20 Oct 2025 11:40:39 +0200 Subject: [PATCH 2/7] refactor precommit Signed-off-by: JartX --- vllm/attention/layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 085625dea2e8..4c9108128d16 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -99,7 +99,7 @@ def check_upstream_fa_availability(dtype: torch.dtype): def maybe_get_vit_flash_attn_backend( attn_backend: _Backend, use_upstream_fa: bool -) -> tuple[_Backend, Callable]: +) -> tuple[_Backend, Callable | None]: if current_platform.is_rocm(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): attn_backend = _Backend.ROCM_AITER_FA @@ -548,6 +548,7 @@ def forward( value = torch.repeat_interleave(value, num_repeat, dim=2) if self.is_flash_attn_backend: + assert self._flash_attn_varlen_func is not None cu_seqlens_q = torch.arange( 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=query.device ) From 633792cdd1a38343a92a5d6e0d91ed54dc1e0d16 Mon Sep 17 00:00:00 2001 From: JartX Date: Wed, 22 Oct 2025 13:34:25 +0200 Subject: [PATCH 3/7] refactor maybe_get_vit_flash_attn_backend Signed-off-by: JartX --- vllm/attention/layer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 1202c6c958a2..44a4f160876e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -104,21 +104,21 @@ def maybe_get_vit_flash_attn_backend( if current_platform.is_rocm(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): attn_backend = _Backend.ROCM_AITER_FA - elif on_gfx9(): + + elif check_upstream_fa_availability(torch.get_default_dtype()) and on_gfx9(): attn_backend = _Backend.FLASH_ATTN + use_upstream_fa = True else: return _Backend.TORCH_SDPA, None - else: - if ( - attn_backend != _Backend.FLASH_ATTN - and attn_backend != _Backend.ROCM_AITER_FA - and check_upstream_fa_availability(torch.get_default_dtype()) + + elif current_platform.is_cuda(): + if attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability( + torch.get_default_dtype() ): attn_backend = _Backend.FLASH_ATTN use_upstream_fa = True - - if current_platform.is_rocm() and attn_backend == _Backend.FLASH_ATTN: - use_upstream_fa = True + else: + return _Backend.TORCH_SDPA, None if attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}: if attn_backend == _Backend.ROCM_AITER_FA: From 493f6837b1843c06294144724f7e7024239f096e Mon Sep 17 00:00:00 2001 From: JartX Date: Wed, 22 Oct 2025 17:32:21 +0200 Subject: [PATCH 4/7] flassh_attn on rocm if on_gfx9() and find_spec(flash_attn) Signed-off-by: JartX --- vllm/platforms/rocm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 7aab0b76aa06..933ada5f0bf6 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -202,12 +202,16 @@ class RocmPlatform(Platform): @classmethod def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": + from importlib.util import find_spec + from vllm.attention.backends.registry import _Backend if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): return _Backend.ROCM_AITER_FA - if on_gfx9(): + + if on_gfx9() and find_spec("flash_attn") is not None: return _Backend.FLASH_ATTN + return _Backend.TORCH_SDPA @classmethod From f8aa45f061243bfda39bff6d4d09788b16926ba1 Mon Sep 17 00:00:00 2001 From: JartX Date: Fri, 24 Oct 2025 09:31:38 +0200 Subject: [PATCH 5/7] pre-commit Signed-off-by: JartX --- vllm/attention/layer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 2af2126abe43..5e672cf0714d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -99,14 +99,19 @@ def check_upstream_fa_availability(dtype: torch.dtype): def maybe_get_vit_flash_attn_backend( - attn_backend: _Backend, use_upstream_fa: bool, + attn_backend: _Backend, + use_upstream_fa: bool, attn_backend_override: _Backend | None = None, ) -> tuple[_Backend, Callable | None]: if current_platform.is_rocm(): if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): attn_backend = _Backend.ROCM_AITER_FA - elif check_upstream_fa_availability(torch.get_default_dtype()) and on_gfx9() and attn_backend_override is None: + elif ( + check_upstream_fa_availability(torch.get_default_dtype()) + and on_gfx9() + and attn_backend_override is None + ): attn_backend = _Backend.FLASH_ATTN use_upstream_fa = True else: From 0dabbf1c94c63094f2c57669fe766ad42fd7eed8 Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 25 Oct 2025 18:06:36 +0200 Subject: [PATCH 6/7] qwen2vl and qwen2.5vl contiguous on rocm and torch.sdpa Co-authored-by: tjtanaa Signed-off-by: JartX --- vllm/model_executor/models/qwen2_5_vl.py | 5 +++++ vllm/model_executor/models/qwen2_vl.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 1b3ce3edd47b..07da3536afd3 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -423,6 +423,11 @@ def forward( ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. + from vllm.platforms import current_platform + if current_platform.is_rocm(): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() outputs = [] for i in range(1, len(cu_seqlens)): start_idx = cu_seqlens[i - 1] diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 94436fe009f1..bf610ecd3a43 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -453,6 +453,11 @@ def forward( ).contiguous() elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. + from vllm.platforms import current_platform + if current_platform.is_rocm(): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() outputs = [] for i in range(1, len(cu_seqlens)): start_idx = cu_seqlens[i - 1] From c42f2d19f48e66f7bb2f44b18dba4b381435f4e4 Mon Sep 17 00:00:00 2001 From: JartX Date: Sat, 25 Oct 2025 18:19:32 +0200 Subject: [PATCH 7/7] precommit Signed-off-by: JartX precommit Signed-off-by: JartX --- vllm/model_executor/models/qwen2_5_vl.py | 1 + vllm/model_executor/models/qwen2_vl.py | 1 + 2 files changed, 2 insertions(+) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 07da3536afd3..3deca805c349 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -424,6 +424,7 @@ def forward( elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform + if current_platform.is_rocm(): q = q.contiguous() k = k.contiguous() diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index bf610ecd3a43..c3d265235e2a 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -454,6 +454,7 @@ def forward( elif self.attn_backend == _Backend.TORCH_SDPA: # Execute attention entry by entry for speed & less VRAM. from vllm.platforms import current_platform + if current_platform.is_rocm(): q = q.contiguous() k = k.contiguous()