Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import torch

from ..modeling_flash_attention_utils import _flash_attention_forward
from ..utils import is_flash_attn_greater_or_equal_2_10
from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask


_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
_use_top_left_mask = flash_attn_supports_top_left_mask()


def flash_attention_forward(
Expand Down
233 changes: 233 additions & 0 deletions src/transformers/integrations/npu_flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
import torch.nn.functional as F

from ..utils.import_utils import is_torch_npu_available


if is_torch_npu_available():
import torch_npu
from einops import rearrange, repeat


# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
# Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask.
TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2
DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3

SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE))
if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]:
raise ValueError(
"Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) "
"or 3 (down-right aligned causal mask)."
)


def is_npu_fa2_top_left_aligned_causal_mask():
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False


# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices):
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
second_dim = other_shape.numel()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return torch.gather(
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
).reshape(-1, *other_shape)

@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, "b ... -> b (...)")
grad_input = torch.zeros(
[ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None


index_first_axis = IndexFirstAxis.apply


# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values, indices, first_axis_dim):
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output[indices] = values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
return output

@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values = grad_output[indices]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
return grad_values, None, None


index_put_first_axis = IndexPutFirstAxis.apply


# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
def pad_input(hidden_states, indices, batch, seqlen):
"""
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
batch: int, batch size for the padded sequence.
seqlen: int, maximum sequence length for the padded sequence.
Return:
hidden_states: (batch, seqlen, ...)
"""
# dim = hidden_states.shape[-1]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, "(b s) ... -> b s ...", b=batch)


# Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/bert_padding.py
def unpad_input(hidden_states, attention_mask, unused_mask=None):
"""
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
Return:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
"""
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return (
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
indices,
cu_seqlens,
max_seqlen_in_batch,
used_seqlens_in_batch,
)


def npu_flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
**kwargs,
):
keep_prob = 1.0 - dropout_p

if not causal:
head_num = q.shape[2]
output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
else:
attn_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(q.device)
head_num = q.shape[2]
output = torch_npu.npu_fusion_attention(
q,
k,
v,
head_num,
"BSND",
keep_prob=keep_prob,
scale=softmax_scale,
atten_mask=attn_mask_npu,
sparse_mode=SPARSE_MODE,
)[0]

return output


def npu_flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
**kwargs,
):
keep_prob = 1.0 - dropout_p

if not causal:
head_num = q.shape[1]
output = torch_npu.npu_fusion_attention(
q,
k,
v,
head_num,
pse=None,
atten_mask=None,
scale=softmax_scale,
keep_prob=keep_prob,
input_layout="TND",
actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
)[0]
else:
attn_mask_npu = torch.triu(torch.ones([2048, 2048]), diagonal=1).bool().to(q.device)
head_num = q.shape[1]
output = torch_npu.npu_fusion_attention(
q,
k,
v,
head_num,
pse=None,
padding_mask=None,
atten_mask=attn_mask_npu,
scale=softmax_scale,
keep_prob=keep_prob,
input_layout="TND",
actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
sparse_mode=SPARSE_MODE,
)[0]

return output
51 changes: 50 additions & 1 deletion src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,68 @@
import torch
import torch.nn.functional as F

from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal, logging
from .utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torch_npu_available,
logging,
)


logger = logging.get_logger(__name__)
flash_attn_func = None


if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb # noqa


# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
if is_torch_npu_available():
from torch_npu import npu_rotary_mul as apply_rotary_emb # noqa

from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func


if flash_attn_func:
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)


def is_flash_attn_available():
"""Determine whether flash-attention can be used or not."""

# if package `flash-attn` is available, flash-attention can be used natively.
if is_flash_attn_2_available():
return True

# flash-attention can be used on Ascend NPU without package `flash-attn`
if is_torch_npu_available():
return True

return False


def flash_attn_supports_top_left_mask():
"""Determine whether flash-attention uses top-left or down-right mask"""

if is_flash_attn_2_available():
# top-left mask is used in package `flash-attn` with version lower than 2.1.0
return not is_flash_attn_greater_or_equal_2_10()

if is_torch_npu_available():
# down-right mask is used on Ascend NPU by default, set env `NPU_FA2_SPARSE_MODE=2` to activate top-left mask.
from .integrations.npu_flash_attention import is_npu_fa2_top_left_aligned_causal_mask

return is_npu_fa2_top_left_aligned_causal_mask()

return False


def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, int]:
"""
Retrieves indexing data required to repad unpadded (ragged) tensors.
Expand Down
10 changes: 6 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2276,11 +2276,13 @@ def _check_and_enable_flash_attn_2(
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."

if importlib.util.find_spec("flash_attn") is None:
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logic and early exit.
if is_torch_npu_available():
recommend_message_npu = "You should use attn_implementation='sdpa' instead when using NPU. "
raise ImportError(
f"{preface} the package flash_attn is not supported on Ascend NPU. {recommend_message_npu}"
)
if not hard_check_only:
config._attn_implementation = "flash_attention_2"

logger.info("Detect using FlashAttention2 on Ascend NPU.")
return config
else:
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")

Expand Down
5 changes: 1 addition & 4 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ...utils.import_utils import (
is_causal_conv1d_available,
is_mamba_2_ssm_available,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_bamba import BambaConfig


Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@
SuppressTokensLogitsProcessor,
)
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, get_parameter_device
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
from ..auto import AutoModel
Expand All @@ -54,7 +53,7 @@
)


if is_flash_attn_2_available():
if is_flash_attn_available():
from ...modeling_flash_attention_utils import _flash_attention_forward


Expand Down Expand Up @@ -203,7 +202,7 @@ def __init__(self, *args, **kwargs):
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Expand Down
Loading