Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/config/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class AttentionConfig:
"""Force vllm to use a specific flash-attention version (2, 3, or 4).
Only valid when using the flash-attention backend."""

use_system_flash_attn: bool = False
"""Force vllm to use the Flash Attention library installed in the system instead of
the one bundled with vllm. Currently only supports Flash Attention 4"""

use_prefill_decode_attention: bool = False
"""Use separate prefill and decode kernels for attention instead of
the unified triton kernel."""
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,10 @@
MLAAttentionImpl,
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.fa_utils import (
get_flash_attn_version,
should_use_system_flash_attn,
)
from vllm.v1.attention.backends.utils import (
get_dcp_local_seq_lens,
get_per_layer_parameters,
Expand Down Expand Up @@ -2202,6 +2205,7 @@ def _flash_attn_varlen_diff_headdims(
kwargs["return_attn_probs"] = return_softmax_lse
if envs.VLLM_BATCH_INVARIANT:
kwargs["num_splits"] = 1
kwargs["use_system_flash_attn"] = should_use_system_flash_attn()

attn_out = self.flash_attn_varlen_func(
q=q,
Expand Down
33 changes: 33 additions & 0 deletions vllm/v1/attention/backends/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,27 @@ def get_flash_attn_version(
return None


def should_use_system_flash_attn() -> bool:
"""Check if the system flash-attn library should be used based on config/env."""
if not current_platform.is_cuda():
return False
from vllm.config import get_current_vllm_config_or_none

vllm_config = get_current_vllm_config_or_none()
if vllm_config is not None and vllm_config.attention_config.use_system_flash_attn:
if get_flash_attn_version() != 4:
logger.warning_once(
f"System Flash Attention is only compatible with Flash Attention 4 "
f"but the detected version is {get_flash_attn_version()}. "
f"Disabling system flash attention.",
scope="local",
)
return False
else:
return True
return False


def flash_attn_supports_fp8() -> bool:
return (
get_flash_attn_version() == 3
Expand Down Expand Up @@ -207,6 +228,18 @@ def is_flash_attn_varlen_func_available() -> bool:
Returns:
bool: True if a working flash_attn_varlen_func implementation is available.
"""
if should_use_system_flash_attn():
# Attempt to import flash_attn_varlen_func from system flash-attn
# Currently only supports Flash Attention 4
import importlib.util

if not importlib.util.find_spec("flash_attn"):
logger.warning_once(
"attention-config.use_system_flash_attn is set, but "
"failed to import system flash_attn. "
)
return False
return True
if current_platform.is_cuda() or current_platform.is_xpu():
# CUDA and XPU always have flash_attn_varlen_func available
return True
Expand Down
13 changes: 12 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
flash_attn_supports_fp8,
get_flash_attn_version,
is_flash_attn_varlen_func_available,
should_use_system_flash_attn,
)
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
Expand Down Expand Up @@ -171,7 +172,11 @@ def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:

@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size % 8 == 0 and head_size <= 256
if head_size % 8 != 0:
return False
if should_use_system_flash_attn():
return head_size <= 512
return head_size <= 256

@classmethod
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
Expand Down Expand Up @@ -624,6 +629,7 @@ def __init__(
self.vllm_flash_attn_version,
scope="local",
)
self.use_system_flash_attn = should_use_system_flash_attn()
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = envs.VLLM_BATCH_INVARIANT

Expand Down Expand Up @@ -790,6 +796,7 @@ def forward(
v_descale=v_descale,
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
use_system_flash_attn=self.use_system_flash_attn,
)
return output

Expand Down Expand Up @@ -1017,6 +1024,7 @@ def _forward_encoder_attention(
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=1 if self.batch_invariant_enabled else 0,
use_system_flash_attn=self.use_system_flash_attn,
)

return output
Expand Down Expand Up @@ -1125,6 +1133,7 @@ def cascade_attention(
k_descale: torch.Tensor | None = None,
v_descale: torch.Tensor | None = None,
s_aux: torch.Tensor | None = None,
use_system_flash_attn: bool = False,
) -> torch.Tensor:
assert alibi_slopes is None, "Cascade attention does not support ALiBi."
# TODO: Support sliding window.
Expand Down Expand Up @@ -1163,6 +1172,7 @@ def cascade_attention(
# enabling its effect during the final attention merge.
s_aux=s_aux,
num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits,
use_system_flash_attn=use_system_flash_attn,
)

descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
Expand All @@ -1188,6 +1198,7 @@ def cascade_attention(
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits,
use_system_flash_attn=use_system_flash_attn,
)

# Merge prefix and suffix outputs, and store the result in output.
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_mla,
get_flash_attn_version,
should_use_system_flash_attn,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
Expand Down Expand Up @@ -301,6 +302,7 @@ def __init__(
raise NotImplementedError(
"FlashAttnMLA V1 with FP8 KV cache not yet supported"
)
self.use_system_flash_attn = should_use_system_flash_attn()

def forward_mqa(
self,
Expand Down Expand Up @@ -349,6 +351,7 @@ def forward_mqa(
cp_world_size=self.dcp_world_size,
cp_rank=self.dcp_rank,
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
use_system_flash_attn=self.use_system_flash_attn,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In forward_mqa, fa_version is hardcoded to 3 (at line 359). However, the PR adds use_system_flash_attn=self.use_system_flash_attn. If use_system_flash_attn is True, flash_attn_varlen_func will raise a NotImplementedError because it explicitly forbids using system Flash Attention with fa_version=3. To support system Flash Attention 4 for MLA, the fa_version parameter must be updated to 4 when use_system_flash_attn is enabled. Additionally, the FA4 interface in flash_attn_interface.py must be updated to pass the q_v parameter required by MLA.

)

if self.need_to_return_lse_for_decode:
Expand Down
30 changes: 24 additions & 6 deletions vllm/vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def flash_attn_varlen_func(
cp_world_size=1,
cp_rank=0,
cp_tot_seqused_k=None,
use_system_flash_attn: bool = False,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -253,6 +254,7 @@ def flash_attn_varlen_func(
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
use_system_flash_attn: bool. Whether to use the Flash Attention library installed in the system.
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The
Expand Down Expand Up @@ -283,6 +285,10 @@ def flash_attn_varlen_func(
dummy_cu_seqlens_k = torch.empty_like(cu_seqlens_q)

if fa_version == 2:
if use_system_flash_attn:
raise NotImplementedError(
"Using system flash-attn is not supported for FA2. Please use FA4."
)
if (
scheduler_metadata is not None
and q_descale is not None
Expand Down Expand Up @@ -324,6 +330,10 @@ def flash_attn_varlen_func(
None,
)
elif fa_version == 3:
if use_system_flash_attn:
raise NotImplementedError(
"Using system flash-attn is not supported for FA3. Please use FA4."
)
assert alibi_slopes is None, "Alibi is not supported in FA3"
out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
q,
Expand Down Expand Up @@ -369,12 +379,20 @@ def flash_attn_varlen_func(
# FA4 on SM90 doesn't support paged KV; SM100+ does
from vllm.platforms import current_platform

if block_table is not None and current_platform.is_device_capability_family(90):
raise NotImplementedError(
"FA4 with paged KV is not supported on SM90 (Hopper). "
"Use FA3 or upgrade to Blackwell (SM100+)."
)
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd
if use_system_flash_attn:
try:
from flash_attn.cute.interface import _flash_attn_fwd
except ImportError:
use_system_flash_attn = False
if not use_system_flash_attn:
if block_table is not None and current_platform.is_device_capability_family(
90
):
raise NotImplementedError(
"FA4 with paged KV is not supported on SM90 (Hopper). "
"Use FA3 or upgrade to Blackwell (SM100+)."
)
from vllm.vllm_flash_attn.cute.interface import _flash_attn_fwd
Comment on lines +382 to +395
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Importing modules inside flash_attn_varlen_func introduces unnecessary overhead in the critical path of every attention operation. While Python caches imports, the repeated string lookups and try-except blocks can impact performance in high-throughput scenarios. These imports should be performed at the module level or cached. Furthermore, the FA4 implementation of _flash_attn_fwd (lines 397-415) is missing the q_v parameter, which is essential for MLA support. Without passing q_v, MLA models will produce incorrect results when using FA4.


out, softmax_lse = _flash_attn_fwd(
q,
Expand Down
Loading