Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer] support flash att by 'torch scaled dot attention' #2351

Merged
merged 8 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ flake8-pyi==20.5.0
mccabe
pycodestyle==2.6.0
pyflakes==2.2.0
torch==2.1.2
torchaudio==2.1.2
torch>=2.1.2
torchaudio>=2.1.2
tqdm
deepspeed<0.13.0
librosa
Expand Down
50 changes: 0 additions & 50 deletions test/wenet/ssl/w2vbert/test_w2vbert.py

This file was deleted.

111 changes: 111 additions & 0 deletions test/wenet/transformer/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch
import pytest
from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES

from wenet.utils.mask import add_optional_chunk_mask, make_non_pad_mask


@pytest.mark.parametrize("args", [
{
"n_feat": 256,
"n_head": 4,
"dropout_rate": 0.0
},
{
"n_feat": 512,
"n_head": 8,
"dropout_rate": 0.0
},
{
"n_feat": 1280,
"n_head": 20,
"dropout_rate": 0.0
},
{
"n_feat": 512,
"n_head": 4,
"dropout_rate": 0.0
},
])
def test_sdpa(args):
torch.manual_seed(777)
mha_module = MultiHeadedAttention(use_sdpa=False, **args)
torch.manual_seed(777)
mha_module_with_sdpa = MultiHeadedAttention(use_sdpa=True, **args)
mha_module.eval()
mha_module_with_sdpa.eval()

q = torch.rand(10, 100, args['n_feat'])
k = torch.rand(10, 100, args['n_feat'])
v = torch.rand(10, 100, args['n_feat'])
input_lens = torch.tensor([100, 90, 80, 79, 60, 51, 40, 30, 10, 5])
mask = make_non_pad_mask(input_lens).unsqueeze(1)
att_mask = add_optional_chunk_mask(q,
mask,
use_dynamic_chunk=True,
decoding_chunk_size=0,
static_chunk_size=0,
use_dynamic_left_chunk=True,
num_decoding_left_chunks=-1)
output, cache = mha_module(q, k, v, mask=att_mask)

att_mask_bias = (1.0 - att_mask.float()) * torch.finfo(torch.float).min
output_with_sdpa, cache_with_sdpa = mha_module_with_sdpa(
q, k, v, mask=att_mask_bias)
assert torch.allclose(
output * mask.transpose(1, 2),
output_with_sdpa * mask.transpose(1, 2),
atol=9e-7,
)
assert torch.allclose(cache, cache_with_sdpa)

n_blocks = 12
torch.manual_seed(777)
mha_layers = [
TransformerEncoderLayer(
args['n_feat'],
MultiHeadedAttention(use_sdpa=False, **args),
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
normalize_before=True,
) for _ in range(n_blocks)
]

torch.manual_seed(777)
mha_layers_with_sdpa = [
TransformerEncoderLayer(
args['n_feat'],
MultiHeadedAttention(use_sdpa=True, **args),
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
normalize_before=True,
) for _ in range(n_blocks)
]

for i in range(n_blocks):
output, _, cache, _ = mha_layers[i](q, att_mask, None, mask)
output_with_sdpa, _, cache_with_sdpa, _ = mha_layers_with_sdpa[i](
q, att_mask_bias, None, mask)

assert torch.allclose(
output * mask.transpose(1, 2),
output_with_sdpa * mask.transpose(1, 2),
atol=9e-7,
rtol=9e-4,
)
# assert torch.allclose(cache, cache_with_sdpa)

q = output
24 changes: 21 additions & 3 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True):
key_bias: bool = True,
use_sdpa: bool = False):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
Expand All @@ -49,6 +50,9 @@ def __init__(self,
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)

self.use_sdpa = use_sdpa
self.dropout_rate = dropout_rate

def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -192,8 +196,22 @@ def forward(
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)

scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
else:
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask.unsqueeze(1),
dropout_p=self.dropout_rate,
scale=1 / math.sqrt(self.d_k),
)
output = (output.transpose(1, 2).contiguous().view(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache


class RelPositionMultiHeadedAttention(MultiHeadedAttention):
Expand Down
18 changes: 14 additions & 4 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
)
from wenet.utils.common import mask_to_bias
from wenet.utils.mask import (subsequent_mask, make_pad_mask)


Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
activation_type: str = "relu",
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
):
super().__init__()
attention_dim = encoder_output_size
Expand All @@ -98,10 +100,10 @@ def __init__(
attention_dim,
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim,
self_attention_dropout_rate, key_bias),
self_attention_dropout_rate, key_bias, use_sdpa),
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim, src_attention_dropout_rate,
key_bias) if src_attention else None,
key_bias, use_sdpa) if src_attention else None,
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate, activation),
dropout_rate,
Expand All @@ -111,6 +113,7 @@ def __init__(

self.gradient_checkpointing = gradient_checkpointing
self.tie_word_embedding = tie_word_embedding
self.use_sdpa = use_sdpa

def forward(
self,
Expand Down Expand Up @@ -152,6 +155,10 @@ def forward(
device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
if self.use_sdpa:
tgt_mask = mask_to_bias(tgt_mask, tgt.dtype)
memory_mask = mask_to_bias(memory_mask, memory_mask.dtype)

x, _ = self.embed(tgt)
if self.gradient_checkpointing and self.training:
x = self.forward_layers_checkpointed(x, tgt_mask, memory,
Expand Down Expand Up @@ -290,6 +297,7 @@ def __init__(
key_bias: bool = True,
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
):

super().__init__()
Expand All @@ -309,7 +317,8 @@ def __init__(
normalize_before,
key_bias=key_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding)
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa)

self.right_decoder = TransformerDecoder(
vocab_size,
Expand All @@ -326,7 +335,8 @@ def __init__(
normalize_before,
key_bias=key_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding)
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa)

def forward(
self,
Expand Down
14 changes: 11 additions & 3 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
from wenet.utils.common import mask_to_bias


class BaseEncoder(torch.nn.Module):
Expand All @@ -53,6 +54,7 @@ def __init__(
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
):
"""
Args:
Expand Down Expand Up @@ -84,6 +86,7 @@ def __init__(
key_bias: whether use bias in attention.linear_k, False for whisper models.
gradient_checkpointing: rerunning a forward-pass segment for each
checkpointed segment during backward.
use_sdpa: whether to use SDPA, currently only support transformer for now
"""
super().__init__()
self._output_size = output_size
Expand All @@ -103,6 +106,7 @@ def __init__(
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
self.gradient_checkpointing = gradient_checkpointing
self.use_sdpa = use_sdpa

def output_size(self) -> int:
return self._output_size
Expand Down Expand Up @@ -149,6 +153,8 @@ def forward(
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
if self.use_sdpa:
chunk_masks = mask_to_bias(chunk_masks, xs.dtype)
if self.gradient_checkpointing and self.training:
xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
mask_pad)
Expand Down Expand Up @@ -355,6 +361,7 @@ def __init__(
key_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
):
""" Construct TransformerEncoder

Expand All @@ -365,15 +372,16 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing)
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
output_size,
attention_dropout_rate,
key_bias),
key_bias, use_sdpa),
PositionwiseFeedForward(output_size, linear_units,
dropout_rate, activation),
dropout_rate, normalize_before) for _ in range(num_blocks)
Expand Down Expand Up @@ -433,7 +441,7 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing)
use_dynamic_left_chunk, gradient_checkpointing, False)
activation = WENET_ACTIVATION_CLASSES[activation_type]()

# self-attention module definition
Expand Down
7 changes: 6 additions & 1 deletion wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import torch
from torch.nn.utils.rnn import pad_sequence

from wenet.utils.common import (add_sos_eos, log_add, add_whisper_tokens)
from wenet.utils.common import (add_sos_eos, log_add, add_whisper_tokens,
mask_to_bias)
from wenet.utils.ctc_utils import remove_duplicates_and_blank
from wenet.utils.mask import (make_pad_mask, mask_finished_preds,
mask_finished_scores, subsequent_mask)
Expand Down Expand Up @@ -289,6 +290,8 @@ def attention_beam_search(
]).unsqueeze(1).to(device) # (B*N, 1)
end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device)
cache: Optional[List[torch.Tensor]] = None
if model.decoder.use_sdpa:
encoder_mask = mask_to_bias(encoder_mask, encoder_out.dtype)
# 2. Decoder forward step by step
for i in range(prefix_len, maxlen + 1):
# Stop if all batch and all beam produce eos
Expand All @@ -297,6 +300,8 @@ def attention_beam_search(
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
if model.decoder.use_sdpa:
hyps_mask = mask_to_bias(hyps_mask, encoder_out.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

如果是bitransformer,这里访问use_sdpa属性还要再加一层module,model.decoder.left_decoder.use_sdpa

# logp: (B*N, vocab)
logp, cache = model.decoder.forward_one_step(encoder_out, encoder_mask,
hyps, hyps_mask, cache)
Expand Down
Loading
Loading