diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index a3ca4bea484d..a78166ed040b 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -2,11 +2,10 @@ import torch -from ..modeling_flash_attention_utils import _flash_attention_forward -from ..utils import is_flash_attn_greater_or_equal_2_10 +from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask -_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10() +_use_top_left_mask = flash_attn_supports_top_left_mask() def flash_attention_forward( diff --git a/src/transformers/integrations/npu_flash_attention.py b/src/transformers/integrations/npu_flash_attention.py new file mode 100644 index 000000000000..fa26379126fc --- /dev/null +++ b/src/transformers/integrations/npu_flash_attention.py @@ -0,0 +1,233 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +import torch.nn.functional as F + +from ..utils.import_utils import is_torch_npu_available + + +if is_torch_npu_available(): + import torch_npu + from einops import rearrange, repeat + + +# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default. +# Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask. +TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2 +DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3 + +SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE)) +if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]: + raise ValueError( + "Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) " + "or 3 (down-right aligned causal mask)." + ) + + +def is_npu_fa2_top_left_aligned_causal_mask(): + return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False + + +# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather( + rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + # dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py +def unpad_input(hidden_states, attention_mask, unused_mask=None): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. + indices: (total_nnz), the indices of masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. + """ + all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask + seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) + used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + used_seqlens_in_batch, + ) + + +def npu_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + **kwargs, +): + keep_prob = 1.0 - dropout_p + + if not causal: + head_num = q.shape[2] + output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0] + else: + attn_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(q.device) + head_num = q.shape[2] + output = torch_npu.npu_fusion_attention( + q, + k, + v, + head_num, + "BSND", + keep_prob=keep_prob, + scale=softmax_scale, + atten_mask=attn_mask_npu, + sparse_mode=SPARSE_MODE, + )[0] + + return output + + +def npu_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + **kwargs, +): + keep_prob = 1.0 - dropout_p + + if not causal: + head_num = q.shape[1] + output = torch_npu.npu_fusion_attention( + q, + k, + v, + head_num, + pse=None, + atten_mask=None, + scale=softmax_scale, + keep_prob=keep_prob, + input_layout="TND", + actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()), + actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()), + )[0] + else: + attn_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(q.device) + head_num = q.shape[1] + output = torch_npu.npu_fusion_attention( + q, + k, + v, + head_num, + pse=None, + padding_mask=None, + atten_mask=attn_mask_npu, + scale=softmax_scale, + keep_prob=keep_prob, + input_layout="TND", + actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()), + actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()), + sparse_mode=SPARSE_MODE, + )[0] + + return output diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 4da8b451f1af..70516946c894 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -19,19 +19,68 @@ import torch import torch.nn.functional as F -from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal, logging +from .utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal, + is_flash_attn_greater_or_equal_2_10, + is_torch_npu_available, + logging, +) logger = logging.get_logger(__name__) +flash_attn_func = None if is_flash_attn_2_available(): from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.layers.rotary import apply_rotary_emb # noqa + +# patch functions in package `flash-attn` when using flash-attention on Ascend NPU. +if is_torch_npu_available(): + from torch_npu import npu_rotary_mul as apply_rotary_emb # noqa + + from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input + from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func + from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func + + +if flash_attn_func: _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) +def is_flash_attn_available(): + """Determine whether flash-attention can be used or not.""" + + # if package `flash-attn` is available, flash-attention can be used natively. + if is_flash_attn_2_available(): + return True + + # flash-attention can be used on Ascend NPU without package `flash-attn` + if is_torch_npu_available(): + return True + + return False + + +def flash_attn_supports_top_left_mask(): + """Determine whether flash-attention uses top-left or down-right mask""" + + if is_flash_attn_2_available(): + # top-left mask is used in package `flash-attn` with version lower than 2.1.0 + return not is_flash_attn_greater_or_equal_2_10() + + if is_torch_npu_available(): + # down-right mask is used on Ascend NPU by default, set env `NPU_FA2_SPARSE_MODE=2` to activate top-left mask. + from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask + + return is_npu_fa2_top_left_aligned_causal_mask() + + return False + + def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]: """ Retrieves indexing data required to repad unpadded (ragged) tensors. diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c279ae391c1a..98aa3206a7a3 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2276,11 +2276,13 @@ def _check_and_enable_flash_attn_2( install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." if importlib.util.find_spec("flash_attn") is None: + # package `flash-attn` can not be installed on Ascend NPU, ignore related validation logic and early exit. if is_torch_npu_available(): - recommend_message_npu = "You should use attn_implementation='sdpa' instead when using NPU. " - raise ImportError( - f"{preface} the package flash_attn is not supported on Ascend NPU. {recommend_message_npu}" - ) + if not hard_check_only: + config._attn_implementation = "flash_attention_2" + + logger.info("Detect using FlashAttention2 on Ascend NPU.") + return config else: raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index e9a2c5f81a9c..06986e834c38 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -47,10 +47,7 @@ replace_return_docstrings, ) from ...utils.deprecation import deprecate_kwarg -from ...utils.import_utils import ( - is_causal_conv1d_available, - is_mamba_2_ssm_available, -) +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_bamba import BambaConfig diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index f287234893ea..ccfae9505882 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -29,14 +29,13 @@ SuppressTokensLogitsProcessor, ) from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput from ...modeling_utils import PreTrainedModel, get_parameter_device from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, is_accelerate_available, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, ) from ..auto import AutoModel @@ -54,7 +53,7 @@ ) -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -203,7 +202,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _split_heads(self, tensor, num_heads, attn_head_size): """ diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index ac9945df71b7..a51952849183 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -32,6 +32,7 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -47,15 +48,13 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_bart import BartConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -300,7 +299,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 33193365359f..aff97487bb1e 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -28,7 +28,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -39,8 +39,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, is_torchdynamo_compiling, logging, @@ -55,10 +53,6 @@ from ...integrations.flex_attention import make_flex_block_causal_mask -if is_flash_attn_2_available(): - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "ChameleonConfig" @@ -385,7 +379,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() # Ignore copy def forward( diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index c13d9097e6ce..f6da323a8a08 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import is_torch_greater_or_equal_than_2_2 @@ -32,8 +33,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, torch_int, @@ -41,7 +40,7 @@ from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -414,7 +413,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward def forward( diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 4d226567b565..df312b047d01 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -40,15 +41,13 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_peft_available, logging, ) from .configuration_data2vec_audio import Data2VecAudioConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) @@ -495,7 +494,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 08f11ed6ffd4..5303f84edcb5 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -25,13 +25,12 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -46,7 +45,7 @@ from ...integrations.flex_attention import make_flex_block_causal_mask -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) @@ -331,7 +330,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index c86fffad7aab..7cadf5673ff3 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -32,7 +32,11 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import FlashAttentionKwargs, _flash_attention_forward +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, + _flash_attention_forward, + flash_attn_supports_top_left_mask, +) from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -48,7 +52,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -254,7 +257,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index e116db288246..68102347904a 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -23,11 +23,8 @@ from torch import nn from ...cache_utils import Cache, StaticCache -from ...modeling_flash_attention_utils import _flash_attention_forward -from ...utils import ( - is_flash_attn_greater_or_equal_2_10, - logging, -) +from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ...utils import logging from ..gemma.modeling_gemma import GemmaForCausalLM from ..llama.modeling_llama import ( LlamaDecoderLayer, @@ -176,7 +173,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index c869355a046b..b78050a01aeb 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -30,6 +30,7 @@ from ...configuration_utils import PretrainedConfig from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -49,15 +50,13 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_distilbert import DistilBertConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -251,7 +250,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index 792d48268127..5fc8e43afa15 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -32,7 +32,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, logging, replace_return_docstrings, ) @@ -50,10 +49,6 @@ from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig -if is_flash_attn_2_available(): - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - - _CONFIG_FOR_DOC = "Emu3Config" _CHECKPOINT_FOR_DOC = "BAAI/Emu3-Chat-hf" diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 61c27f5f2967..e6cb36353dc5 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -29,6 +29,7 @@ from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -42,8 +43,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, ) from ...utils.deprecation import deprecate_kwarg @@ -53,7 +52,7 @@ if TYPE_CHECKING: from ...configuration_utils import PretrainedConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) @@ -472,7 +471,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index d388f70a777b..d9a84d6f463d 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -36,14 +37,12 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, ) from .configuration_gpt_bigcode import GPTBigCodeConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -284,7 +283,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index ff963fd665df..12e45d918d88 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, @@ -40,8 +41,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, is_torch_fx_available, logging, @@ -55,7 +54,7 @@ from ...integrations.flex_attention import make_flex_block_causal_mask -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -287,7 +286,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 57feb2c53607..fca8b0e5a167 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -38,8 +39,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, is_torch_fx_proxy, logging, @@ -54,7 +53,7 @@ from ...integrations.flex_attention import make_flex_block_causal_mask -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -279,7 +278,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 192c99df7f57..c05a10504085 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -24,7 +24,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask from ...modeling_outputs import ( BaseModelOutputWithPast, MoeCausalLMOutputWithPast, @@ -36,7 +36,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -534,7 +533,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index d97620e62ba3..fd2cdf2843d3 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -29,14 +29,13 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -411,7 +410,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 333c6df5dbcb..3f6e3a5a8739 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -26,21 +26,20 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_hubert import HubertConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -569,7 +568,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 603ab05d7990..81b77e57c116 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -27,13 +27,12 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -42,7 +41,7 @@ from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -279,7 +278,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, @@ -874,7 +873,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() # Ignore copy def forward( diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 82401437e23c..c3644ab6c785 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -26,13 +26,12 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -40,7 +39,7 @@ from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -279,7 +278,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index de8780fe09e0..4fba6ba89b9b 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -33,6 +33,7 @@ from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, @@ -48,14 +49,12 @@ from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_mamba_ssm_available, ) from .configuration_jamba import JambaConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -391,7 +390,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 88a003a44cf8..c7457db3e9b5 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, @@ -36,8 +37,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -52,7 +51,7 @@ from ...integrations.flex_attention import make_flex_block_causal_mask -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) @@ -685,7 +684,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index e16a299dfcd2..cba5ead876c2 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -31,6 +31,7 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -43,15 +44,13 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_m2m_100 import M2M100Config -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -354,7 +353,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index c1cd512f85b6..522c5c8cfce7 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -31,6 +31,7 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -46,15 +47,13 @@ add_end_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_mbart import MBartConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -297,7 +296,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index b77203ca0ed0..4746755a4f51 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel @@ -32,15 +33,13 @@ ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_mimi import MimiConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) @@ -604,7 +603,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 193f43054cc2..cc352e76fb5f 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -30,6 +30,7 @@ GenerationMixin, ) from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -42,8 +43,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -53,7 +52,7 @@ from .configuration_moshi import MoshiConfig, MoshiDepthConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) @@ -573,7 +572,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index e424845245f4..add0eeba6544 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -40,6 +40,7 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -51,8 +52,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -61,7 +60,7 @@ from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward if TYPE_CHECKING: @@ -330,7 +329,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index ec6074f48a96..08af10d1a3d8 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -35,6 +35,7 @@ StoppingCriteriaList, ) from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPast, ModelOutput, @@ -43,8 +44,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -53,7 +52,7 @@ from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward if TYPE_CHECKING: @@ -346,7 +345,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 0ff0544debce..b1cae89b0bae 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -27,7 +27,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -42,7 +42,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -324,7 +323,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() # Ignore copy def forward( diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 009f7dc57ab0..e6246ac6001a 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -23,6 +23,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, @@ -33,8 +34,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -42,7 +41,7 @@ from .configuration_olmoe import OlmoeConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -404,7 +403,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 5ab6efd7dd39..729f71a49aec 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -27,6 +27,7 @@ from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -38,8 +39,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -53,7 +52,7 @@ from ...integrations.flex_attention import make_flex_block_causal_mask -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -208,7 +207,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 238658add5f7..1065e801a455 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -28,19 +28,13 @@ ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, is_torchdynamo_compiling, logging, replace_return_docstrings, ) from ...utils.deprecation import deprecate_kwarg -from .configuration_paligemma import PaliGemmaConfig - - -if is_flash_attn_2_available(): - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa - from ..auto import AutoModel, AutoModelForCausalLM +from .configuration_paligemma import PaliGemmaConfig logger = logging.get_logger(__name__) diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 8fe0137057dd..7701a76e6c7e 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask +from ...modeling_flash_attention_utils import is_flash_attn_available from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, @@ -36,7 +37,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, logging, replace_return_docstrings, ) @@ -45,7 +45,7 @@ from .configuration_phimoe import PhimoeConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 095a0769ed3b..62fb093bc91d 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -37,33 +37,20 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func - from flash_attn.layers.rotary import apply_rotary_emb - -else: - flash_attn_varlen_func = None - apply_rotary_emb = None +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward -else: - flash_attn_varlen_func = None logger = logging.get_logger(__name__) @@ -819,7 +806,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index a77344c976d9..10100d8f3336 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -47,18 +47,14 @@ from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, VideoInput +from ...modeling_flash_attention_utils import is_flash_attn_available from ...processing_utils import ProcessingKwargs, Unpack, VideosKwargs from ...tokenization_utils_base import PreTokenizedInput, TextInput -from ...utils import is_flash_attn_2_available, logging +from ...utils import logging -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func - from flash_attn.layers.rotary import apply_rotary_emb - -else: - flash_attn_varlen_func = None - apply_rotary_emb = None +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func logger = logging.get_logger(__name__) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 8a9122c48708..37d38c429a7c 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -25,13 +25,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -40,7 +39,7 @@ from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -210,7 +209,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() @deprecate_kwarg("key_value_states", version="4.52") @deprecate_kwarg("past_key_value", version="4.52") diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 5c6fc1c71559..40263884b130 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -31,6 +31,7 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( MoeCausalLMOutputWithPast, MoeModelOutputWithPast, @@ -44,8 +45,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -53,7 +52,7 @@ from .configuration_qwen2_moe import Qwen2MoeConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) @@ -412,7 +411,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 2bcfd3608ded..b0c485a9e9e2 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -33,26 +33,21 @@ from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig -if is_flash_attn_2_available(): - from flash_attn import flash_attn_varlen_func - - from ...modeling_flash_attention_utils import _flash_attention_forward -else: - flash_attn_varlen_func = None +if is_flash_attn_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_varlen_func logger = logging.get_logger(__name__) @@ -632,7 +627,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index bff2c96667d5..069691cee293 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, Optional, Tuple, Union import torch @@ -588,7 +589,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 117356dffd48..a697289993cb 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -685,7 +686,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, causal_mask, position_ids, diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 0ef4852fa5f2..e16a17dbf2bc 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -27,20 +27,19 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, ) from .configuration_sew import SEWConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -569,7 +568,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 2da27ae36a99..34b0ee370cd3 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -28,14 +28,13 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, torch_int, @@ -43,7 +42,7 @@ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -451,7 +450,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward def forward( diff --git a/src/transformers/models/siglip2/modeling_siglip2.py b/src/transformers/models/siglip2/modeling_siglip2.py index 922bc9304dd6..4fbe92148380 100644 --- a/src/transformers/models/siglip2/modeling_siglip2.py +++ b/src/transformers/models/siglip2/modeling_siglip2.py @@ -32,21 +32,20 @@ from ...activations import ACT2FN from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -342,7 +341,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward def forward( diff --git a/src/transformers/models/smolvlm/modeling_smolvlm.py b/src/transformers/models/smolvlm/modeling_smolvlm.py index 3e34025fac90..a0f43033d721 100644 --- a/src/transformers/models/smolvlm/modeling_smolvlm.py +++ b/src/transformers/models/smolvlm/modeling_smolvlm.py @@ -30,13 +30,12 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) @@ -44,7 +43,7 @@ from .configuration_smolvlm import SmolVLMConfig, SmolVLMVisionConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -253,7 +252,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 84f4b7f82d75..a11e0627b8f1 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -30,6 +30,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -42,8 +43,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -58,7 +57,7 @@ from ...integrations.flex_attention import make_flex_block_causal_mask -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -461,7 +460,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 46ed990b29bd..7172618182a4 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput, Wav2Vec2BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -35,15 +36,13 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, ) from .configuration_unispeech import UniSpeechConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -601,7 +600,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index f19d5f3789ff..0472cde71565 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -42,8 +43,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_peft_available, logging, replace_return_docstrings, @@ -51,7 +50,7 @@ from .configuration_unispeech_sat import UniSpeechSatConfig -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -618,7 +617,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index b7b37e47a670..be6f5eeda6e3 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -44,8 +45,6 @@ add_start_docstrings, add_start_docstrings_to_model_forward, cached_file, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_peft_available, is_safetensors_available, logging, @@ -61,7 +60,7 @@ from safetensors.torch import load_file as safe_load_file -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) @@ -664,7 +663,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 5f7cf59e8e9f..5c85dc8f0128 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -27,6 +27,7 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -39,8 +40,6 @@ from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, is_torch_flex_attn_available, logging, replace_return_docstrings, @@ -54,7 +53,7 @@ from ...integrations.flex_attention import make_flex_block_causal_mask -if is_flash_attn_2_available(): +if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -366,7 +365,7 @@ def __init__(self, *args, **kwargs): # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index f98403ded287..5aa01be0d7c6 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1016,11 +1016,6 @@ def is_flash_attn_2_available(): if not (torch.cuda.is_available() or is_torch_mlu_available()): return False - # Ascend does not support "flash_attn". - # If "flash_attn" is left in the env, is_flash_attn_2_available() should return False. - if is_torch_npu_available(): - return False - if torch.version.cuda: return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") elif torch.version.hip: diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 212969ce150a..8abef0422705 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -48,6 +48,7 @@ is_torch_available, logging, ) +from transformers.modeling_flash_attention_utils import is_flash_attn_available from transformers.testing_utils import ( TOKEN, CaptureLogger, @@ -79,6 +80,7 @@ is_flash_attn_2_available, is_flax_available, is_tf_available, + is_torch_npu_available, is_torch_sdpa_available, is_torchdynamo_available, ) @@ -653,7 +655,7 @@ def test_model_from_pretrained_attn_implementation(self): if is_torch_sdpa_available(): attn_implementation_available.append("sdpa") - if is_flash_attn_2_available(): + if is_flash_attn_available(): attn_implementation_available.append("flash_attention_2") for requested_attn_implementation in attn_implementation_available: @@ -677,7 +679,7 @@ def test_model_from_config_attn_implementation(self): if is_torch_sdpa_available(): attn_implementation_available.append("sdpa") - if is_flash_attn_2_available(): + if is_flash_attn_available(): attn_implementation_available.append("flash_attention_2") for requested_attn_implementation in attn_implementation_available: @@ -2676,6 +2678,11 @@ def test_not_available_flash(self): if is_flash_attn_2_available(): self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash") + if is_torch_npu_available(): + self.skipTest( + reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case." + ) + with self.assertRaises(ImportError) as cm: _ = AutoModel.from_pretrained( "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_2" @@ -2686,6 +2693,11 @@ def test_not_available_flash_with_config(self): if is_flash_attn_2_available(): self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash") + if is_torch_npu_available(): + self.skipTest( + reason="FlashAttention2 is supported on Ascend NPU without using package `flash-attn`, ignore this test case." + ) + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel") with self.assertRaises(ImportError) as cm: