Skip to content
Merged
9 changes: 5 additions & 4 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ ENV UV_LINK_MODE=copy
# Verify GCC version
RUN gcc --version

# Ensure CUDA compatibility library is loaded
RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/00-cuda-compat.conf && ldconfig
# Workaround for triton/pytorch issues
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/

# ============================================================
# SLOW-CHANGING DEPENDENCIES BELOW
Expand Down Expand Up @@ -423,6 +423,7 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
libsm6 \
libxext6 \
libgl1 \
git \
&& if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
mkdir -p -m 0755 /etc/apt/keyrings ; \
Expand Down Expand Up @@ -473,8 +474,8 @@ ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
ENV UV_LINK_MODE=copy

# Ensure CUDA compatibility library is loaded
RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/00-cuda-compat.conf && ldconfig
# Workaround for triton/pytorch issues
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/

# ============================================================
# SLOW-CHANGING DEPENDENCIES BELOW
Expand Down
2 changes: 2 additions & 0 deletions requirements/cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ torchaudio==2.9.1
torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
# FlashInfer should be updated together with the Dockerfile
flashinfer-python==0.5.3
# FA4
flash-attn-cute @ git+https://github.com/Dao-AILab/flash-attention.git@2580b5a4882562640f3cfbffd2bb8d2de9268f9f#subdirectory=flash_attn/cute
12 changes: 11 additions & 1 deletion vllm/config/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,17 @@ def compute_hash(self) -> str:
def validate_backend_before(cls, value: Any) -> Any:
"""Enable parsing of the `backend` enum type from string."""
if isinstance(value, str):
return AttentionBackendEnum[value.upper()]
value = AttentionBackendEnum[value.upper()]

# Disallow ViT-only attention tags in the KV-cache attention config.
if value == AttentionBackendEnum.FLASH_ATTN_CUTE:
raise ValueError(
"AttentionConfig.backend does not support FLASH_ATTN_CUTE "
"(FA4 / flash_attn.cute). This is a ViT/MM-encoder-only attention "
"tag. Use --mm-encoder-attn-backend / "
"MultiModalConfig.mm_encoder_attn_backend instead."
)

return value

def _set_from_env_if_set(self, field_name: str, env_var_name: str) -> None:
Expand Down
41 changes: 40 additions & 1 deletion vllm/model_executor/layers/attention/mm_encoder_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.vit_attn_wrappers import (
vit_fa4_flash_attn_wrapper,
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
)
Expand Down Expand Up @@ -79,6 +80,8 @@ def __init__(
AttentionBackendEnum.ROCM_AITER_FA,
}

self.is_fa4_backend = self.attn_backend == AttentionBackendEnum.FLASH_ATTN_CUTE

self._fa_version = (
get_flash_attn_version() if self.is_flash_attn_backend else None
)
Expand Down Expand Up @@ -182,6 +185,40 @@ def _forward_fa(
output = output.reshape(bsz, q_len, -1)
return output

def _forward_fa4(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
"""FA4 (flash_attn.cute) attention for multimodal encoder (no KV cache)."""
assert (cu_seqlens is not None and max_seqlen is not None) or (
cu_seqlens is None and max_seqlen is None
), "cu_seqlens and max_seqlen should be both set or both None."

bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4

query, key, value = self.maybe_reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)

output = vit_fa4_flash_attn_wrapper(
q=query,
k=key,
v=value,
batch_size=bsz,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)
return output

def forward_native(
self,
query: torch.Tensor,
Expand All @@ -200,7 +237,9 @@ def forward_cuda(
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
if self.is_flash_attn_backend:
if self.is_fa4_backend:
return self._forward_fa4(query, key, value, cu_seqlens, max_seqlen)
elif self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/layers/rotary_embedding/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ def __init__(

self.apply_rotary_emb_flash_attn = None
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
try:
from flash_attn.ops.triton.rotary import apply_rotary
except ImportError:
apply_rotary = None

self.apply_rotary_emb_flash_attn = apply_rotary

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def __init__(

if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASH_ATTN_CUTE,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
Expand Down Expand Up @@ -785,6 +786,7 @@ def compute_attn_mask_seqlen(
max_seqlen = torch.zeros([], device=cu_seqlens.device)
if self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASH_ATTN_CUTE,
AttentionBackendEnum.ROCM_AITER_FA,
}:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def __init__(

if self.attn_backend not in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASH_ATTN_CUTE,
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.ROCM_AITER_FA,
}:
Expand Down Expand Up @@ -538,6 +539,7 @@ def compute_attn_mask_seqlen(
max_seqlen = torch.zeros([], device=cu_seqlens.device)
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
or self.attn_backend == AttentionBackendEnum.FLASH_ATTN_CUTE
or self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA
):
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
Expand Down
148 changes: 148 additions & 0 deletions vllm/model_executor/warmup/fa4_warmup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Warmup FA4 (flash_attn.cute) kernels for ViT/MM encoder attention.

We specifically warm up the FlashAttention Cute-DSL (FA4) compile cache by
running a few representative varlen attention calls that differ only in
sequence length. This helps avoid JIT compilation in the hot path.

This warmup is:
- Blackwell-only (compute capability 10.x)
- Opt-in (only when mm_encoder_attn_backend == FLASH_ATTN_CUTE)
- Scoped to Qwen3-VL / Qwen3-VL-MoE vision transformer workloads
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.registry import AttentionBackendEnum

if TYPE_CHECKING:
from vllm.v1.worker.gpu_worker import Worker

logger = init_logger(__name__)


def _get_default_qwen3_vit_warmup_seqlens(
max_positions: int | None = None,
) -> list[int]:
candidates = [
16**2, # 256
24**2, # 576
32**2, # 1024
48**2, # 2304
64**2, # 4096
96**2, # 9216
128**2, # 16384
192**2, # 36864
256**2, # 65536
]
if max_positions is None:
return candidates
return [s for s in candidates if s <= max_positions]


def should_fa4_vit_warmup(worker: Worker) -> bool:
"""Fast predicate used by `kernel_warmup` to gate FA4 warmup."""
if not current_platform.is_cuda():
return False
cc = current_platform.get_device_capability()
if cc is None or cc.major != 10:
return False

mm_cfg = getattr(worker.model_config, "multimodal_config", None)
return (
mm_cfg is not None
and mm_cfg.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN_CUTE
)


def fa4_vit_warmup(worker: Worker) -> None:
"""Warm up FA4 kernels for Qwen3-VL(-MoE) ViT attention."""

# Config gating: only warm up when explicitly selected for mm encoder.
if not should_fa4_vit_warmup(worker):
return

# Dependency gating.
from vllm.v1.attention.backends.fa4_utils import (
is_flash_attn_cute_available,
supports_dtype,
warn_if_unoptimized_head_size,
)

if not is_flash_attn_cute_available():
logger.warning(
"Skipping FA4 warmup: `flash_attn.cute.interface` is not available."
)
return

model = worker.get_model()
visual = getattr(model, "visual", None)
if visual is None:
# Not a Qwen3-VL(-MoE) style model, or vision tower disabled.
logger.warning("Skipping FA4 warmup: vision tower disabled or not found.")
return

# Derive head shape and dtype from the actual vision attention module.
try:
first_attn = visual.blocks[0].attn # Qwen2_5_VisionAttention
head_size = int(first_attn.hidden_size_per_attention_head)
num_heads = int(first_attn.num_attention_heads_per_partition)
scale = float(first_attn.hidden_size_per_attention_head**-0.5)
dtype = visual.dtype
except Exception:
# If the model structure is unexpected, skip warmup.
return

if not supports_dtype(dtype):
# If dtype is not supported, the FA4 backend should not have been selected.
logger.warning_once(
"Skipping FA4 warmup: dtype %s is not supported by flash_attn.cute.",
dtype,
)
return

warn_if_unoptimized_head_size(head_size)

seqlens = tuple(_get_default_qwen3_vit_warmup_seqlens())

logger.info_once(
"Warming up FA4 (flash_attn.cute) ViT kernels for seqlens=%s "
"(head_size=%d, num_heads=%d, dtype=%s).",
seqlens,
head_size,
num_heads,
dtype,
)

# Run a small number of representative calls that only vary seqlen.
# Compilation key can be found under `flash_attn/cute/interface.py`.
from vllm.v1.attention.backends.fa4_utils import flash_attn_varlen_func

device = torch.device("cuda")
with torch.inference_mode():
for seqlen in seqlens:
q = torch.empty((seqlen, num_heads, head_size), device=device, dtype=dtype)
k = torch.empty_like(q)
v = torch.empty_like(q)
cu = torch.tensor([0, seqlen], device=device, dtype=torch.int32)

# This call will populate FA4's internal compile cache (Cute-DSL).
_ = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu,
cu_seqlens_k=cu,
max_seqlen_q=seqlen,
max_seqlen_k=seqlen,
softmax_scale=scale,
causal=False,
)
5 changes: 5 additions & 0 deletions vllm/model_executor/warmup/kernel_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.warmup.deep_gemm_warmup import deep_gemm_warmup
from vllm.model_executor.warmup.fa4_warmup import fa4_vit_warmup, should_fa4_vit_warmup
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import is_deep_gemm_supported
from vllm.utils.flashinfer import has_flashinfer
Expand Down Expand Up @@ -72,6 +73,10 @@ def _is_flashinfer_backend(backend):
create_mixed_batch=True,
)

# FA4 (flash_attn.cute) warmup for ViT/MM encoder attention.
if should_fa4_vit_warmup(worker):
fa4_vit_warmup(worker)


def flashinfer_autotune(runner: "GPUModelRunner") -> None:
"""
Expand Down
35 changes: 35 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def get_attn_backend_cls(
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.TORCH_SDPA,
AttentionBackendEnum.FLASH_ATTN_CUTE,
AttentionBackendEnum.FLASH_ATTN,
]

Expand All @@ -371,11 +372,45 @@ def get_vit_attn_backend(
dtype: torch.dtype,
backend: Optional["AttentionBackendEnum"] = None,
) -> "AttentionBackendEnum":
cc = cls.get_device_capability()

if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)

# FA4 is Blackwell-only and opt-in (via --mm-encoder-attn-backend).
if backend == AttentionBackendEnum.FLASH_ATTN_CUTE:
if cc is None or cc.major != 10:
raise ValueError(
"FLASH_ATTN_CUTE (FA4 / flash_attn.cute) is only supported on "
"Blackwell GPUs (compute capability 10.x)."
)

from vllm.v1.attention.backends.fa4_utils import (
is_flash_attn_cute_available,
warn_if_unoptimized_head_size,
)
from vllm.v1.attention.backends.fa4_utils import (
supports_dtype as fa4_supports_dtype,
)

if not fa4_supports_dtype(dtype):
raise ValueError(
"FLASH_ATTN_CUTE (FA4 / flash_attn.cute) only supports "
"float16/bfloat16 for ViT attention."
)

if not is_flash_attn_cute_available():
raise ImportError(
"FLASH_ATTN_CUTE (FA4 / flash_attn.cute) selected, but "
"`flash_attn.cute.interface` is not available in this "
"environment."
)

warn_if_unoptimized_head_size(head_size)

logger.info_once(f"Using backend {backend} for vit attention")
return backend

Expand Down
Loading