Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/transformers/integrations/npu_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import torch_npu
from einops import rearrange, repeat
from torch_npu import npu_rotary_mul


# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
Expand Down Expand Up @@ -247,3 +248,19 @@ def npu_flash_attn_varlen_func(
)[0]

return output


def npu_apply_rotary_emb(x, cos, sin, **kwargs):
# cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2:
cos = cos.repeat(1, 2)
# cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
cos = cos.unsqueeze(0).unsqueeze(2)

# sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2:
sin = sin.repeat(1, 2)
# sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
sin = sin.unsqueeze(0).unsqueeze(2)

return npu_rotary_mul(x, cos, sin)
3 changes: 1 addition & 2 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@

# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
if is_torch_npu_available():
from torch_npu import npu_rotary_mul as apply_rotary_emb # noqa

from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func

Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
Expand All @@ -31,11 +32,11 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
from ...utils import auto_docstring, logging
from .configuration_esm import EsmConfig


if is_flash_attn_2_available():
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward


Expand Down Expand Up @@ -413,7 +414,7 @@ def __init__(self, config, position_embedding_type=None):
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
self.dropout_prob = config.attention_probs_dropout_prob

def forward(
Expand Down
21 changes: 9 additions & 12 deletions src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,17 @@
from ...cache_utils import Cache, DynamicCache
from ...generation import GenerationMixin
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_flash_attention_utils import (
FlashAttentionKwargs,
flash_attn_supports_top_left_mask,
is_flash_attn_available,
)
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
check_torch_load_is_safe,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from ...utils import auto_docstring, check_torch_load_is_safe, logging
from ...utils.hub import cached_file
from .configuration_qwen2_5_omni import (
Qwen2_5OmniAudioEncoderConfig,
Expand All @@ -61,9 +59,8 @@
)


if is_flash_attn_2_available():
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb
if is_flash_attn_available():
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
else:
flash_attn_varlen_func = None
apply_rotary_emb = None
Expand Down Expand Up @@ -653,7 +650,7 @@ def __init__(self, *args, **kwargs):
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()

def forward(
self,
Expand Down
10 changes: 4 additions & 6 deletions src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,20 @@

from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...generation import GenerationMixin
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_rope_utils import rope_config_validation
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...utils import (
auto_docstring,
check_torch_load_is_safe,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from ...utils.hub import cached_file


if is_flash_attn_2_available():
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb
if is_flash_attn_available():
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
else:
flash_attn_varlen_func = None
apply_rotary_emb = None
Expand Down Expand Up @@ -1667,7 +1665,7 @@ def __init__(self, *args, **kwargs):
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()

def forward(
self,
Expand Down