diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index 4039316c36c4..20502cbf0feb 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -4,6 +4,7 @@ from typing import Any from vllm.logger import init_logger +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.platforms import current_platform logger = init_logger(__name__) @@ -111,6 +112,16 @@ def get_flash_attn_version( ) fa_version = 2 + # FA4 currently uses batch-shape-dependent scheduling + # heuristics on SM100+, which breaks batch invariance. + if vllm_is_batch_invariant() and fa_version == 4: + logger.warning_once( + "Cannot use FA version 4 with batch invariance, " + "defaulting to FA version 2.", + scope="local", + ) + fa_version = 2 + # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict # supported head dimensions. # See: https://github.com/Dao-AILab/flash-attention/issues/1959