Skip to content

Commit 1d0c166

Browse files
FightingZhenbvantuan
authored andcommitted
[bugfix] [WIP] fix apply_rotary_emb error on Ascend NPU (huggingface#38491)
[bugfix] fix apply_rotary_emb error on Ascend NPU
1 parent 567dddf commit 1d0c166

File tree

5 files changed

+35
-23
lines changed

5 files changed

+35
-23
lines changed

src/transformers/integrations/npu_flash_attention.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import torch_npu
2525
from einops import rearrange, repeat
26+
from torch_npu import npu_rotary_mul
2627

2728

2829
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
@@ -247,3 +248,19 @@ def npu_flash_attn_varlen_func(
247248
)[0]
248249

249250
return output
251+
252+
253+
def npu_apply_rotary_emb(x, cos, sin, **kwargs):
254+
# cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
255+
if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2:
256+
cos = cos.repeat(1, 2)
257+
# cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
258+
cos = cos.unsqueeze(0).unsqueeze(2)
259+
260+
# sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
261+
if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2:
262+
sin = sin.repeat(1, 2)
263+
# sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
264+
sin = sin.unsqueeze(0).unsqueeze(2)
265+
266+
return npu_rotary_mul(x, cos, sin)

src/transformers/modeling_flash_attention_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@
4040

4141
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
4242
if is_torch_npu_available():
43-
from torch_npu import npu_rotary_mul as apply_rotary_emb # noqa
44-
4543
from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input
44+
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa
4645
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
4746
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
4847

src/transformers/models/esm/modeling_esm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch import nn
2424
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2525

26+
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
2627
from ...modeling_outputs import (
2728
BaseModelOutputWithPastAndCrossAttentions,
2829
BaseModelOutputWithPoolingAndCrossAttentions,
@@ -31,11 +32,11 @@
3132
TokenClassifierOutput,
3233
)
3334
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
34-
from ...utils import auto_docstring, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
35+
from ...utils import auto_docstring, logging
3536
from .configuration_esm import EsmConfig
3637

3738

38-
if is_flash_attn_2_available():
39+
if is_flash_attn_available():
3940
from ...modeling_flash_attention_utils import _flash_attention_forward
4041

4142

@@ -413,7 +414,7 @@ def __init__(self, config, position_embedding_type=None):
413414
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
414415
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
415416
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
416-
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
417+
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
417418
self.dropout_prob = config.attention_probs_dropout_prob
418419

419420
def forward(

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,17 @@
3434
from ...cache_utils import Cache, DynamicCache
3535
from ...generation import GenerationMixin
3636
from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
37-
from ...modeling_flash_attention_utils import FlashAttentionKwargs
37+
from ...modeling_flash_attention_utils import (
38+
FlashAttentionKwargs,
39+
flash_attn_supports_top_left_mask,
40+
is_flash_attn_available,
41+
)
3842
from ...modeling_layers import GradientCheckpointingLayer
3943
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput
4044
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
4145
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
4246
from ...processing_utils import Unpack
43-
from ...utils import (
44-
auto_docstring,
45-
check_torch_load_is_safe,
46-
is_flash_attn_2_available,
47-
is_flash_attn_greater_or_equal_2_10,
48-
logging,
49-
)
47+
from ...utils import auto_docstring, check_torch_load_is_safe, logging
5048
from ...utils.hub import cached_file
5149
from .configuration_qwen2_5_omni import (
5250
Qwen2_5OmniAudioEncoderConfig,
@@ -61,9 +59,8 @@
6159
)
6260

6361

64-
if is_flash_attn_2_available():
65-
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
66-
from flash_attn.layers.rotary import apply_rotary_emb
62+
if is_flash_attn_available():
63+
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
6764
else:
6865
flash_attn_varlen_func = None
6966
apply_rotary_emb = None
@@ -653,7 +650,7 @@ def __init__(self, *args, **kwargs):
653650
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
654651
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
655652
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
656-
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
653+
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
657654

658655
def forward(
659656
self,

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,20 @@
4343

4444
from ...configuration_utils import PretrainedConfig, layer_type_validation
4545
from ...generation import GenerationMixin
46+
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
4647
from ...modeling_outputs import BaseModelOutput, ModelOutput
4748
from ...modeling_rope_utils import rope_config_validation
4849
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
4950
from ...utils import (
5051
auto_docstring,
5152
check_torch_load_is_safe,
52-
is_flash_attn_2_available,
53-
is_flash_attn_greater_or_equal_2_10,
5453
logging,
5554
)
5655
from ...utils.hub import cached_file
5756

5857

59-
if is_flash_attn_2_available():
60-
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
61-
from flash_attn.layers.rotary import apply_rotary_emb
58+
if is_flash_attn_available():
59+
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
6260
else:
6361
flash_attn_varlen_func = None
6462
apply_rotary_emb = None
@@ -1667,7 +1665,7 @@ def __init__(self, *args, **kwargs):
16671665
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
16681666
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
16691667
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
1670-
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
1668+
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
16711669

16721670
def forward(
16731671
self,

0 commit comments

Comments
 (0)