diff --git a/.github/workflows/self-scheduled-flash-attn-caller.yml b/.github/workflows/self-scheduled-flash-attn-caller.yml index ff990f2808b7..ccaf205456e1 100644 --- a/.github/workflows/self-scheduled-flash-attn-caller.yml +++ b/.github/workflows/self-scheduled-flash-attn-caller.yml @@ -56,5 +56,5 @@ jobs: runner_type: "a10" report_repo_id: hf-internal-testing/transformers_flash_attn_ci commit_sha: ${{ github.sha }} - pytest_marker: "flash_attn_test or flash_attn_3_test" + pytest_marker: "flash_attn_test or flash_attn_3_test or flash_attn_4_test or all_flash_attn_test" secrets: inherit diff --git a/conftest.py b/conftest.py index c194a058b1c4..a7d2341505a2 100644 --- a/conftest.py +++ b/conftest.py @@ -90,6 +90,10 @@ def pytest_configure(config): config.addinivalue_line("markers", "torch_export_test: mark test which tests torch export functionality") config.addinivalue_line("markers", "flash_attn_test: mark test which tests flash attention functionality") config.addinivalue_line("markers", "flash_attn_3_test: mark test which tests flash attention 3 functionality") + config.addinivalue_line("markers", "flash_attn_4_test: mark test which tests flash attention 4 functionality") + config.addinivalue_line( + "markers", "all_flash_attn_test: mark test which tests all mainline flash attentions' functionality" + ) config.addinivalue_line("markers", "training_ci: mark test for training CI validation") config.addinivalue_line("markers", "tensor_parallel_ci: mark test for tensor parallel CI validation") diff --git a/pyproject.toml b/pyproject.toml index 710f64032aa9..f7e72facf021 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,8 @@ line-ending = "auto" addopts = "--doctest-glob='**/*.md'" doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS" markers = [ + "all_flash_attn_test: marks tests related to all mainline flash attentions at once (deselect with '-m \"not all_flash_attn_test\"')", + "flash_attn_4_test: marks tests related to flash attention 4 (deselect with '-m \"not flash_attn_4_test\"')", "flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')", "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')", "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 69848473b7ae..922849e89547 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -732,6 +732,7 @@ class AttentionMaskInterface(GeneralInterface): "eager": eager_mask, "flash_attention_2": flash_attention_mask, "flash_attention_3": flash_attention_mask, + "flash_attention_4": flash_attention_mask, "flex_attention": flex_attention_mask, } diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index a7f3dc8286bd..72d432c06eea 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import inspect import os from collections.abc import Callable @@ -23,11 +24,14 @@ from .utils import ( is_flash_attn_2_available, is_flash_attn_3_available, - is_flash_attn_greater_or_equal_2_10, + is_flash_attn_4_available, + is_torch_cuda_available, + is_torch_mlu_available, is_torch_npu_available, is_torch_xpu_available, logging, ) +from .utils.import_utils import PACKAGE_DISTRIBUTION_MAPPING, is_tracing logger = logging.get_logger(__name__) @@ -35,10 +39,8 @@ # TODO Deprecate when all models have the attention interface def flash_attn_supports_top_left_mask(): - if is_flash_attn_3_available(): + if is_flash_attn_2_available() or is_flash_attn_3_available() or is_flash_attn_4_available(): return False - if is_flash_attn_2_available(): - return not is_flash_attn_greater_or_equal_2_10() from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask @@ -48,13 +50,63 @@ def flash_attn_supports_top_left_mask(): # TODO Deprecate when all models have the attention interface def is_flash_attn_available(): return ( - is_flash_attn_3_available() + is_flash_attn_4_available() + or is_flash_attn_3_available() or is_flash_attn_2_available() or is_torch_npu_available() or is_torch_xpu_available() ) +# Mapping from flash attention implementations to their kernel fallback repositories +FLASH_ATTN_KERNEL_FALLBACK = { + "flash_attention_2": "kernels-community/flash-attn2", + "flash_attention_3": "kernels-community/vllm-flash-attn3", +} + + +# Meta information on each mainline FA compatibility: +# 1. The import structure and availability +# 2. Device support (with custom ones that use other workarounds, e.g. kernels) +# 3. Supported major cuda devices, e.g. Hopper, Blackwell. Mostly found in the newest FA versions +FLASH_ATTENTION_COMPATIBILITY_MATRIX = { + 2: { + "flash_attn_version": 2, + "general_availability_check": is_flash_attn_2_available, + "pkg_availability_check": lambda *args, **kwargs: importlib.util.find_spec("flash_attn") is not None + and "flash_attn" in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"], + "supported_devices": ( + (is_torch_cuda_available, "cuda"), + (is_torch_mlu_available, "mlu"), + (is_torch_npu_available, "npu"), + (is_torch_xpu_available, "xpu"), + ), + "custom_supported_devices": ( + (is_torch_npu_available, "Detect using FlashAttention2 on Ascend NPU."), + ( + is_torch_xpu_available, + f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU.", + ), + ), + }, + 3: { + "flash_attn_version": 3, + "general_availability_check": is_flash_attn_3_available, + "pkg_availability_check": lambda *args, **kwargs: importlib.util.find_spec("flash_attn_3") is not None, + "supported_devices": ((is_torch_cuda_available, "cuda"),), + "cuda_min_major_version": 8, # Ampere + }, + 4: { + "flash_attn_version": 4, + "general_availability_check": is_flash_attn_4_available, + "pkg_availability_check": lambda *args, **kwargs: importlib.util.find_spec("flash_attn") is not None + and importlib.util.find_spec("flash_attn.cute") is not None, + "supported_devices": ((is_torch_cuda_available, "cuda"),), + "cuda_min_major_version": 9, # Hopper + }, +} + + # `globals()` is not compatible with dynamo, hence we have do define them in global scope ourselves _loaded_implementation = None _flash_fn = None @@ -70,6 +122,8 @@ def is_flash_attn_available(): "dropout": "dropout_p", "sliding_window": "window_size", } +# alternative names within the different flash attention APIs, e.g. for attention sinks +_flash_api_alternative_names = {"s_aux": "learnable_sink"} def _lazy_imports( @@ -87,13 +141,16 @@ def _lazy_imports( """ is_fa2 = is_flash_attn_2_available() is_fa3 = is_flash_attn_3_available() + is_fa4 = is_flash_attn_4_available() pad_input, unpad_input = _pad_input, _unpad_input is_paged = implementation.startswith("paged|") implementation = implementation.split("|")[1] if is_paged else implementation - if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3): + if (implementation == "flash_attention_2" and is_fa2) or ( + implementation is None and is_fa2 and not is_fa3 and not is_fa4 + ): from flash_attn import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache from flash_attn.bert_padding import pad_input, unpad_input elif is_torch_npu_available(): @@ -103,8 +160,12 @@ def _lazy_imports( from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func from .integrations.npu_flash_attention import npu_flash_attn_with_kvcache as flash_attn_with_kvcache else: - if implementation == "flash_attention_3" or (implementation is None and is_fa3): + if implementation == "flash_attention_3" or (implementation is None and is_fa3 and not is_fa4): from flash_attn_interface import flash_attn_func, flash_attn_varlen_func, flash_attn_with_kvcache + elif implementation == "flash_attention_4" or (implementation is None and is_fa4): + from flash_attn.cute import flash_attn_func, flash_attn_varlen_func + + flash_attn_with_kvcache = None # not supported yet # Kernels fallback else: from .integrations.hub_kernels import load_and_register_attn_kernel @@ -157,6 +218,9 @@ def _lazy_define_process_function(flash_function): fa_param = _hf_api_to_flash_mapping.get(param, param) supports_mapping[fa_param] = fa_param in flash_parameters + if (fa_alternative_name := _flash_api_alternative_names.get(param, param)) != fa_param: + supports_mapping[fa_alternative_name] = fa_alternative_name in flash_parameters + return partial(_process_flash_attention_kwargs, supports_mapping=supports_mapping) @@ -229,7 +293,7 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None): seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() + max_seqlen_in_batch = seqlens_in_batch.max() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( @@ -278,9 +342,7 @@ def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.T """ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - # NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile, - # this might cause a graph break - max_seqlen_in_batch = seqlens_in_batch.max().item() + max_seqlen_in_batch = seqlens_in_batch.max() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, @@ -401,12 +463,6 @@ def prepare_fa_kwargs_from_position_ids(position_ids): # We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing # for some models (e.g. qwen2-vl). max_length_q = cu_seq_lens_q.diff().max() - # NOTE: With torch compile, this will cause a graph break if you don't set - # `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call - # `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass. - # This is a limitation of flash attention API, as the function `flash_attn_varlen_func` - # requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. - max_length_q = max_length_q.item() max_length_k = max_length_q return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k) @@ -516,6 +572,8 @@ def _process_flash_attention_kwargs( softcap: float | None = None, deterministic: bool | None = None, s_aux: torch.Tensor | None = None, + max_seqlen_q: int | torch.IntTensor | None = None, + max_seqlen_k: int | torch.IntTensor | None = None, supports_mapping: dict[str, bool] | None = None, **kwargs, ): @@ -546,6 +604,10 @@ def _process_flash_attention_kwargs( Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled. s_aux (`torch.Tensor`, *optional*): Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head. + max_seqlen_q (`Union[int, torch.IntTensor]`, *optional*): + The maximum sequence length in the query tensor during a varlen forward. + max_seqlen_k (`Union[int, torch.IntTensor]`, *optional*): + The maximum sequence length in the key/value tensor during a varlen forward. Return: flash_kwargs (`dict`): A dict of kwargs that are requested and supported. @@ -573,9 +635,31 @@ def _process_flash_attention_kwargs( if supports_mapping["softcap"] and softcap is not None: flash_kwargs["softcap"] = softcap - # Only within kernel implementation atm - if supports_mapping["s_aux"] and s_aux is not None: - flash_kwargs["s_aux"] = s_aux + if ((legacy_sink_param := supports_mapping["s_aux"]) or supports_mapping["learnable_sink"]) and s_aux is not None: + if legacy_sink_param: + flash_kwargs["s_aux"] = s_aux # e.g. FA3 (vllm) + else: + flash_kwargs["learnable_sink"] = s_aux # FA4 + + # There is a limitation of the flash attention API, as the function `flash_attn_varlen_func` + # may require `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`. + # + # You can either set + # - Env: `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` + # - Before compiling: `torch._dynamo.config.capture_scalar_outputs = True` + # to allow torch compile to handle scalar outputs in those cases. + same_max_seqlen = max_seqlen_q is max_seqlen_k # to avoid 2x device syncs + if supports_mapping["max_seqlen_q"] and max_seqlen_q is not None: + if not isinstance(max_seqlen_q, int) and is_tracing(max_seqlen_q): + max_seqlen_q = max_seqlen_q.item() + flash_kwargs["max_seqlen_q"] = max_seqlen_q + + if supports_mapping["max_seqlen_k"] and max_seqlen_k is not None: + if same_max_seqlen and flash_kwargs["max_seqlen_q"] is not None: + max_seqlen_k = flash_kwargs["max_seqlen_q"] + elif not isinstance(max_seqlen_k, int) and is_tracing(max_seqlen_k): + max_seqlen_k = max_seqlen_k.item() + flash_kwargs["max_seqlen_k"] = max_seqlen_k return flash_kwargs @@ -631,7 +715,8 @@ def _flash_attention_forward( ) # Extract the flash attention kwargs that have been requested (and are supported by the implementation) - flash_kwargs = process_flash_kwargs_fn( + flash_kwargs = partial( + process_flash_kwargs_fn, query_length=query_length, key_length=key_states.size(1), is_causal=is_causal, @@ -673,9 +758,7 @@ def _flash_attention_forward( v, cu_seqlens_q=cu_seq_lens_q, cu_seqlens_k=cu_seq_lens_k, - max_seqlen_q=max_length_q, - max_seqlen_k=max_length_k, - **flash_kwargs, + **flash_kwargs(max_seqlen_q=max_length_q, max_seqlen_k=max_length_k), ) if isinstance(out_unpad, tuple): out_unpad = out_unpad[0] @@ -704,9 +787,7 @@ def _flash_attention_forward( v, cu_seqlens_q=cu_seq_lens_q, cu_seqlens_k=cu_seq_lens_k, - max_seqlen_q=max_length_q, - max_seqlen_k=max_length_k, - **flash_kwargs, + **flash_kwargs(max_seqlen_q=max_length_q, max_seqlen_k=max_length_k), ) if isinstance(out, tuple): out = out[0] @@ -715,7 +796,7 @@ def _flash_attention_forward( # No padding else: - out = flash_fn(query_states, key_states, value_states, **flash_kwargs) + out = flash_fn(query_states, key_states, value_states, **flash_kwargs()) if isinstance(out, tuple): out = out[0] diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 5fd8ff53b2a9..a4928a331c11 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -15,7 +15,6 @@ import collections import copy import functools -import importlib.metadata import inspect import json import os @@ -84,7 +83,12 @@ verify_tp_plan, ) from .loss.loss_utils import LOSS_MAPPING -from .modeling_flash_attention_utils import lazy_import_flash_attention, lazy_import_paged_flash_attention +from .modeling_flash_attention_utils import ( + FLASH_ATTENTION_COMPATIBILITY_MATRIX, + FLASH_ATTN_KERNEL_FALLBACK, + lazy_import_flash_attention, + lazy_import_paged_flash_attention, +) from .modeling_rope_utils import ROPE_INIT_FUNCTIONS from .monkey_patching import apply_patches, patch_output_recorders from .pytorch_utils import id_tensor_storage @@ -109,11 +113,8 @@ is_accelerate_available, is_bitsandbytes_available, is_env_variable_true, - is_flash_attn_2_available, - is_flash_attn_3_available, is_kernels_available, is_torch_flex_attn_available, - is_torch_mlu_available, is_torch_npu_available, is_torch_xpu_available, logging, @@ -121,8 +122,10 @@ from .utils.generic import GeneralInterface, is_flash_attention_requested from .utils.hub import DownloadKwargs, create_and_tag_model_card, get_checkpoint_shard_files from .utils.import_utils import ( + is_flash_attn_greater_or_equal, is_huggingface_hub_greater_or_equal, is_sagemaker_mp_enabled, + is_torch_cuda_available, is_tracing, ) from .utils.loading_report import LoadStateDictInfo, log_state_dict_report @@ -154,12 +157,6 @@ _is_quantized = False _is_ds_init_called = False -# Mapping from flash attention implementations to their kernel fallback repositories -FLASH_ATTN_KERNEL_FALLBACK = { - "flash_attention_2": "kernels-community/flash-attn2", - "flash_attention_3": "kernels-community/vllm-flash-attn3", -} - @dataclass(frozen=True) class LoadStateDictConfig: @@ -1553,176 +1550,145 @@ def can_generate(cls) -> bool: # Otherwise, can't generate return False - def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool: + def _flash_attn_import_error( + self, + flash_attn_version: int, + general_availability_check: Callable, + pkg_availability_check: Callable, + supported_devices: tuple[tuple[Callable, str]], + custom_supported_devices: tuple[tuple[Callable, str]] = (), + cuda_min_major_version: int | None = None, + ): """ - Check the availability of Flash Attention 2 for a given model. + Checks whether the specified Flash Attention version is supported and if not, searches for the specific reason + on why it failed - package import and/or device incompatibility issues. Args: - is_init_check (`bool`, *optional*): - Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are - fully instantiated. This is needed as we also check the devices of the weights, which are only available - later after __init__. This allows to raise proper exceptions early before instantiating the full models - if we know that the model does not support the requested attention. + flash_attn_version (`int`): + The requested version of Flash Attention. + general_availability_check (`Callable`): + Checks whether our `is_available` function detects the specific FA version. Failing reasons + are then checked for one-by-one. + pkg_availability_check (`Callable`): + Checks whether the package could theoretically be detected in the environment by the init structures. + This is not a sure-fire check as device compatibility with FA is just as important. + supported_devices (`tuple[tuple[Callable, str]]`): + Essentially a list (for mutable kwargs reasons a tuple) of the supported devices in the format of + `(device_availability_check, device_name)`, i.e. a pair of the associated device's name and whether + it is available in the environment. + custom_supported_devices (`tuple[tuple[Callable, str]]`, *optional*, defaults to `()`): + Essentially a list (for mutable kwargs reasons a tuple) of the custom supported devices in the format of + `(device_availability_check, info_message)`. These custom devices have custom logic outside the torch + ecosystem either via kernels or other packages and hence have early checks for availability. + cuda_min_major_version (`int`, *optional*): + The minimum major cuda version supported for this version of Flash Attention. This is mostly + affecting more recent versions which are more specialized to the features of new hardware. """ - dtype = self.config.dtype - - # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases - if not (self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False)): - raise ValueError( - f"{self.__class__.__name__} does not support Flash Attention 2.0 yet. Please request to add support where" - f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new" - " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" - ) + # Certain devices have custom workarounds e.g. with their own package distribution (NPU) or via kernels (XPU) + for device_availability_check, info_message in custom_supported_devices: + if device_availability_check(): + logger.info(info_message) + return - if not is_flash_attn_2_available(): - preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" - install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." + if not general_availability_check(): + preface = f"FlashAttention{flash_attn_version} has been toggled on, but it cannot be used due to the following error:" - # package `flash-attn` can not be installed on Ascend NPU, following validation logics can be ignored. - if is_torch_npu_available(): - logger.info("Detect using FlashAttention2 on Ascend NPU.") - return True - - if is_torch_xpu_available(): - logger.info( - f"Detect using FlashAttention2 (via kernel `{FLASH_ATTN_KERNEL_FALLBACK['flash_attention_2']}`) on XPU." + # Can the package be seen in the import structure + if not pkg_availability_check(): + raise ImportError( + f"{preface} the package for FlashAttention{flash_attn_version} doesn't seem to be installed." ) - return True - - if importlib.util.find_spec("flash_attn") is None: - raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") + # Minimum version (FA2 only) + elif flash_attn_version == 2 and not is_flash_attn_greater_or_equal("2.3.3"): + raise ImportError(f"{preface} FlashAttention{flash_attn_version} requires at least version `2.3.3`.") else: - # Check FA2 installed version compatibility - flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) - if torch.version.cuda: - if flash_attention_version < version.parse("2.1.0"): - raise ImportError( - f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}" - ) - elif not torch.cuda.is_available(): - raise ValueError( - f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device." - ) - else: - raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") - elif torch.version.hip: - if flash_attention_version < version.parse("2.0.4"): + # Supported devices availability + device_availability_checks, device_names = zip(*supported_devices) + if not any(device_availability_check() for device_availability_check in device_availability_checks): + raise ImportError( + f"{preface} FlashAttention{flash_attn_version} is not available on CPU. Please make sure you are on any of the supported devices: {device_names}." + ) + # Cuda major versions (more recent FA versions are specialized to newer cuda devices) + elif cuda_min_major_version is not None and is_torch_cuda_available(): + major, _ = torch.cuda.get_device_capability() + if major < cuda_min_major_version: raise ImportError( - f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Detected version {flash_attention_version}. {install_message}" + f"{preface} FlashAttention{flash_attn_version} requires compute capability >= {cuda_min_major_version}, but found {torch.cuda.get_device_capability()} with compute capability {major}.x" ) - else: - raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") - - if dtype is None: - logger.warning_once( - "You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour" - ) - elif dtype is not None and dtype not in [torch.float16, torch.bfloat16]: - logger.warning_once( - "Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but" - f" the current dype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," - ' or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", dtype=torch.float16)`' - ) - - # With the early check, the parameters are not yet initialized correctly - if not is_init_check: - param_devices = list({param.device for param in self.parameters()}) - if len(param_devices) == 1 and param_devices[0].type == "cpu": - if torch.cuda.is_available(): - logger.warning_once( - "You are attempting to use Flash Attention 2 with a model not initialized on GPU. Make sure to move the model to GPU" - " after initializing it on CPU with `model.to('cuda')`." - ) - elif is_torch_mlu_available(): - logger.warning_once( - "You are attempting to use Flash Attention 2 with a model not initialized on MLU. Make sure to move the model to MLU" - " after initializing it on CPU with `model.to('mlu')`." - ) - else: - raise ValueError( - "You are attempting to use Flash Attention 2 with a model not initialized on GPU and with no GPU available. " - "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " - "or initialising the model on CPU and then moving it to GPU." - ) - # If no error raise by this point, we can return `True` - return True - - def _flash_attn_3_can_dispatch(self, is_init_check: bool = False) -> bool: + def _flash_attn_can_dispatch(self, flash_attn_version: int, is_init_check: bool = False) -> bool: """ - Check the availability of Flash Attention 3 for a given model. + Check the availability of Flash Attention for a given model. Args: + flash_attn_version (`int`): + The requested version of Flash Attention. is_init_check (`bool`, *optional*): Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are fully instantiated. This is needed as we also check the devices of the weights, which are only available later after __init__. This allows to raise proper exceptions early before instantiating the full models if we know that the model does not support the requested attention. """ - dtype = self.config.dtype - if not self._supports_flash_attn: raise ValueError( - f"{self.__class__.__name__} does not support Flash Attention 3 yet. Please request to add support where" + f"{self.__class__.__name__} does not support Flash Attention {flash_attn_version} yet. Please request to add support where" f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new" " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" ) - if not is_flash_attn_3_available(): - preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:" + if flash_attn_version not in [2, 3, 4]: + raise ValueError(f"Requested Flash Attention {flash_attn_version} which is not supported.") - if importlib.util.find_spec("flash_attn_3") is None: - raise ImportError(f"{preface} the package flash_attn_3 seems to be not installed.") + # Check if we can even use the FA version based on the env of the user + self._flash_attn_import_error(**FLASH_ATTENTION_COMPATIBILITY_MATRIX[flash_attn_version]) - if torch.cuda.is_available(): - major, _ = torch.cuda.get_device_capability() - if major < 9: - raise ValueError( - f"{preface} Flash Attention 3 requires compute capability >= 9.0, but found {torch.cuda.get_device_capability()} with compute capability {major}.0." - ) - else: - raise ImportError(f"{preface} Flash Attention 3 is not available.") - else: - raise ValueError( - f"{preface} Flash Attention 3 is not available on CPU. Please make sure torch can access a CUDA device." + # Check for attention dropout, which is incompatible with newer FA versions + # (many should not really care about dropout as it is not super effective, hence warning for now) + if flash_attn_version > 2: + if hasattr(self.config, "attention_dropout") and self.config.attention_dropout > 0: + logger.warning_once( + f"You are attempting to use Flash Attention {flash_attn_version} with dropout. " + "This might lead to unexpected behaviour as this is not supported on recent versions of Flash Attention." ) + # People often move dtypes after init so we only warn in those cases + dtype = self.config.dtype if dtype is None: logger.warning_once( - "You are attempting to use Flash Attention 3 without specifying a torch dtype. This might lead to unexpected behaviour" + f"You are attempting to use Flash Attention {flash_attn_version} without specifying a dtype. This might lead to unexpected behaviour" ) elif dtype is not None and dtype not in [torch.float16, torch.bfloat16]: logger.warning_once( - "Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but" + f"Flash Attention {flash_attn_version} only supports torch.float16 and torch.bfloat16 dtypes, but" f" the current dype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," - ' or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", dtype=torch.float16)`' - ) - - if getattr(self.config, "alibi", False) or getattr(self.config, "use_alibi", False): - raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.") - - # Check for attention dropout, which is incompatible with FA3 - if hasattr(self.config, "attention_dropout") and self.config.attention_dropout > 0: - raise ValueError( - f"Model has attention_dropout={self.config.attention_dropout}, which is not supported by Flash Attention 3." + f' or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_{flash_attn_version}", dtype=torch.float16)`' ) # With the early check, the parameters are not yet initialized correctly if not is_init_check: param_devices = list({param.device for param in self.parameters()}) if len(param_devices) == 1 and param_devices[0].type == "cpu": - if torch.cuda.is_available(): - logger.warning_once( - "You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU" - " after initializing it on CPU with `model.to('cuda')`." - ) - else: + found_device = False + for device_availability_check, device_name in FLASH_ATTENTION_COMPATIBILITY_MATRIX[flash_attn_version][ + "supported_devices" + ]: + if device_availability_check(): + found_device = True + logger.warning_once( + f"You are attempting to use Flash Attention {flash_attn_version} with a model not initialized on GPU. Please make sure to have " + "access to a GPU and either initialise the model on a GPU by passing a device_map or initialising the model on CPU and then " + f"moving it to GPU, e.g. with `model.to('{device_name}')`." + ) + break + + if not found_device: raise ValueError( - "You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. " + f"You are attempting to use Flash Attention {flash_attn_version} with a model not initialized on GPU and with no GPU available. " "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " "or initialising the model on CPU and then moving it to GPU." ) + # If no error raise by this point, we can return `True` return True def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool: @@ -1823,8 +1789,8 @@ def _check_and_adjust_attn_implementation( compatible_flash_implementations = getattr(self, "_compatible_flash_implementations", None) if ( - compatible_flash_implementations - and is_flash_attention_requested(requested_attention_implementation=base_implementation) + is_flash_attention_requested(requested_attention_implementation=base_implementation) + and compatible_flash_implementations is not None and base_implementation not in compatible_flash_implementations ): default_flash_implementation = ( @@ -1838,18 +1804,26 @@ def _check_and_adjust_attn_implementation( attn_implementation = default_flash_implementation applicable_attn_implementation = attn_implementation - is_paged = attn_implementation is not None and attn_implementation.startswith("paged|") - # If FA not installed, do not fail but use kernels instead - requested_original_flash_attn = attn_implementation is not None and ( - attn_implementation.removeprefix("paged|") == "flash_attention_2" - or attn_implementation.removeprefix("paged|") == "flash_attention_3" - ) + requested_original_flash_attn = False + if is_flash_attention_requested(requested_attention_implementation=attn_implementation): + # If FA not installed, do not fail but use kernels instead if possible + for fa_version in FLASH_ATTENTION_COMPATIBILITY_MATRIX.keys(): + # No kernels support for FA4 for now + if fa_version == 4: + continue + + # Check whether we have an original FA requested but not available in the env + if requested_original_flash_attn := ( + attn_implementation.removeprefix("paged|") == f"flash_attention_{fa_version}" + and not FLASH_ATTENTION_COMPATIBILITY_MATRIX[fa_version]["general_availability_check"]() + ): + break + if ( - requested_original_flash_attn - and self._supports_flash_attn - and not (is_flash_attn_2_available() or is_flash_attn_3_available()) + self._supports_flash_attn + and requested_original_flash_attn and is_kernels_available() and not is_torch_npu_available() ): @@ -1882,10 +1856,8 @@ def _check_and_adjust_attn_implementation( except Exception as e: # raise the proper exception for requested flash attention if requested_original_flash_attn: - if attn_implementation.endswith("2"): - self._flash_attn_2_can_dispatch() - else: - self._flash_attn_3_can_dispatch() + fa_version = int(attn_implementation[-1]) # "flash_attention_(2|3|...)" + self._flash_attn_can_dispatch(flash_attn_version=fa_version, is_init_check=is_init_check) # error properly out if a kernel was specifically requested raise e @@ -1922,7 +1894,10 @@ def get_correct_attn_implementation(self, requested_attention: str | None, is_in ) # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases if self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False): - message += ', `"attn_implementation=flash_attention_3"`, `"attn_implementation=flash_attention_2"`, `"attn_implementation=paged|flash_attention_2"`' + message += ", " + for fa_version in FLASH_ATTENTION_COMPATIBILITY_MATRIX.keys(): + message += f'`"attn_implementation=flash_attention_{fa_version}"`, `"attn_implementation=paged|flash_attention_{fa_version}"`, ' + message = message[:-2] # remove trailing comma if self._supports_sdpa: message += ', `"attn_implementation=sdpa"`, `"attn_implementation=paged|sdpa"`' if self._supports_flex_attn: @@ -1930,10 +1905,11 @@ def get_correct_attn_implementation(self, requested_attention: str | None, is_in raise ValueError(message + ".") # Perform relevant checks - if "flash_attention_2" in applicable_attention: - self._flash_attn_2_can_dispatch(is_init_check) - elif "flash_attention_3" in applicable_attention: - self._flash_attn_3_can_dispatch(is_init_check) + if is_flash_attention_requested(requested_attention_implementation=applicable_attention) and ( + fa_matched := re.search(r"^flash_attention_(\d)$", applicable_attention) + ): + fa_version = int(fa_matched.group(1)) # last digit + self._flash_attn_can_dispatch(flash_attn_version=fa_version, is_init_check=is_init_check) elif "flex_attention" in applicable_attention: self._flex_attn_can_dispatch(is_init_check) elif "sdpa" in applicable_attention: @@ -3791,7 +3767,13 @@ def from_pretrained( attn_implementation (`str`, *optional*): - The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. + The attention implementation to use in the model (if relevant). Can be any of + - `"eager"` (manual implementation of the attention) + - `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)) + - `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)) + - `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)) + - `"flash_attention_4"` (using [Dao-AILab/flash-attention/flash_attn/cute](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/cute)). + By default, if available, SDPA will be used. The default is otherwise the manual `"eager"` implementation. Accept HF kernel references in the form: /[@][:] @@ -4847,10 +4829,12 @@ class AttentionInterface(GeneralInterface): # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if # a new instance is created (in order to locally override a given function) _global_mapping = { + "flash_attention_4": flash_attention_forward, "flash_attention_3": flash_attention_forward, "flash_attention_2": flash_attention_forward, "flex_attention": flex_attention_forward, "sdpa": sdpa_attention_forward, + "paged|flash_attention_4": paged_attention_forward, "paged|flash_attention_3": paged_attention_forward, "paged|flash_attention_2": paged_attention_forward, "paged|sdpa": sdpa_attention_paged_forward, diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 9157c32f1626..9f87e09a155a 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -406,7 +406,7 @@ class GptOssPreTrainedModel(PreTrainedModel): "attentions": GptOssAttention, } _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] - _compatible_flash_implementations = ["kernels-community/vllm-flash-attn3"] + _compatible_flash_implementations = ["kernels-community/vllm-flash-attn3", "flash_attention_4"] @torch.no_grad() def _init_weights(self, module): diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index b2e6374dd1f2..247bf6f14983 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -315,7 +315,7 @@ def forward( class GptOssPreTrainedModel(LlamaPreTrainedModel): _keep_in_fp32_modules = ["post_attention_layernorm", "input_layernorm", "norm"] _supports_sdpa = False - _compatible_flash_implementations = ["kernels-community/vllm-flash-attn3"] + _compatible_flash_implementations = ["kernels-community/vllm-flash-attn3", "flash_attention_4"] _can_record_outputs = { "router_logits": OutputRecorder(GptOssTopKRouter, index=0), "hidden_states": GptOssDecoderLayer, diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index d48d1f5fe525..f9e371318872 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -93,6 +93,7 @@ is_fbgemm_gpu_available, is_flash_attn_2_available, is_flash_attn_3_available, + is_flash_attn_4_available, is_flute_available, is_fouroversix_available, is_fp_quant_available, @@ -671,6 +672,37 @@ def require_flash_attn_3(test_case): return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case) +def require_flash_attn_4(test_case): + """ + Decorator marking a test that requires Flash Attention 4. + + These tests are skipped when Flash Attention 4 isn't installed. + """ + return unittest.skipUnless(is_flash_attn_4_available(), "test requires Flash Attention 4")(test_case) + + +def require_all_flash_attn(test_case): + flash_attn_available = is_flash_attn_2_available() + kernels_available = is_kernels_available() + try: + from kernels import get_kernel + + get_kernel(FLASH_ATTN_KERNEL_FALLBACK["flash_attention_2"]) + except Exception as _: + kernels_available = False + + return unittest.skipUnless( + all( + ( + flash_attn_available | kernels_available, + is_flash_attn_3_available(), + is_flash_attn_4_available(), + ) + ), + "test requires all mainline Flash Attention packages", + )(test_case) + + def require_peft(test_case): """ Decorator marking a test that requires PEFT. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index cd19f728541d..699d28c7ff04 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -132,8 +132,8 @@ is_fbgemm_gpu_available, is_flash_attn_2_available, is_flash_attn_3_available, + is_flash_attn_4_available, is_flash_attn_greater_or_equal, - is_flash_attn_greater_or_equal_2_10, is_flute_available, is_fouroversix_available, is_fp_quant_available, diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 447cac2e9cf6..379b23b58de6 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -240,6 +240,10 @@ def is_flash_attention_requested( else: checked_attention_implementation = requested_attention_implementation + # theoretically can happen, equivalent to default implementation (sdpa/eager) + if checked_attention_implementation is None: + return False + # If a specific version is requested, look for a pattern of type "flash...{version}" if version is not None: return re.match(r".*flash.*" + str(version), checked_attention_implementation) is not None diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 31d437cb206c..d2113c013ceb 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -916,22 +916,15 @@ def is_bitsandbytes_available(min_version: str = BITSANDBYTES_MIN_VERSION) -> bo @lru_cache def is_flash_attn_2_available() -> bool: is_available, flash_attn_version = _is_package_available("flash_attn", return_version=True) + # FA4 is also distributed under "flash_attn", hence we need to check the naming here + is_available = is_available and "flash_attn" in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] + if not is_available or not (is_torch_cuda_available() or is_torch_mlu_available()): return False - import torch - + # Only allow versions >= 2.3.3 to avoid very old legacy workarounds that are now 2+ years old try: - torch_version = getattr(torch, "version") - if torch_version.cuda: - return version.parse(flash_attn_version) >= version.parse("2.1.0") - elif torch_version.hip: - # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention - return version.parse(flash_attn_version) >= version.parse("2.0.4") - elif is_torch_mlu_available(): - return version.parse(flash_attn_version) >= version.parse("2.3.3") - else: - return False + return version.parse(flash_attn_version) >= version.parse("2.3.3") except packaging.version.InvalidVersion: return False @@ -942,14 +935,19 @@ def is_flash_attn_3_available() -> bool: @lru_cache -def is_flash_attn_greater_or_equal_2_10() -> bool: - _, flash_attn_version = _is_package_available("flash_attn", return_version=True) - return is_flash_attn_2_available() and version.parse(flash_attn_version) >= version.parse("2.1.0") +def is_flash_attn_4_available() -> bool: + # Check first under base flash then cute + if not (is_torch_cuda_available() and _is_package_available("flash_attn")[0]): + return False + return _is_package_available("flash_attn.cute")[0] @lru_cache def is_flash_attn_greater_or_equal(library_version: str) -> bool: is_available, flash_attn_version = _is_package_available("flash_attn", return_version=True) + # FA4 is also distributed under "flash_attn", hence we need to check the naming here + is_available = is_available and "flash_attn" in PACKAGE_DISTRIBUTION_MAPPING["flash_attn"] + if not is_available: return False try: diff --git a/tests/generation/test_flash_attention_parity.py b/tests/generation/test_flash_attention_parity.py index 969cdddcd38d..7b71ed02dc36 100644 --- a/tests/generation/test_flash_attention_parity.py +++ b/tests/generation/test_flash_attention_parity.py @@ -1,4 +1,4 @@ -# Copyright 2025 Eduard Durech and SGLang team. +# Copyright 2025 Eduard Durech, SGLang, and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,12 +16,13 @@ # RUN_SLOW=1 pytest -s tests/generation/test_flash_attention_parity.py import unittest +from collections import defaultdict import pytest import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.testing_utils import require_flash_attn, require_flash_attn_3, require_torch_gpu, slow +from transformers.testing_utils import require_all_flash_attn, require_torch_gpu, slow class FlashAttentionParityTest(unittest.TestCase): @@ -74,12 +75,13 @@ def _benchmark_generation(self, model, inputs, n_warmup=3, n_runs=5): return start_time.elapsed_time(end_time) / n_runs - @pytest.mark.flash_attn_3_test - @require_torch_gpu - @require_flash_attn - @require_flash_attn_3 @slow - def test_flash_attention_2_3_parity(self): + @require_torch_gpu + @require_all_flash_attn + @pytest.mark.all_flash_attn_test + def test_flash_attention_parity(self): + flash_attn_versions = [2, 3, 4] + model_id = "meta-llama/Llama-3.2-1B-Instruct" prompt = ["The ETH AI Center is", "What is life?"] @@ -87,55 +89,88 @@ def test_flash_attention_2_3_parity(self): model = AutoModelForCausalLM.from_pretrained( model_id, dtype=torch.bfloat16, + device_map="auto", attn_implementation="flash_attention_2", - ).to("cuda") + ) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token_id = tokenizer.eos_token_id # 2. Generate with both models inputs = tokenizer(prompt, padding=True, padding_side="left", return_tensors="pt").to("cuda") + logits = {} + logprobs = {} + outputs = defaultdict(list) with torch.no_grad(): - output_2 = model.generate( - **inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True - ) - model.set_attn_implementation("flash_attention_3") - output_3 = model.generate( - **inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True - ) + + def generate(model, version, outputs, logits, logprobs): + model.set_attn_implementation(f"flash_attention_{version}") + output = model.generate( + **inputs, max_new_tokens=20, do_sample=False, output_scores=True, return_dict_in_generate=True + ) + logit = torch.stack(output.scores) + logprob = torch.nn.functional.log_softmax(logit, dim=-1) + + for i in range(len(prompt)): + outputs[version].append(tokenizer.decode(output.sequences[i], skip_special_tokens=True)) + logits[version] = logit + logprobs[version] = logprob + + for version in flash_attn_versions: + generate(model, version, outputs, logits, logprobs) # 3. Correctness check # 3a. Logits - logits_2 = torch.stack(output_2.scores) - logits_3 = torch.stack(output_3.scores) - torch.testing.assert_close(logits_2, logits_3, atol=1e-3, rtol=1e-3) - logprobs_2 = torch.nn.functional.log_softmax(logits_2, dim=-1) - logprobs_3 = torch.nn.functional.log_softmax(logits_3, dim=-1) - max_logprob_diff = torch.max(torch.abs(logprobs_2 - logprobs_3)).item() + # FA2 as base to compare against + logits_1 = logits[2] + logprobs_1 = logprobs[2] + max_logprob_diffs = [] + for version in range(1, len(flash_attn_versions)): + logits_x = logits[flash_attn_versions[version]] + logprobs_x = logprobs[flash_attn_versions[version]] + max_logprob_diffs.append(torch.max(torch.abs(logprobs_1 - logprobs_x)).item()) + + # Only 80% need to pass the tolerance (big model with several steps) + atol, fraction = 4e-2, 0.8 + logits_ok = (torch.abs(logits_1 - logits_x) <= atol).float().mean().item() + assert logits_ok >= fraction, ( + f"FA{flash_attn_versions[version]} logits pass fraction {logits_ok:.6f} < {fraction:.6f}" + ) # 3b. Generated text - text_2s, text_3s = [], [] - for i in range(len(prompt)): - text_2s.append(tokenizer.decode(output_2.sequences[i], skip_special_tokens=True)) - text_3s.append(tokenizer.decode(output_3.sequences[i], skip_special_tokens=True)) - - rouge_scores = self._calculate_rouge_l(text_2s, text_3s) - for i in range(len(rouge_scores)): - assert rouge_scores[i] > 0.99, f"Generated texts at prompt {i} do not match (ROUGE-L: {rouge_scores[i]})" + # FA2 as base to compare against + texts_1 = outputs[2] + rouge_scores = [] + for version in range(1, len(flash_attn_versions)): + fa_version = flash_attn_versions[version] + texts_x = outputs[fa_version] + rouge_score = self._calculate_rouge_l(texts_1, texts_x) + for idx, score in enumerate(rouge_score): + assert score > 0.99, ( + f"Generated texts at prompt {idx} do not match (ROUGE-L: {score}) comparing FA2 vs FA{fa_version}" + ) + rouge_scores.append(self._calculate_rouge_l(texts_1, texts_x)) # 4. Performance check + times = [] with torch.no_grad(): - time_3 = self._benchmark_generation(model, inputs) - model.set_attn_implementation("flash_attention_2") - time_2 = self._benchmark_generation(model, inputs) - - print(f"\n--- Flash Attention {2, 3} Parity Test on {model_id} ---") - print(f"Prompt: '{prompt}'") - print(f"Generated text with Flash Attention 2: {text_2s}") - print(f"Generated text with Flash Attention 3: {text_3s}") - print(f"ROUGE-L: {rouge_scores}") - print(f"Max absolute difference in logprobs: {max_logprob_diff:.5e}") - print(f"Flash Attention 2 latency: {time_2:.2f} ms") - print(f"Flash Attention 3 latency: {time_3:.2f} ms") - print(f"Speed-up: {time_2 / time_3:.2f}x") + for version in flash_attn_versions: + model.set_attn_implementation(f"flash_attention_{version}") + times.append(self._benchmark_generation(model, inputs)) + + # Summary + print(f"\n--- Flash Attention Parity Test on {model_id} ---") + print(f"Prompts: '{prompt}'") + print("\nGenerated texts:") + for version in flash_attn_versions: + print(f" With FA{version}: {outputs[version]}") + print("\nROUGE-L scores:") + for idx, version in enumerate(range(1, len(flash_attn_versions))): + print(f" Between FA2 and FA{flash_attn_versions[version]}: {rouge_scores[idx]}") + print("\nMax absolute difference in logprobs:") + for idx, version in enumerate(range(1, len(flash_attn_versions))): + print(f" Between FA2 and FA{flash_attn_versions[version]}: {max_logprob_diffs[idx]:.5e}") + print("\nLatency:") + for idx, version in enumerate(flash_attn_versions): + print(f" With FA{version}: {times[idx]}") print("---") diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 8b5d37daf72b..4c4d1210585c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -43,6 +43,7 @@ require_accelerate, require_flash_attn, require_flash_attn_3, + require_flash_attn_4, require_optimum_quanto, require_torch, require_torch_accelerator, @@ -1797,6 +1798,7 @@ def _test_attention_implementation(self, attn_implementation): "sdpa": "_supports_sdpa", "flash_attention_2": "_supports_flash_attn", "flash_attention_3": "_supports_flash_attn", + "flash_attention_4": "_supports_flash_attn", } for model_class in self.all_generative_model_classes: @@ -1896,6 +1898,14 @@ def test_eager_matches_fa3_generate(self): """Tests that generate has equivalent outputs with FA3 and eager attention implementations.""" self._test_attention_implementation("flash_attention_3") + @pytest.mark.flash_attn_4_test + @require_flash_attn_4 + @require_torch_gpu + @slow + def test_eager_matches_fa4_generate(self): + """Tests that generate has equivalent outputs with FA4 and eager attention implementations.""" + self._test_attention_implementation("flash_attention_4") + @require_flash_attn @require_torch_accelerator @pytest.mark.flash_attn_test @@ -2006,6 +2016,7 @@ def attention_mask_padding_matches_padding_free_with_position_ids( "sdpa": "_supports_sdpa", "flash_attention_2": "_supports_flash_attn", "flash_attention_3": "_supports_flash_attn", + "flash_attention_4": "_supports_flash_attn", } for model_class in self.all_generative_model_classes: @@ -2153,6 +2164,22 @@ def test_flash_attention_3_padding_matches_padding_free_with_position_ids_and_fa attn_implementation="flash_attention_3", fa_kwargs=True ) + @require_flash_attn_4 + @require_torch_gpu + @pytest.mark.flash_attn_4_test + @slow + def test_flash_attention_4_padding_matches_padding_free_with_position_ids(self): + self.attention_mask_padding_matches_padding_free_with_position_ids(attn_implementation="flash_attention_4") + + @require_flash_attn_4 + @require_torch_gpu + @pytest.mark.flash_attn_4_test + @slow + def test_flash_attention_4_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): + self.attention_mask_padding_matches_padding_free_with_position_ids( + attn_implementation="flash_attention_4", fa_kwargs=True + ) + def _get_custom_4d_mask_test_data(self): # Sequence in which all but the last token is the same input_ids = torch.tensor( diff --git a/tests/models/gemma3/test_modeling_gemma3.py b/tests/models/gemma3/test_modeling_gemma3.py index c8ce8c660863..343b125144b7 100644 --- a/tests/models/gemma3/test_modeling_gemma3.py +++ b/tests/models/gemma3/test_modeling_gemma3.py @@ -36,6 +36,7 @@ require_deterministic_for_xpu, require_flash_attn, require_flash_attn_3, + require_flash_attn_4, require_torch, require_torch_accelerator, require_torch_gpu, @@ -401,6 +402,13 @@ def test_flash_attn_2_from_config(self): def test_flash_attn_3_from_config(self): self.flash_attn_from_config(attn_implementation="flash_attention_3", test_fwd_in_train=False) + @require_flash_attn_4 + @require_torch_gpu + @mark.flash_attn_4_test + @slow + def test_flash_attn_4_from_config(self): + self.flash_attn_from_config(attn_implementation="flash_attention_4", test_fwd_in_train=False) + @slow @require_torch_accelerator diff --git a/tests/models/sam3/test_modeling_sam3.py b/tests/models/sam3/test_modeling_sam3.py index 7f679ca2d842..33e6fe37e4b2 100644 --- a/tests/models/sam3/test_modeling_sam3.py +++ b/tests/models/sam3/test_modeling_sam3.py @@ -664,6 +664,20 @@ def test_flash_attn_3_inference_equivalence(self): def test_flash_attn_3_inference_equivalence_right_padding(self): pass + @unittest.skip( + reason="Sam3Model creates attention masks from features (with gradients), " + "which is incompatible with flash attention's expectation of binary masks" + ) + def test_flash_attn_4_inference_equivalence(self): + pass + + @unittest.skip( + reason="Sam3Model creates attention masks from features (with gradients), " + "which is incompatible with flash attention's expectation of binary masks" + ) + def test_flash_attn_4_inference_equivalence_right_padding(self): + pass + @unittest.skip( reason="Sam3Model creates attention masks from features (with gradients), " "which is incompatible with flash attention's expectation of binary masks" diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index cc86d287243a..0e7cceeee7f1 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -89,6 +89,7 @@ require_deepspeed, require_flash_attn, require_flash_attn_3, + require_flash_attn_4, require_kernels, require_non_hpu, require_torch, @@ -3393,6 +3394,22 @@ def test_flash_attn_3_inference_equivalence(self): def test_flash_attn_3_inference_equivalence_right_padding(self): self.flash_attn_inference_equivalence(attn_implementation="flash_attention_3", padding_side="right") + @require_flash_attn_4 + @require_torch_gpu + @mark.flash_attn_4_test + @slow + @is_flaky() + def test_flash_attn_4_inference_equivalence(self): + self.flash_attn_inference_equivalence(attn_implementation="flash_attention_4", padding_side="left") + + @require_flash_attn_4 + @require_torch_gpu + @mark.flash_attn_4_test + @slow + @is_flaky() + def test_flash_attn_4_inference_equivalence_right_padding(self): + self.flash_attn_inference_equivalence(attn_implementation="flash_attention_4", padding_side="right") + def test_attn_implementation_composite_models(self): """ Tests if composite models can receive a dict object as attn_implementation, where each key should be @@ -3797,6 +3814,12 @@ def test_flash_attn_2_can_dispatch_composite_models(self): def test_flash_attn_3_can_dispatch_composite_models(self): self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_3") + @require_flash_attn_4 + @require_torch_gpu + @mark.flash_attn_4_test + def test_flash_attn_4_can_dispatch_composite_models(self): + self.flash_attn_can_dispatch_composite_models(attn_implementation="flash_attention_4") + @require_flash_attn @require_torch_accelerator @require_bitsandbytes @@ -3964,6 +3987,13 @@ def test_flash_attn_2_from_config(self): def test_flash_attn_3_from_config(self): self.flash_attn_from_config(attn_implementation="flash_attention_3") + @require_flash_attn_4 + @require_torch_gpu + @mark.flash_attn_4_test + @slow + def test_flash_attn_4_from_config(self): + self.flash_attn_from_config(attn_implementation="flash_attention_4") + def test_sliding_window_mask(self): """Tests that we can control the sliding window attention behavior of a model.""" config, inputs = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 3ce7f8032e1d..7366845c4d78 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -87,6 +87,7 @@ from transformers.utils.import_utils import ( is_flash_attn_2_available, is_flash_attn_3_available, + is_flash_attn_4_available, is_kernels_available, is_torch_npu_available, ) @@ -760,6 +761,9 @@ def test_model_from_pretrained_attn_implementation(self): if is_flash_attn_3_available(): attn_implementation_available.append("flash_attention_3") + if is_flash_attn_4_available(): + attn_implementation_available.append("flash_attention_4") + for requested_attn_implementation in attn_implementation_available: model = AutoModelForCausalLM.from_pretrained( TINY_MISTRAL, attn_implementation=requested_attn_implementation @@ -785,6 +789,9 @@ def test_model_from_config_attn_implementation(self): if is_flash_attn_3_available(): attn_implementation_available.append("flash_attention_3") + if is_flash_attn_4_available(): + attn_implementation_available.append("flash_attention_4") + for requested_attn_implementation in attn_implementation_available: config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation) # Ensure the config was set correctly @@ -2309,6 +2316,9 @@ def test_decoder_only_model_can_be_used_as_encoder(self, attn_implementation: st config = LlamaConfig( num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=16, hidden_size=32, intermediate_size=64, vocab_size=100, @@ -2829,7 +2839,7 @@ def test_not_available_flash(self): _ = AutoModel.from_pretrained( "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" ) - self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) + self.assertTrue("the package for FlashAttention2 doesn't seem to be installed." in str(cm.exception)) def test_not_available_flash_with_config(self): if is_flash_attn_2_available(): @@ -2852,7 +2862,7 @@ def test_not_available_flash_with_config(self): attn_implementation="flash_attention_2", ) - self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) + self.assertTrue("the package for FlashAttention2 doesn't seem to be installed." in str(cm.exception)) def test_kernels_fallback(self): if not is_kernels_available():