Skip to content
15 changes: 4 additions & 11 deletions vllm/attention/utils/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,13 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:

# 2. override if passed by environment
if envs.VLLM_FLASH_ATTN_VERSION is not None:
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3]
assert envs.VLLM_FLASH_ATTN_VERSION in [2, 3, 4]
fa_version = envs.VLLM_FLASH_ATTN_VERSION

# 3. fallback for unsupported combinations
if device_capability.major == 10 and fa_version == 3:
if requires_alibi and fa_version in [3, 4]:
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform "
"defaulting to FA version 2."
)
fa_version = 2

if requires_alibi and fa_version == 3:
logger.warning_once(
"Cannot use FA version 3 with ALiBi, defaulting to FA version 2."
"Cannot use FA version %d with ALiBi, defaulting to FA version 2.",
fa_version,
)
fa_version = 2

Expand Down
6 changes: 6 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,12 @@ def get_attn_backend_cls(
# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100):
if envs.VLLM_FLASH_ATTN_VERSION == 4 and is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
):
logger.info_once("Using Flash Attention 4 backend on V1 engine.")
return FLASH_ATTN_V1

if is_default_backend_supported := is_attn_backend_supported(
FLASHINFER_V1, head_size, dtype
):
Expand Down
88 changes: 63 additions & 25 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,15 @@ def get_supported_dtypes(cls) -> list[torch.dtype]:

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
# FIXME (zyongye): change this until FA4 support more head_dim
if envs.VLLM_FLASH_ATTN_VERSION == 4:
return [64, 96, 128]
return [32, 64, 96, 128, 160, 192, 224, 256]

@staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
if envs.VLLM_FLASH_ATTN_VERSION == 4:
return [128]
Comment on lines 58 to +68

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Accept FA4 in flash attention version selection

These new branches rely on VLLM_FLASH_ATTN_VERSION == 4 to activate FA4-specific behavior, but get_flash_attn_version still asserts that the environment variable is only 2 or 3. Setting VLLM_FLASH_ATTN_VERSION=4 to reach this code path currently triggers an AssertionError during backend initialization, so the FA4 code here is unreachable and the feature cannot be enabled. The version-selection logic needs to be updated to admit 4 (and handle unsupported hardware) before these branches will ever execute.

Useful? React with 👍 / 👎.

return [MultipleOf(16)]

@classmethod
Expand Down Expand Up @@ -486,8 +491,8 @@ def __init__(

self.sinks = sinks
if self.sinks is not None:
assert self.vllm_flash_attn_version == 3, (
"Sinks are only supported in FlashAttention 3"
assert self.vllm_flash_attn_version in [3, 4], (
"Sinks are only supported in FlashAttention 3 and 4"
)
assert self.sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
Expand Down Expand Up @@ -606,6 +611,9 @@ def forward(
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)

if self.dcp_world_size > 1:
assert get_flash_attn_version() != 4, (
"Distributed Context Parallel doesn't support FA4 yet"
)
self._forward_with_dcp(
query[:num_actual_tokens],
key[:num_actual_tokens],
Expand All @@ -620,29 +628,59 @@ def forward(
)
return output
else:
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)
if envs.VLLM_FLASH_ATTN_VERSION == 4:
from flash_attn.cute.interface import _flash_attn_fwd

window_size = (
None
if self.sliding_window[0] == -1
else self.sliding_window[0],
None
if self.sliding_window[1] == -1
else self.sliding_window[1],
)
output, lse, *rest = _flash_attn_fwd(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=None,
seqused_q=None,
seqused_k=seqused_k,
page_table=block_table,
softmax_scale=self.scale,
causal=attn_metadata.causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
learnable_sink=self.sinks,
softcap=self.logits_soft_cap,
return_lse=False,
out=output[:num_actual_tokens],
)
else:
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)
return output

# Cascade attention (rare case).
Expand Down