From efa7189c5546a94e8c1194b2ccc6669e0f7ad65e Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 14 Feb 2025 06:53:31 +0000 Subject: [PATCH 1/8] _supports_flash_attn_3 --- examples/modular-transformers/modeling_dummy.py | 1 + examples/modular-transformers/modeling_multimodal1.py | 1 + examples/modular-transformers/modeling_multimodal2.py | 1 + examples/modular-transformers/modeling_my_new_model2.py | 1 + examples/modular-transformers/modeling_new_task_model.py | 1 + examples/modular-transformers/modeling_super.py | 1 + src/transformers/modeling_utils.py | 3 +++ src/transformers/models/aria/modular_aria.py | 2 ++ src/transformers/models/bamba/modular_bamba.py | 1 + src/transformers/models/bark/modeling_bark.py | 1 + src/transformers/models/bart/modeling_bart.py | 1 + src/transformers/models/chameleon/modeling_chameleon.py | 1 + src/transformers/models/clip/modeling_clip.py | 1 + src/transformers/models/cohere/modeling_cohere.py | 1 + src/transformers/models/cohere2/modeling_cohere2.py | 1 + src/transformers/models/data2vec/modeling_data2vec_audio.py | 1 + src/transformers/models/dbrx/modeling_dbrx.py | 1 + src/transformers/models/diffllama/modeling_diffllama.py | 1 + src/transformers/models/distilbert/modeling_distilbert.py | 1 + src/transformers/models/emu3/modeling_emu3.py | 1 + .../models/encoder_decoder/modeling_encoder_decoder.py | 1 + src/transformers/models/falcon/modeling_falcon.py | 1 + src/transformers/models/gemma/modeling_gemma.py | 1 + src/transformers/models/gemma2/modeling_gemma2.py | 1 + src/transformers/models/glm/modeling_glm.py | 1 + src/transformers/models/got_ocr2/modeling_got_ocr2.py | 1 + src/transformers/models/gpt2/modeling_gpt2.py | 1 + src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py | 1 + src/transformers/models/gpt_neo/modeling_gpt_neo.py | 1 + src/transformers/models/gpt_neox/modeling_gpt_neox.py | 1 + src/transformers/models/gptj/modeling_gptj.py | 1 + src/transformers/models/granite/modeling_granite.py | 1 + src/transformers/models/granitemoe/modeling_granitemoe.py | 1 + src/transformers/models/helium/modeling_helium.py | 1 + src/transformers/models/hubert/modeling_hubert.py | 1 + src/transformers/models/idefics2/modeling_idefics2.py | 1 + src/transformers/models/idefics3/modeling_idefics3.py | 1 + src/transformers/models/jamba/modeling_jamba.py | 1 + src/transformers/models/jetmoe/modeling_jetmoe.py | 1 + src/transformers/models/llama/modeling_llama.py | 1 + src/transformers/models/llava/modeling_llava.py | 1 + src/transformers/models/llava_next/modeling_llava_next.py | 1 + .../models/llava_next_video/modeling_llava_next_video.py | 1 + .../models/llava_onevision/modeling_llava_onevision.py | 1 + src/transformers/models/m2m_100/modeling_m2m_100.py | 1 + src/transformers/models/mbart/modeling_mbart.py | 1 + src/transformers/models/mimi/modeling_mimi.py | 1 + src/transformers/models/mistral/modeling_mistral.py | 1 + src/transformers/models/mixtral/modeling_mixtral.py | 1 + src/transformers/models/modernbert/modular_modernbert.py | 1 + src/transformers/models/moonshine/modular_moonshine.py | 1 + src/transformers/models/moshi/modeling_moshi.py | 2 ++ src/transformers/models/musicgen/modeling_musicgen.py | 2 ++ .../models/musicgen_melody/modeling_musicgen_melody.py | 2 ++ src/transformers/models/nemotron/modeling_nemotron.py | 1 + src/transformers/models/olmo/modeling_olmo.py | 1 + src/transformers/models/olmo2/modeling_olmo2.py | 1 + src/transformers/models/olmoe/modeling_olmoe.py | 1 + src/transformers/models/opt/modeling_opt.py | 1 + src/transformers/models/paligemma/modeling_paligemma.py | 1 + src/transformers/models/phi/modeling_phi.py | 1 + src/transformers/models/phi3/modeling_phi3.py | 1 + src/transformers/models/phimoe/modeling_phimoe.py | 1 + src/transformers/models/qwen2/modeling_qwen2.py | 1 + src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 1 + src/transformers/models/qwen2_audio/modeling_qwen2_audio.py | 1 + src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 1 + src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 1 + src/transformers/models/rag/modeling_rag.py | 1 + .../models/recurrent_gemma/modeling_recurrent_gemma.py | 1 + src/transformers/models/sew/modeling_sew.py | 1 + src/transformers/models/siglip/modeling_siglip.py | 1 + .../speech_encoder_decoder/modeling_speech_encoder_decoder.py | 1 + src/transformers/models/stablelm/modeling_stablelm.py | 1 + src/transformers/models/starcoder2/modeling_starcoder2.py | 1 + src/transformers/models/unispeech/modeling_unispeech.py | 1 + .../models/unispeech_sat/modeling_unispeech_sat.py | 1 + src/transformers/models/video_llava/modeling_video_llava.py | 1 + src/transformers/models/vipllava/modeling_vipllava.py | 1 + .../vision_encoder_decoder/modeling_vision_encoder_decoder.py | 1 + .../modeling_vision_text_dual_encoder.py | 1 + src/transformers/models/wav2vec2/modeling_wav2vec2.py | 1 + src/transformers/models/whisper/modeling_whisper.py | 1 + src/transformers/models/zamba/modeling_zamba.py | 1 + src/transformers/models/zamba2/modular_zamba2.py | 1 + 85 files changed, 91 insertions(+) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 1b0ad5ad92fe..ecdd473e957e 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -351,6 +351,7 @@ class DummyPreTrainedModel(PreTrainedModel): _no_split_modules = ["DummyDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index ec54af22186e..0b3aaabdeb26 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -351,6 +351,7 @@ class Multimodal1TextPreTrainedModel(PreTrainedModel): _no_split_modules = ["Multimodal1TextDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/examples/modular-transformers/modeling_multimodal2.py b/examples/modular-transformers/modeling_multimodal2.py index b10b11b671af..ee4ac1793f81 100644 --- a/examples/modular-transformers/modeling_multimodal2.py +++ b/examples/modular-transformers/modeling_multimodal2.py @@ -637,6 +637,7 @@ class Multimodal2VisionPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): """Initialize the weights""" diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 86669310c4f8..609a37fb5a3c 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -351,6 +351,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel): _no_split_modules = ["MyNewModel2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index 3cea4ef2c455..6d866da27459 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -108,6 +108,7 @@ class NewTaskModelPreTrainedModel(PreTrainedModel): _supports_quantized_cache = True _supports_static_cache = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 454860458636..4c94d3dba138 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -351,6 +351,7 @@ class SuperPreTrainedModel(PreTrainedModel): _no_split_modules = ["SuperDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 13c8719b3603..e75473435857 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1299,6 +1299,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Flash Attention 2 support _supports_flash_attn_2 = False + # Flash Attention 3 support + _supports_flash_attn_3 = False + # SDPA support _supports_sdpa = False diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 7d579d6e37f3..7db59793178c 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1208,6 +1208,7 @@ class AriaTextPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False + _supports_flash_attn_3 = False _supports_sdpa = True _supports_cache_class = True @@ -1348,6 +1349,7 @@ class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast): class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): config_class = AriaConfig _supports_flash_attn_2 = False + _supports_flash_attn_3 = False _supports_sdpa = False _tied_weights_keys = ["language_model.lm_head.weight"] diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 3972d25b51b9..d4f0489dd4f7 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -810,6 +810,7 @@ class BambaPreTrainedModel(PreTrainedModel): _no_split_modules = ["BambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 56f8ce4d1006..5bf3d0fdb216 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -376,6 +376,7 @@ class BarkPreTrainedModel(PreTrainedModel): config_class = BarkConfig supports_gradient_checkpointing = False _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): """Initialize the weights.""" diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index e64ab3b2d041..7b8e340f179a 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -748,6 +748,7 @@ class BartPreTrainedModel(PreTrainedModel): _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 0a9421409e25..0347cf64b8bc 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1043,6 +1043,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_quantized_cache = True _supports_cache_class = True diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 01c8f4dcbc9a..d064310afbf5 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -643,6 +643,7 @@ class CLIPPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 69e7c579f9ce..c03511f79253 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -417,6 +417,7 @@ class CoherePreTrainedModel(PreTrainedModel): _no_split_modules = ["CohereDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 11353a0a990c..8b042db9ef9c 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -433,6 +433,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Cohere2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index b1be8ab19660..72e4b1905a70 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -931,6 +931,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index fceefbe2c752..52189d09a9d9 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -830,6 +830,7 @@ class DbrxPreTrainedModel(PreTrainedModel): _no_split_modules = ["DbrxBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 16aeefcb1c88..681d506505fb 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -595,6 +595,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): _no_split_modules = ["DiffLlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = False _supports_cache_class = True diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 6aa50397d42c..0c1f9b866a19 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -589,6 +589,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): base_model_prefix = "distilbert" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module: nn.Module): diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 7d31b8d3d323..892368026d4f 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1176,6 +1176,7 @@ class Emu3PreTrainedModel(PreTrainedModel): ] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_quantized_cache = True _supports_cache_class = True diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 9ab4b7f2ced1..d091a31c4e2b 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -181,6 +181,7 @@ class EncoderDecoderModel(PreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_param_buffer_assignment = False _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index e36ea9cef222..cff79004cdf0 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -784,6 +784,7 @@ class FalconPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["FalconDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index aeb742e16dd0..19ccf291ba61 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -383,6 +383,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _no_split_modules = ["GemmaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index e9fd43c49000..c610353313f6 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -435,6 +435,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Gemma2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index f1fddcda107a..25c2220e3e04 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -398,6 +398,7 @@ class GlmPreTrainedModel(PreTrainedModel): _no_split_modules = ["GlmDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/got_ocr2/modeling_got_ocr2.py b/src/transformers/models/got_ocr2/modeling_got_ocr2.py index 957e05bea75a..ac6577fe920f 100644 --- a/src/transformers/models/got_ocr2/modeling_got_ocr2.py +++ b/src/transformers/models/got_ocr2/modeling_got_ocr2.py @@ -593,6 +593,7 @@ class GotOcr2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 931bba9ba965..0a93d3144380 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -464,6 +464,7 @@ class GPT2PreTrainedModel(PreTrainedModel): _no_split_modules = ["GPT2Block"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 4729ee098da3..e218f1a63153 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -666,6 +666,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTBigCodeBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 8598d51e6871..498e1912c6d4 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -492,6 +492,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = False # TODO: needs a HybridCache diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index f420a8ceb206..281f3db77592 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -369,6 +369,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoXLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 8c9de2dbced1..f171214a3f08 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -486,6 +486,7 @@ class GPTJPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index d0579bb8a7a4..ca6fb4a486d1 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -398,6 +398,7 @@ class GranitePreTrainedModel(PreTrainedModel): _no_split_modules = ["GraniteDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index d877b8323b3b..538a480fd631 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -839,6 +839,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel): _no_split_modules = ["GraniteMoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index fc6f862be258..59becff89c40 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -385,6 +385,7 @@ class HeliumPreTrainedModel(PreTrainedModel): _no_split_modules = ["HeliumDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index b986ab863680..0bb500c10ca7 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -1116,6 +1116,7 @@ class HubertPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index eb676c295a4f..8be9e187d7e4 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -600,6 +600,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index e4cc8bda569f..fb643f950821 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -620,6 +620,7 @@ class Idefics3PreTrainedModel(PreTrainedModel): _no_split_modules = ["Idefics3VisionAttention", "Idefics3DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index fa95b126883a..9facc4a24c6a 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1102,6 +1102,7 @@ class JambaPreTrainedModel(PreTrainedModel): _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 283174ba3cfd..931b3d7fd82b 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -861,6 +861,7 @@ class JetMoePreTrainedModel(PreTrainedModel): _no_split_modules = ["JetMoeBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0d65e1417f52..e86ade5d9d31 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -387,6 +387,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 36f212e76844..1898110c59e4 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -135,6 +135,7 @@ class LlavaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 06e1cc63940f..4401ac6c06fd 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -244,6 +244,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index f62824947ddf..28277bff7903 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -152,6 +152,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index ed584bda7f5d..8d9c0972813a 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -249,6 +249,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaOnevisionVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_static_cache = False # Qwen2 doesn't but llava has no reasons to not support _supports_quantized_cache = True diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index eb207bedd21b..0bd818dbad96 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -779,6 +779,7 @@ class M2M100PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 8412ecef1cf9..3bed530b1f28 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -745,6 +745,7 @@ class MBartPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index af36b2333577..c17c9da585d5 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1436,6 +1436,7 @@ class MimiPreTrainedModel(PreTrainedModel): _no_split_modules = ["MimiDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b300c7c646f2..69843b500f35 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -359,6 +359,7 @@ class MistralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MistralDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 170d54eca1b2..94eb64342cf1 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -481,6 +481,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MixtralDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index edfdc94346bf..5af53bb31014 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -785,6 +785,7 @@ class ModernBertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["ModernBertEmbeddings", "ModernBertEncoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = False diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 24fa4f0a1ef8..8cf83eef0dcd 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -542,6 +542,7 @@ class MoonshinePreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MoonshineEncoderLayer", "MoonshineDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index ae9a3fd804dc..2f9aec0ec960 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -871,6 +871,7 @@ class MoshiPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MoshiDecoderLayer", "MimiTransformerLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True main_input_name = "input_ids" @@ -1919,6 +1920,7 @@ class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin): main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__(self, config: MoshiConfig): diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index cab950995a97..6b75100ca3dc 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -707,6 +707,7 @@ class MusicgenPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -1671,6 +1672,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin): main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 279a7c046c4d..156bb4e16e0a 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -666,6 +666,7 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -1596,6 +1597,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 829a3283d0a3..4cab850867e0 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -628,6 +628,7 @@ class NemotronPreTrainedModel(PreTrainedModel): _no_split_modules = ["NemotronDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 6b7abaa96af2..9d972c42b850 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -363,6 +363,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _no_split_modules = ["OlmoDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 89ef5e1050bb..cc540451c496 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -364,6 +364,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Olmo2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 9b0336a32b1c..1c28997861d9 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -764,6 +764,7 @@ class OlmoePreTrainedModel(PreTrainedModel): _no_split_modules = ["OlmoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 1969acf2f5b1..66f87b831de0 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -523,6 +523,7 @@ class OPTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["OPTDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 9172b98c069e..b2f206d4fe89 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -198,6 +198,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): _supports_quantized_cache = True _supports_static_cache = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 33d86999fdf8..60b484ef3b4d 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -359,6 +359,7 @@ class PhiPreTrainedModel(PreTrainedModel): _no_split_modules = ["PhiDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index fc4787b7883b..27669d3cbc47 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -428,6 +428,7 @@ class Phi3PreTrainedModel(PreTrainedModel): _no_split_modules = ["Phi3DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 1cea9a2ea28b..1f9059ac0dc0 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -909,6 +909,7 @@ class PhimoePreTrainedModel(PreTrainedModel): _no_split_modules = ["PhimoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index bec0cb46ef29..61a0b0a75f6c 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -372,6 +372,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 8f9c7a5e6bba..f8ae1fc07484 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -386,6 +386,7 @@ class Qwen2_5_VLPreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 320d2093133f..7732051a1d23 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -543,6 +543,7 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2AudioAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 3e4aa05a22bf..f55977066440 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -792,6 +792,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2MoeDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 91ec520d5146..1e44b86106a4 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -937,6 +937,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions` diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index d3ca787691c4..6e291ce0e904 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -233,6 +233,7 @@ class RagPreTrainedModel(PreTrainedModel): config_class = RagConfig base_model_prefix = "rag" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True @classmethod diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index e2014079f936..789332485f1b 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -539,6 +539,7 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): _no_split_modules = ["RecurrentGemmaDecoderLayer"] _skip_keys_device_placement = ["cache"] _supports_flash_attn_2 = False + _supports_flash_attn_3 = False _supports_sdpa = False # we can't compare with eager for now _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index d534f6843466..94d83858a1a9 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -979,6 +979,7 @@ class SEWPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index d8a317493a10..9853b3585d0a 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -685,6 +685,7 @@ class SiglipPreTrainedModel(PreTrainedModel): "SiglipMultiheadAttentionPoolingHead", ] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 9fa099d19230..0c400495f13e 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -184,6 +184,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_param_buffer_assignment = False _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index c401b772db74..5bbfddd9b8e9 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -671,6 +671,7 @@ class StableLmPreTrainedModel(PreTrainedModel): _no_split_modules = ["StableLmDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_sdpa = True _supports_quantized_cache = True diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index d64953d72b69..91936e8e384d 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -364,6 +364,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Starcoder2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_flex_attn = True _supports_cache_class = True diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 6f7e544b598a..23cce095246e 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -1218,6 +1218,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 8daea82a0e23..de95114e79da 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -1235,6 +1235,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index d8da974b9862..3484ab33b45e 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -136,6 +136,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 71201db2098e..a7a6cff395e1 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -136,6 +136,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_cache_class = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 55c759f8e9ae..a7d01d8cc4f0 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -162,6 +162,7 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_param_buffer_assignment = False _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py index d7cceb5d2feb..2c1be11fcb47 100755 --- a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py @@ -162,6 +162,7 @@ class VisionTextDualEncoderModel(PreTrainedModel): config_class = VisionTextDualEncoderConfig base_model_prefix = "vision_text_dual_encoder" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 81f2110e721c..f1e7a34da35f 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1330,6 +1330,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index f6ffab062993..0e731ca3b160 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -769,6 +769,7 @@ class WhisperPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 54f57971a82e..3be45dd7fcfa 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -837,6 +837,7 @@ class ZambaPreTrainedModel(PreTrainedModel): _no_split_modules = ["ZambaAttentionDecoderLayer", "ZambaMambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False + _supports_flash_attn_3 = False _supports_sdpa = False _supports_cache_class = True # Note: only supports ZambaHybridDynamicCache _is_stateful = True diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index f2074b76f3da..add996d57592 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -915,6 +915,7 @@ class Zamba2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_flex_attn = True _supports_sdpa = True _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache From b1fc52e593436adb0011acb7cb420f7da6f03d28 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 14 Feb 2025 07:04:15 +0000 Subject: [PATCH 2/8] modeling_utils/import_utils --- src/transformers/modeling_utils.py | 96 ++++++++++++++++++++++++++ src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 19 +++++ 3 files changed, 116 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e75473435857..c1a96760ef70 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -88,6 +88,7 @@ is_accelerate_available, is_bitsandbytes_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_offline_mode, is_optimum_available, is_peft_available, @@ -1542,6 +1543,8 @@ def _autoset_attn_implementation( message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + if cls._supports_flash_attn_3: + message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' if cls._supports_sdpa: message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' if cls._supports_flex_attn: @@ -1582,6 +1585,14 @@ def _autoset_attn_implementation( hard_check_only=False, check_device_map=check_device_map, ) + elif config._attn_implementation == "flash_attention_3": + cls._check_and_enable_flash_attn_3( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) elif requested_attn_implementation == "flex_attention": config = cls._check_and_enable_flex_attn(config, hard_check_only=True) elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): @@ -1778,6 +1789,90 @@ def _check_and_enable_flash_attn_2( config._attn_implementation = "flash_attention_2" return config + @classmethod + def _check_and_enable_flash_attn_3( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + hard_check_only: bool = False, + ) -> PretrainedConfig: + """ + Checks the availability of Flash Attention 3 and compatibility with the current model. + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module. + """ + if not cls._supports_flash_attn_3: + raise ValueError( + f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where" + f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" + " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" + ) + + if not is_flash_attn_3_available(): + preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:" + # TODO: docs + install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-3 to install Flash Attention 3." + + if importlib.util.find_spec("flash_attn_interface") is None: + raise ImportError( + f"{preface} the package flash_attn_interface seems to be not installed. {install_message}" + ) + + if torch.version.cuda: + compute_capability = torch.cuda.get_device_capability() + major, _ = compute_capability + if major < 9: + raise ValueError("Flash Attention 3 requires NVIDIA GPU with compute capability >= 9.0") + else: + raise ValueError("Flash Attention 3 requires NVIDIA GPU with compute capability >= 9.0") + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + + if _is_bettertransformer: + raise ValueError( + "Flash Attention 3 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + ) + + if torch_dtype is None: + logger.warning_once( + "You are attempting to use Flash Attention 3.0 without specifying a torch dtype. This might lead to unexpected behaviour" + ) + elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: + logger.warning_once( + "Flash Attention 3.0 only supports torch.float16 and torch.bfloat16 dtypes, but" + f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," + ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`' + ) + + # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, + # or the model may be initialized under the context manager `with torch.device("cuda"):`. + if check_device_map and device_map is None and torch.empty(0).device.type != "cuda": + if torch.cuda.is_available(): + logger.warning_once( + "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU. Make sure to move the model to GPU" + " after initializing it on CPU with `model.to('cuda')`." + ) + else: + raise ValueError( + "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU and with no GPU available. " + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " + "or initialising the model on CPU and then moving it to GPU." + ) + elif ( + check_device_map + and device_map is not None + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to use Flash Attention 3.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " + "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." + ) + if not hard_check_only: + config._attn_implementation = "flash_attention_3" + return config + @classmethod def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: """ @@ -5758,6 +5853,7 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): ALL_ATTENTION_FUNCTIONS.update( { "flash_attention_2": flash_attention_forward, + "flash_attention_3": flash_attention_forward, "flex_attention": flex_attention_forward, "sdpa": sdpa_attention_forward, } diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index cf13060ee307..bb90ef8899b6 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -137,6 +137,7 @@ is_faiss_available, is_fbgemm_gpu_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_flax_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index bd95b6f282c0..a2d4f590a4a2 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -62,6 +62,14 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ except ImportError: # If the package can't be imported, it's not available package_exists = False + elif pkg_name == "flash_attn_interface": + try: + package = importlib.import_module(pkg_name) + package_version = getattr(package, "__version__", "N/A") + package_exists = True + except ImportError: + # If the package can't be imported, it's not available + package_exists = False else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False @@ -984,6 +992,17 @@ def is_flash_attn_greater_or_equal(library_version: str): return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) +@lru_cache() +def is_flash_attn_3_available(): + if not is_flash_attn_2_available(): + return False + + if not _is_package_available("flash_attn_interface"): + return False + + return True + + @lru_cache() def is_torch_greater_or_equal(library_version: str): if not _is_package_available("torch"): From 9a80143ca26066ee3266749a57c7b3850d8843be Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 14 Feb 2025 07:33:06 +0000 Subject: [PATCH 3/8] config._attn_implementation/_use_flash_attention_3 --- .../modular-transformers/modeling_dummy.py | 2 +- .../modeling_multimodal1.py | 2 +- .../modeling_my_new_model2.py | 2 +- .../modeling_new_task_model.py | 2 +- .../modular-transformers/modeling_super.py | 2 +- src/transformers/models/aria/modeling_aria.py | 3 +- src/transformers/models/aria/modular_aria.py | 1 + .../models/bamba/modeling_bamba.py | 2 +- .../models/bamba/modular_bamba.py | 2 +- src/transformers/models/bark/modeling_bark.py | 42 ++++++++++++++++++- src/transformers/models/bart/modeling_bart.py | 8 ++-- .../models/bloom/modeling_bloom.py | 2 +- .../models/chameleon/modeling_chameleon.py | 2 +- src/transformers/models/clip/modeling_clip.py | 3 +- .../models/codegen/modeling_codegen.py | 2 +- .../models/cohere/modeling_cohere.py | 2 +- .../models/cohere2/modeling_cohere2.py | 10 ++--- .../models/cohere2/modular_cohere2.py | 8 ++-- .../data2vec/modeling_data2vec_audio.py | 3 +- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- .../models/diffllama/modeling_diffllama.py | 2 +- .../models/distilbert/modeling_distilbert.py | 3 +- src/transformers/models/emu3/modeling_emu3.py | 2 +- .../models/falcon/modeling_falcon.py | 3 +- .../models/gemma/modeling_gemma.py | 2 +- .../models/gemma2/modeling_gemma2.py | 10 ++--- .../models/gemma2/modular_gemma2.py | 10 ++--- src/transformers/models/glm/modeling_glm.py | 2 +- src/transformers/models/gpt2/modeling_gpt2.py | 4 +- .../gpt_bigcode/modeling_gpt_bigcode.py | 3 +- .../models/gpt_neo/modeling_gpt_neo.py | 2 +- .../models/gpt_neox/modeling_gpt_neox.py | 3 +- .../models/gpt_neox/modular_gpt_neox.py | 1 + .../modeling_gpt_neox_japanese.py | 2 +- src/transformers/models/gptj/modeling_gptj.py | 3 +- .../models/granite/modeling_granite.py | 2 +- .../models/granitemoe/modeling_granitemoe.py | 2 +- .../models/helium/modeling_helium.py | 2 +- .../models/hubert/modeling_hubert.py | 6 ++- .../models/idefics/modeling_idefics.py | 2 +- .../models/idefics2/modeling_idefics2.py | 5 ++- .../models/idefics3/modeling_idefics3.py | 4 +- .../models/jamba/modeling_jamba.py | 2 +- .../models/jetmoe/modeling_jetmoe.py | 4 +- .../models/llama/modeling_llama.py | 2 +- .../models/longt5/modeling_longt5.py | 2 +- .../models/m2m_100/modeling_m2m_100.py | 12 +++--- .../models/mbart/modeling_mbart.py | 6 +-- src/transformers/models/mimi/modeling_mimi.py | 2 +- .../models/mistral/modeling_mistral.py | 2 +- .../models/mistral/modular_mistral.py | 2 +- .../models/mixtral/modeling_mixtral.py | 2 +- .../models/mllama/modeling_mllama.py | 2 +- .../models/modernbert/modeling_modernbert.py | 10 ++--- .../models/modernbert/modular_modernbert.py | 10 ++--- .../models/moonshine/modeling_moonshine.py | 6 +-- .../models/moonshine/modular_moonshine.py | 4 +- .../models/moshi/modeling_moshi.py | 4 +- src/transformers/models/mt5/modeling_mt5.py | 2 +- .../models/musicgen/modeling_musicgen.py | 4 +- .../modeling_musicgen_melody.py | 2 +- .../models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- .../models/olmo2/modeling_olmo2.py | 2 +- .../models/olmoe/modeling_olmoe.py | 2 +- src/transformers/models/opt/modeling_opt.py | 3 +- .../models/paligemma/modeling_paligemma.py | 2 +- .../models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- .../models/phimoe/modeling_phimoe.py | 2 +- .../models/pix2struct/modeling_pix2struct.py | 2 +- .../models/plbart/modeling_plbart.py | 8 ++-- .../models/pop2piano/modeling_pop2piano.py | 2 +- .../models/qwen2/modeling_qwen2.py | 2 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 2 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 2 +- src/transformers/models/sew/modeling_sew.py | 3 +- .../models/siglip/modeling_siglip.py | 3 +- .../models/stablelm/modeling_stablelm.py | 2 +- .../models/starcoder2/modeling_starcoder2.py | 2 +- .../modeling_switch_transformers.py | 2 +- src/transformers/models/t5/modeling_t5.py | 2 +- src/transformers/models/udop/modeling_udop.py | 2 +- src/transformers/models/umt5/modeling_umt5.py | 2 +- .../models/unispeech/modeling_unispeech.py | 6 ++- .../unispeech_sat/modeling_unispeech_sat.py | 6 ++- .../models/wav2vec2/modeling_wav2vec2.py | 6 ++- .../models/whisper/modeling_whisper.py | 3 +- .../models/zamba/modeling_zamba.py | 27 +++++++++++- .../models/zamba2/modeling_zamba2.py | 2 +- 92 files changed, 231 insertions(+), 135 deletions(-) diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index ecdd473e957e..c88bb7a7c64b 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -594,7 +594,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 0b3aaabdeb26..c6368994511b 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -594,7 +594,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 609a37fb5a3c..b91d3ec0c2bc 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -599,7 +599,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index 6d866da27459..9fa9372b8e82 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -259,7 +259,7 @@ def _update_causal_mask( inputs_embeds=None, is_training: bool = False, ): - if self.config.text_config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config.text_config._attn_implementation: if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index 4c94d3dba138..f3331c78fee3 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -516,7 +516,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 90755045b635..79ac3e58ea87 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1010,7 +1010,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -1381,6 +1381,7 @@ def __init__(self, config: AriaConfig): self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config.text_config._attn_implementation == "flash_attention_3" self.post_init() def _create_patch_attention_mask(self, pixel_mask): diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 7db59793178c..801bf47c9899 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -1362,6 +1362,7 @@ def __init__(self, config: AriaConfig): self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config.text_config._attn_implementation == "flash_attention_3" self.post_init() def _create_patch_attention_mask(self, pixel_mask): diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 6fdce41e5a68..2546ac667439 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1319,7 +1319,7 @@ def _update_causal_mask( past_key_values: HybridMambaAttentionDynamicCache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index d4f0489dd4f7..ce2f77aeb78f 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -1061,7 +1061,7 @@ def _update_causal_mask( past_key_values: HybridMambaAttentionDynamicCache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 5bf3d0fdb216..078bf8b058e4 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -562,6 +562,7 @@ def __init__(self, config): self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias) @@ -704,7 +705,7 @@ def forward( if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None else: attention_mask = attention_mask.view(batch_size, -1) @@ -1159,6 +1160,7 @@ def __init__(self, config): self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.layernorm_final = nn.LayerNorm(config.hidden_size) @@ -1354,7 +1356,7 @@ def forward( if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None else: # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] @@ -1831,6 +1833,42 @@ def _check_and_enable_flash_attn_2( config.fine_acoustics_config._attn_implementation = config._attn_implementation return config + @classmethod + def _check_and_enable_flash_attn_3( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + hard_check_only: bool = False, + check_device_map: bool = False, + ): + """ + `_check_and_enable_flash_attn_3` originally don't expand flash attention enabling to the model + sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention + if necessary. + + If you don't know about Flash Attention, check out the official repository of flash attention: + https://github.com/Dao-AILab/flash-attention + + For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this + specific section of the documentation to learn more about it: + https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models + + The method checks if the current setup is compatible with Flash Attention as it requires the model to be in + half precision and not ran on CPU. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_3" so that the model + can initialize the correct attention module + """ + config = super()._check_and_enable_flash_attn_3( + config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map + ) + + config.semantic_config._attn_implementation = config._attn_implementation + config.coarse_acoustics_config._attn_implementation = config._attn_implementation + config.fine_acoustics_config._attn_implementation = config._attn_implementation + return config + __all__ = [ "BarkFineModel", diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 7b8e340f179a..ca3010d14cb1 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -980,6 +980,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No ) self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -1068,7 +1069,7 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None elif self._use_sdpa and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to @@ -1164,6 +1165,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No ) self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -1284,7 +1286,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: @@ -1304,7 +1306,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index df65f0aeb949..b1793d251b54 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -739,7 +739,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 0347cf64b8bc..36397c5005f3 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1386,7 +1386,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index d064310afbf5..9cbca756dd0a 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -917,6 +917,7 @@ def __init__(self, config: CLIPTextConfig): # For attention mask, it differs between `flash_attention_2` and other attention implementations self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) @@ -954,7 +955,7 @@ def forward( ) # expand attention_mask - if attention_mask is not None and not self._use_flash_attention_2: + if attention_mask is not None and not self._use_flash_attention_2 and not self._use_flash_attention_3: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index dcb24817e303..a63e4793e63a 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -582,7 +582,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index c03511f79253..4aa2841a7c23 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -661,7 +661,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 8b042db9ef9c..29351e6e258a 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -261,8 +261,8 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + # Here we need to slice as we use a static cache by default, but FA does not support it + if attention_mask is not None and "flash_attention" in self.config._attn_implementation: seq_len = attention_mask.shape[-1] key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] @@ -358,7 +358,7 @@ def forward( effective_seq_len = max(cache_position.shape[0], self.sliding_window) # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: attention_mask = attention_mask[:, -effective_seq_len:] # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice # from the left, with an offset if we are beyond the sliding window @@ -697,7 +697,7 @@ def _update_causal_mask( # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: return attention_mask dtype, device = input_tensor.dtype, input_tensor.device @@ -963,7 +963,7 @@ def prepare_inputs_for_generation( if ( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 - and not self.config._attn_implementation == "flash_attention_2" + and "flash_attention" not in self.config._attn_implementation ): if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 979b5abc2600..d889e94f26c5 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -311,8 +311,8 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + # Here we need to slice as we use a static cache by default, but FA does not support it + if attention_mask is not None and "flash_attention" in self.config._attn_implementation: seq_len = attention_mask.shape[-1] key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] @@ -389,7 +389,7 @@ def forward( effective_seq_len = max(cache_position.shape[0], self.sliding_window) # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: attention_mask = attention_mask[:, -effective_seq_len:] # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice # from the left, with an offset if we are beyond the sliding window @@ -633,7 +633,7 @@ def prepare_inputs_for_generation( if ( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 - and not self.config._attn_implementation == "flash_attention_2" + and "flash_attention" not in self.config._attn_implementation ): if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 72e4b1905a70..aca31d7b9a01 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -794,6 +794,7 @@ def __init__(self, config): self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -810,7 +811,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 52189d09a9d9..3b99e1746938 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1115,7 +1115,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 681d506505fb..ca2a196d632b 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -900,7 +900,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 0c1f9b866a19..47f57c95301e 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -677,6 +677,7 @@ def __init__(self, config: PretrainedConfig): self.embeddings = Embeddings(config) # Embeddings self.transformer = Transformer(config) # Encoder self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" # Initialize weights and apply final processing @@ -784,7 +785,7 @@ def forward( embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: if attention_mask is None: diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 892368026d4f..59dd52a84256 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1479,7 +1479,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index cff79004cdf0..ba444369f338 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -839,6 +839,7 @@ def __init__(self, config: FalconConfig): # Transformer blocks self.h = nn.ModuleList([FalconDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" # Final Layer Norm @@ -1027,7 +1028,7 @@ def _update_causal_mask( # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 19ccf291ba61..bb0d18cb8d2b 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -633,7 +633,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index c610353313f6..a0fbf3a884a7 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -230,8 +230,8 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + # Here we need to slice as we use a static cache by default, but FA does not support it + if attention_mask is not None and "flash_attention" in self.config._attn_implementation: seq_len = attention_mask.shape[-1] key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] @@ -296,7 +296,7 @@ def forward( effective_seq_len = max(cache_position.shape[0], self.sliding_window) # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: attention_mask = attention_mask[:, -effective_seq_len:] # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice # from the left, with an offset if we are beyond the sliding window @@ -709,7 +709,7 @@ def _update_causal_mask( # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: return attention_mask dtype, device = input_tensor.dtype, input_tensor.device @@ -978,7 +978,7 @@ def prepare_inputs_for_generation( if ( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 - and not self.config._attn_implementation == "flash_attention_2" + and "flash_attention" not in self.config._attn_implementation ): if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 4e3c8487c4d8..654ae399b9ab 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -278,8 +278,8 @@ def forward( } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - # Here we need to slice as we use a static cache by default, but FA2 does not support it - if attention_mask is not None and self.config._attn_implementation == "flash_attention_2": + # Here we need to slice as we use a static cache by default, but FA does not support it + if attention_mask is not None and "flash_attention" in self.config._attn_implementation: seq_len = attention_mask.shape[-1] key_states, value_states = key_states[:, :, :seq_len, :], value_states[:, :, :seq_len, :] @@ -344,7 +344,7 @@ def forward( effective_seq_len = max(cache_position.shape[0], self.sliding_window) # For FA2, the mask is 2D and is of shape [bs, processed_tokens] (not [bs, max_cache_len]), # thus we must slice from the right (at most `effective_seq_len` elements) - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: attention_mask = attention_mask[:, -effective_seq_len:] # Otherwise, the mask is 4D of shape [bs, 1, query_len, max_cache_len] thus we must slice # from the left, with an offset if we are beyond the sliding window @@ -545,7 +545,7 @@ def _update_causal_mask( # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: return attention_mask dtype, device = input_tensor.dtype, input_tensor.device @@ -710,7 +710,7 @@ def prepare_inputs_for_generation( if ( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 - and not self.config._attn_implementation == "flash_attention_2" + and "flash_attention" not in self.config._attn_implementation ): if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 25c2220e3e04..1e0f7c0e0a3c 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -642,7 +642,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 0a93d3144380..60c1c29f994d 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -824,7 +824,7 @@ def forward( # Attention mask. _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None - if self._attn_implementation == "flash_attention_2": + if "flash_attention" in self._attn_implementation: attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif _use_sdpa: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -861,7 +861,7 @@ def forward( encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1] ) - elif not self._attn_implementation == "flash_attention_2": + elif "flash_attention" not in self._attn_implementation: encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_attention_mask = None diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index e218f1a63153..ea027d7b7df3 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -811,6 +811,7 @@ def __init__(self, config): self._use_sdpa = config._attn_implementation == "sdpa" self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" # Initialize weights and apply final processing self.post_init() @@ -892,7 +893,7 @@ def forward( key_length = past_length + query_length self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None encoder_attention_mask = ( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 498e1912c6d4..d17af991d331 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -792,7 +792,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 281f3db77592..1061d072771a 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -182,6 +182,7 @@ def forward( if (output_attentions or head_mask is not None) and self.config._attn_implementation in [ "sdpa", "flash_attention_2", + "flash_attention_3", ]: logger.warning_once( f"Setting `attention_type` to `eager` because `{attention_type}` does not support" @@ -636,7 +637,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 3a7cc49542ef..00ffd9ae1629 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -174,6 +174,7 @@ def forward( if (output_attentions or head_mask is not None) and self.config._attn_implementation in [ "sdpa", "flash_attention_2", + "flash_attention_3", ]: logger.warning_once( f"Setting `attention_type` to `eager` because `{attention_type}` does not support" diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index d5153fb3f828..d6cb34aec2fa 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -661,7 +661,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index f171214a3f08..5c6ca995c8f0 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -667,6 +667,7 @@ def __init__(self, config): self.post_init() self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -891,7 +892,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index ca6fb4a486d1..c63fa247307f 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -645,7 +645,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 538a480fd631..c84e856bcb6c 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1122,7 +1122,7 @@ def _update_causal_mask( # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 59becff89c40..90c92e443a46 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -629,7 +629,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 0bb500c10ca7..9a3536909bb0 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -940,6 +940,7 @@ def __init__(self, config): self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -956,7 +957,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1028,6 +1029,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1044,7 +1046,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 8c6f1f059bfc..adf62d4f28b4 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1362,7 +1362,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 8be9e187d7e4..a12f5772a84d 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -660,6 +660,7 @@ def __init__(self, config: Idefics2VisionConfig): self.encoder = Idefics2Encoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def get_input_embeddings(self): return self.embeddings @@ -701,7 +702,7 @@ def forward( # avoiding passing the attention_mask, which is equivalent to attending to the full sequence if not torch.any(~patch_attention_mask): patch_attention_mask = None - elif not self._use_flash_attention_2: + elif not self._use_flash_attention_2 and not self._use_flash_attention_3: patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( @@ -1105,6 +1106,7 @@ def __init__(self, config) -> None: self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1248,6 +1250,7 @@ def __init__(self, config: Idefics2Config): self.image_token_id = self.config.image_token_id self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config.text_config._attn_implementation == "flash_attention_3" self.post_init() diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index fb643f950821..b1f5d1bed2af 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -679,6 +679,7 @@ def __init__(self, config: Idefics3VisionConfig): self.patch_size = config.patch_size self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings def get_input_embeddings(self): @@ -722,7 +723,7 @@ def forward( # avoiding passing the attention_mask, which is equivalent to attending to the full sequence if not torch.any(~patch_attention_mask): patch_attention_mask = None - elif not self._use_flash_attention_2: + elif not self._use_flash_attention_2 and not self._use_flash_attention_3: patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( @@ -836,6 +837,7 @@ def __init__(self, config: Idefics3Config): self.image_token_id = self.config.image_token_id self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config.text_config._attn_implementation == "flash_attention_3" self.post_init() diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 9facc4a24c6a..601896d8604b 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1357,7 +1357,7 @@ def forward( ) def _update_causal_mask(self, attention_mask, input_tensor, cache_position): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 931b3d7fd82b..fe23c6b9c8e6 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1036,7 +1036,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + if attention_mask is not None and "flash_attention" in self._attn_implementation and use_cache: batch_size = inputs_embeds.shape[0] is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -1124,7 +1124,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index e86ade5d9d31..3ed964c7ff9b 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -631,7 +631,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 84ea0443d2f1..0630ac0488cb 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1599,7 +1599,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 0bd818dbad96..aff373092b9f 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -954,6 +954,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.gradient_checkpointing = False @@ -1035,7 +1036,7 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None elif self._use_sdpa and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to @@ -1136,6 +1137,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) @@ -1247,7 +1249,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers combined_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: @@ -1267,7 +1269,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on @@ -1407,9 +1409,9 @@ def __init__(self, config: M2M100Config): self.encoder = M2M100Encoder(config, self.shared) self.decoder = M2M100Decoder(config, self.shared) - if config._attn_implementation == "flash_attention_2": + if "flash_attention" in config._attn_implementation: logger.warning_once( - "Attention with Flash Attention 2 does not support `layer_head_mask`. If you need this feature, please use standard attention." + "Attention with Flash Attention 2/3 does not support `layer_head_mask`. If you need this feature, please use standard attention." ) # Initialize weights and apply final processing diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 3bed530b1f28..470fe1a59466 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1044,7 +1044,7 @@ def forward( # expand attention_mask if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: attention_mask = attention_mask if 0 in attention_mask else None elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to @@ -1261,7 +1261,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None: @@ -1281,7 +1281,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index c17c9da585d5..a0715ae79e24 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1068,7 +1068,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 69843b500f35..31ed19bbea69 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -603,7 +603,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index d1531c58a8a6..f7f350bf458f 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -120,7 +120,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 94eb64342cf1..b767374b55a2 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -737,7 +737,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 4a705083f3ba..b1783e76fc23 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1077,7 +1077,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 2fa5a08acc48..1e678873d88a 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -496,7 +496,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): rope_theta = config.local_rope_theta max_position_embeddings = config.local_attention - if config._attn_implementation == "flash_attention_2": + if "flash_attention" in config._attn_implementation: self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) @@ -516,7 +516,7 @@ def forward( qkv = self.Wqkv(hidden_states) bs = hidden_states.shape[0] - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) else: qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) @@ -935,7 +935,7 @@ def forward( attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) repad = False - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if indices is None and cu_seqlens is None and max_seqlen is None: repad = True if inputs_embeds is None: @@ -1110,7 +1110,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if indices is None and cu_seqlens is None and max_seqlen is None: if batch_size is None and seq_len is None: if inputs_embeds is not None: @@ -1169,7 +1169,7 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 5af53bb31014..0a92964e9c19 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -671,7 +671,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): rope_theta = config.local_rope_theta max_position_embeddings = config.local_attention - if config._attn_implementation == "flash_attention_2": + if "flash_attention" in config._attn_implementation: self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) @@ -691,7 +691,7 @@ def forward( qkv = self.Wqkv(hidden_states) bs = hidden_states.shape[0] - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) else: qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) @@ -1039,7 +1039,7 @@ def forward( attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) repad = False - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if indices is None and cu_seqlens is None and max_seqlen is None: repad = True if inputs_embeds is None: @@ -1214,7 +1214,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict self._maybe_set_compile() - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if indices is None and cu_seqlens is None and max_seqlen is None: if batch_size is None and seq_len is None: if inputs_embeds is not None: @@ -1273,7 +1273,7 @@ def forward( if labels is not None: loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size) - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad(): logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index e8b8194516e3..add992185914 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -660,7 +660,7 @@ def forward( mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1]) downsample_stride = 64 * 3 * 2 # conv strides attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len] - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: attention_mask = attention_mask if (attention_mask == 0.0).any() else None # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward @@ -916,7 +916,7 @@ def forward( mask_len = encoder_hidden_states.shape[-2] downsample_stride = 64 * 3 * 2 # conv strides encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len] - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward @@ -994,7 +994,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 8cf83eef0dcd..cf3b95fbd9b1 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -658,7 +658,7 @@ def forward( mask_len = self._get_feat_extract_output_lengths(attention_mask.shape[-1]) downsample_stride = 64 * 3 * 2 # conv strides attention_mask = attention_mask[..., ::downsample_stride][..., :mask_len] - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: attention_mask = attention_mask if (attention_mask == 0.0).any() else None # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward @@ -812,7 +812,7 @@ def forward( mask_len = encoder_hidden_states.shape[-2] downsample_stride = 64 * 3 * 2 # conv strides encoder_attention_mask = encoder_attention_mask[..., ::downsample_stride][..., :mask_len] - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: encoder_attention_mask = encoder_attention_mask if (encoder_attention_mask == 0.0).any() else None # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 2f9aec0ec960..593cc42bee5f 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1299,7 +1299,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: @@ -1613,7 +1613,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 129255a90b5b..3c53ee2caea1 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1191,7 +1191,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 6b75100ca3dc..09d63570d7d3 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1003,7 +1003,7 @@ def forward( if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) - if self.attn_implementation == "flash_attention_2": + if "flash_attention" in self.attn_implementation: attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on @@ -1021,7 +1021,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.attn_implementation == "flash_attention_2": + if "flash_attention" in self.attn_implementation: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 156bb4e16e0a..8a65cb9e4ebc 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -950,7 +950,7 @@ def forward( input_shape = inputs_embeds.size()[:-1] - if self.attn_implementation == "flash_attention_2": + if "flash_attention" in self.attn_implementation: attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self.attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 4cab850867e0..25df00ecb332 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -879,7 +879,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 9d972c42b850..e7123d84ab22 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -607,7 +607,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index cc540451c496..914dd69663ce 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -608,7 +608,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 1c28997861d9..c4af122caa7a 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -1039,7 +1039,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 66f87b831de0..e1a47dd93db0 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -646,6 +646,7 @@ def __init__(self, config: OPTConfig): self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.gradient_checkpointing = False @@ -672,7 +673,7 @@ def _update_causal_mask( """ batch_size, seq_length = input_shape mask_seq_length = past_key_values_length + seq_length - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = ( diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index b2f206d4fe89..46761bd65d8c 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -346,7 +346,7 @@ def _update_causal_mask( input_tensor, is_training: bool = False, ): - if self.config.text_config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config.text_config._attn_implementation: if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 9c589036815b..85140de345b2 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -678,7 +678,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 60b484ef3b4d..013d64d1084d 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -605,7 +605,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 27669d3cbc47..87400629db8b 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -673,7 +673,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 1f9059ac0dc0..7ffe499d2720 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1183,7 +1183,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 71cf2f255511..029983922dfc 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1586,7 +1586,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index e2f11d97b8fd..14018ed65d0f 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -686,6 +686,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -774,7 +775,7 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None elif self._use_sdpa and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to @@ -871,6 +872,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -991,7 +993,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: @@ -1011,7 +1013,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 351482a75e58..59cc0df8f677 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -999,7 +999,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 61a0b0a75f6c..d95ac0094b02 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -616,7 +616,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index f8ae1fc07484..546f34d790ea 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1250,7 +1250,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index f55977066440..09e1a429b78d 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1071,7 +1071,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 1e44b86106a4..041a3254d59d 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1196,7 +1196,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 94d83858a1a9..db0e3d41a12c 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -869,6 +869,7 @@ def __init__(self, config): self.upsample = SEWUpsampling(config) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -883,7 +884,7 @@ def forward( if attention_mask is not None: expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # make sure padded tokens output 0 hidden_states[~expand_attention_mask] = 0.0 # 2d mask is passed through the layers diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 9853b3585d0a..857644bef614 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -939,6 +939,7 @@ def __init__(self, config: SiglipTextConfig): self.head = nn.Linear(embed_dim, embed_dim) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) @@ -971,7 +972,7 @@ def forward( # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. # expand attention_mask - if attention_mask is not None and not self._use_flash_attention_2: + if attention_mask is not None and not self._use_flash_attention_2 and not self._use_flash_attention_3: # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 5bbfddd9b8e9..6ff43da0204a 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -934,7 +934,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 91936e8e384d..6a1dd4aaddbb 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -599,7 +599,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and past_key_values is not None: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index a09392c85671..b3ae5bb3afd9 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1135,7 +1135,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index a91c81ba79b7..4dc55a4585c9 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1204,7 +1204,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 80c6d37ba9ff..d7b6ddf53655 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1537,7 +1537,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 25d7a74eabfa..a379a992e8f5 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -848,7 +848,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 23cce095246e..2790ed280e1e 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -972,6 +972,7 @@ def __init__(self, config): self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -988,7 +989,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1060,6 +1061,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1076,7 +1078,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index de95114e79da..33bb7f6a5bec 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -989,6 +989,7 @@ def __init__(self, config): self.layers = nn.ModuleList([UniSpeechSatEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1005,7 +1006,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1077,6 +1078,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1093,7 +1095,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index f1e7a34da35f..de3e9fc2681e 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1005,6 +1005,7 @@ def __init__(self, config): self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1021,7 +1022,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1092,6 +1093,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1108,7 +1110,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 0e731ca3b160..ff753812f354 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1113,6 +1113,7 @@ def __init__(self, config: WhisperConfig): [WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)] ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) @@ -1375,7 +1376,7 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 3be45dd7fcfa..54f6f8cea929 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -897,6 +897,31 @@ def _check_and_enable_flash_attn_2( return config + @classmethod + @classmethod + def _check_and_enable_flash_attn_3( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + hard_check_only: bool = False, + check_device_map: bool = False, + ): + """ + Overloads `PreTrainedModel._check_and_enable_flash_attn_2` so as to DISABLE Flash Attention 3 by default on Zamba models. + Flash attention 2 is currently not supported in the HuggingFace implementation of Zamba v1. + """ + config = super()._check_and_enable_flash_attn_3( + config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map + ) + + # if using the default path -> swap sdpa by eager + if not hard_check_only and config._attn_implementation == "flash_attention_3": + config._attn_implementation = "eager" + + return config + + ZAMBA_INPUTS_DOCSTRING = r""" Args: @@ -1144,7 +1169,7 @@ def forward( # Copied from transformers.models.jamba.modeling_jamba.JambaModel._update_causal_mask def _update_causal_mask(self, attention_mask, input_tensor, cache_position): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 8a5642e0f5d1..315a92697abf 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -1532,7 +1532,7 @@ def forward( return output if return_dict else output.to_tuple() def _update_causal_mask(self, attention_mask, input_tensor, cache_position): - if self.config._attn_implementation == "flash_attention_2": + if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None From a5261895e0dc66f02eb647750536a202ec6235d9 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 14 Feb 2025 07:34:46 +0000 Subject: [PATCH 4/8] testing_utils --- src/transformers/testing_utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 14fef2988488..9abf389d4f6e 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -79,6 +79,7 @@ is_faiss_available, is_fbgemm_gpu_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flax_available, is_flute_available, is_fsdp_available, @@ -592,6 +593,14 @@ def require_flash_attn(test_case): return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) +def require_flash_attn_3(test_case): + """ + Decorator marking a test that requires Flash Attention 3. + These tests are skipped when Flash Attention 3 isn't installed. + """ + return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case) + + def require_torch_sdpa(test_case): """ Decorator marking a test that requires PyTorch's SDPA. From a9717e79099b894530f74079f7dbc9c5e5b8d643 Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 14 Feb 2025 07:35:07 +0000 Subject: [PATCH 5/8] make --- src/transformers/models/zamba/modeling_zamba.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 54f6f8cea929..756586ea3171 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -922,7 +922,6 @@ def _check_and_enable_flash_attn_3( return config - ZAMBA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): From 1b5f20c03d0390dace2b370983efad1ac006df6d Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 14 Feb 2025 07:58:22 +0000 Subject: [PATCH 6/8] sliding_window --- src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py | 2 +- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 546f34d790ea..2f6e7d521276 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1018,7 +1018,7 @@ def __init__(self, config: Qwen2_5_VLConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + if config.use_sliding_window and "flash_attention" not in config._attn_implementation: logger.warning_once( f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 041a3254d59d..5dd44d114092 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -830,7 +830,7 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + if config.use_sliding_window and "flash_attention" not in config._attn_implementation: logger.warning_once( f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." From ea85044ce7f531b30f3c8f0d94f34640ab8f2654 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 18 Feb 2025 09:49:54 +0000 Subject: [PATCH 7/8] Update modeling_granitemoe.py --- src/transformers/models/granitemoe/modeling_granitemoe.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index dd0011c7b1c4..1fb2ba700c18 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1118,11 +1118,6 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if "flash_attention" in self.config._attn_implementation: if attention_mask is not None and (attention_mask == 0.0).any() return attention_mask From af0d015a2d26960782062d2475fa54d8d893a938 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 18 Feb 2025 09:51:16 +0000 Subject: [PATCH 8/8] Update modeling_granitemoe.py --- src/transformers/models/granitemoe/modeling_granitemoe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 1fb2ba700c18..8821a0e8fd89 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1119,7 +1119,7 @@ def _update_causal_mask( output_attentions: bool, ): if "flash_attention" in self.config._attn_implementation: - if attention_mask is not None and (attention_mask == 0.0).any() + if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None