Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ steps:

- label: V1 Test entrypoints # 35min
timeout_in_minutes: 50
mirror_hardwares: [amdexperimental]
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
source_file_dependencies:
Expand Down
30 changes: 19 additions & 11 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@
)

from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import (
check_upstream_fa_availability,
maybe_get_vit_flash_attn_backend,
)
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
from vllm.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_xformers_attn_wrapper,
Expand Down Expand Up @@ -318,6 +315,7 @@ def __init__(
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
Expand Down Expand Up @@ -358,8 +356,14 @@ def __init__(
maybe_get_vit_flash_attn_backend(
self.attn_backend,
self.use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used
from vllm.platforms import current_platform

if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
self.use_upstream_fa = True
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
Expand Down Expand Up @@ -484,6 +488,7 @@ def __init__(
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
if norm_layer is None:
Expand All @@ -499,6 +504,7 @@ def __init__(
use_data_parallel=use_data_parallel,
attn_backend=attn_backend,
use_upstream_fa=use_upstream_fa,
attn_backend_override=attn_backend_override,
)
self.mlp = Qwen2_5_VisionMLP(
dim,
Expand Down Expand Up @@ -698,13 +704,14 @@ def __init__(
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if (
self.attn_backend != _Backend.FLASH_ATTN
and self.attn_backend != _Backend.ROCM_AITER_FA
and check_upstream_fa_availability(torch.get_default_dtype())
):
self.attn_backend = _Backend.FLASH_ATTN
use_upstream_fa = True

self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa,
attn_backend_override=attn_backend_override,
)
)

if self.attn_backend not in {
_Backend.FLASH_ATTN,
Expand All @@ -730,6 +737,7 @@ def __init__(
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
Expand Down
Loading