-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[transformer] support flash att by 'torch scaled dot attention' #2351
Conversation
dc06095
to
03068b8
Compare
0d44db5
to
0f967f8
Compare
当att_mask 为bool mask的时候,torch当前的实现会产生NAN, 这里改为float ,并且这时候得到的结果中mask==0的部分是不一致的, zero out后一致 |
cf76768
to
6f1fb43
Compare
test_w2vbert.py 中fairseq2的torch依赖小于2.1.1, 与该pr conflict,这里先把他删了 后边w2vber2 的权重迁移的pr中再添加该unit test |
803416e
to
896f05c
Compare
896f05c
to
41cd68f
Compare
A100:
torch-2.11
float 32
custom attention: 78.41721597802022 ms
sdpa attention: 61.47173937250135 ms
Float16
torch.float16
custom attention: 29.283620386028534 ms
sdpa attention: 16.00699585556984 ms
torch-2.2
torch.float16
custom attention: 29.052001880462903 ms
sdpa attention: 13.475210270261538 ms
+ activation checkpoint (only forward)
custom attention: 29.876373025612057 ms
sdpa attention: 13.504191947530318 ms
torch.bfloat16
custom attention: 32.66699460367317 ms
sdpa attention: 13.980973962767248 ms
import torch.utils.benchmark as benchmark
from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES
import torch.utils.checkpoint as ckpt
from wenet.utils.mask import add_optional_chunk_mask, make_non_pad_mask
import torch
def benchmarks():
dtype=torch.float16
# dtype=torch.float32
with torch.cuda.amp.autocast(enabled=dtype is not None,
dtype=dtype,
cache_enabled=False):
args = {"n_feat": 512, "n_head": 4, "dropout_rate": 0.0}
n_blocks = 12
torch.manual_seed(777)
mha_layers = [
TransformerEncoderLayer(
args['n_feat'],
MultiHeadedAttention(use_sdpa=False, **args),
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
normalize_before=True,
)for _ in range(n_blocks)
]
torch.manual_seed(777)
mha_layers_with_sdpa = [
TransformerEncoderLayer(
args['n_feat'],
MultiHeadedAttention(use_sdpa=True, **args),
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
normalize_before=True,
) for _ in range(n_blocks)
]
use_cuda = True
if use_cuda:
for i in range(n_blocks):
mha_layers[i].cuda()
mha_layers_with_sdpa[i].cuda()
device = torch.device('cuda' if use_cuda else 'cpu')
q = torch.rand(10, 1000, args['n_feat']).to(device)
k = torch.rand(10, 1000, args['n_feat']).to(device)
v = torch.rand(10, 1000, args['n_feat']).to(device)
input_lens = torch.tensor([1000, 900, 800, 790, 600, 510, 400, 300, 100, 50])
mask = make_non_pad_mask(input_lens).unsqueeze(1).to(device).to()
att_mask = add_optional_chunk_mask(q,
mask,
use_dynamic_chunk=True,
decoding_chunk_size=0,
static_chunk_size=0,
use_dynamic_left_chunk=True,
num_decoding_left_chunks=-1)
att_mask_bias = (1.0 - att_mask.to(dtype)) * torch.finfo(dtype).min
# benchmark
num_iter = 100
timer = benchmark.Timer(
stmt="forward(n_blocks, layers, att_mask, mask, input)",
globals={
'forward': forward,
'n_blocks': n_blocks,
'layers': mha_layers,
'att_mask': att_mask,
'mask': None,
'input': q,
})
elasped = timer.blocked_autorange(min_run_time=num_iter)
print('custom attention: {} ms'.format(elasped.mean * 1000))
timer = benchmark.Timer(
stmt="forward(n_blocks, layers, att_mask, mask, input)",
globals={
'forward': forward,
'n_blocks': n_blocks,
'layers': mha_layers_with_sdpa,
'att_mask': att_mask_bias,
'mask': None,
'input': q,
})
elasped = timer.blocked_autorange(min_run_time=num_iter)
print('sdpa attention: {} ms'.format(elasped.mean * 1000))
def forward(n_blocks, layers, att_mask, mask, input, enable_ckpt=False):
if not enable_ckpt:
for i in range(n_blocks):
output, _, _, _ = layers[i](input, att_mask, None, mask)
input = output
else:
for i in range(n_blocks):
output, _, _, _ = ckpt.checkpoint(layers[i].__call__, input, att_mask, None, mask)
input = output
benchmarks() |
前后向benchmark:a100:
float16
custom attention: 79.05162322356945 ms
sdpa attention: 47.92393183424359 ms
float16 + activation ckpt
custom attention: 109.72641133095908 ms
sdpa attention: 64.0403646742925 ms
float32
custom attention: 152.07037702202797 ms
sdpa attention: 129.51342441523687 ms import torch.utils.benchmark as benchmark
from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES
import torch.utils.checkpoint as ckpt
from wenet.utils.mask import add_optional_chunk_mask, make_non_pad_mask
import torch
class DummyTransformer(torch.nn.Module):
def __init__(self,
output_size,
linear_unit=2048,
n_blocks=12,
use_sdpa=False,
use_ckpt=False) -> None:
super().__init__()
self.layers = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
MultiHeadedAttention(use_sdpa=use_sdpa,
n_feat=512,
n_head=4,
dropout_rate=0.0),
PositionwiseFeedForward(
output_size,
linear_unit,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
normalize_before=True,
) for _ in range(n_blocks)
])
self.user_ckpt = use_ckpt
def forward(self, input, att_mask, mask):
output: torch.Tensor
if not self.user_ckpt:
for layer in self.layers:
input, _, _, _ = layer(input, att_mask, None, mask)
else:
for layer in self.layers:
input, _, _, _ = ckpt.checkpoint(layer.__call__, input,
att_mask, None, mask)
return input
def benchmarks():
dtype = torch.float32
output_size = 512
torch.manual_seed(777)
use_cuda = True
device = torch.device('cuda' if use_cuda else 'cpu')
q = torch.rand(10, 1000, output_size, requires_grad=True).to(device).to(dtype)
k = torch.rand(10, 1000, output_size, requires_grad=True).to(device).to(dtype)
v = torch.rand(10, 1000, output_size, requires_grad=True).to(device).to(dtype)
input_lens = torch.tensor(
[1000, 900, 800, 790, 600, 510, 400, 300, 100, 50])
mask = make_non_pad_mask(input_lens).unsqueeze(1).to(device)
att_mask = add_optional_chunk_mask(q,
mask,
use_dynamic_chunk=True,
decoding_chunk_size=0,
static_chunk_size=0,
use_dynamic_left_chunk=True,
num_decoding_left_chunks=-1)
att_mask_bias = (1.0 - att_mask.to(dtype)) * torch.finfo(dtype).min
custom_model = DummyTransformer(output_size, use_sdpa=False, use_ckpt=False)
torch.manual_seed(777)
sdap_model = DummyTransformer(output_size, use_sdpa=True, use_ckpt=False)
if use_cuda:
custom_model.cuda()
sdap_model.cuda()
def _time(model,
input,
att_mask,
mask,
dtype,
num_iter=10,
name='custom attention'):
with torch.cuda.amp.autocast(enabled=dtype is not None,
dtype=dtype,
cache_enabled=False):
timer = benchmark.Timer(stmt="loss=model(input,att_mask, mask); loss.sum().backward()",
globals={
'model': model,
'input': input,
'att_mask': att_mask,
'mask': mask,
})
elasped = timer.blocked_autorange(min_run_time=num_iter)
print('{}: {} ms'.format(name, elasped.mean * 1000))
num_iter = 100
_time(
custom_model,
q,
att_mask,
mask,
dtype,
num_iter=num_iter,
)
_time(sdap_model,
q,
att_mask,
mask,
dtype,
num_iter=num_iter,
name='sdpa attention')
benchmarks() |
#2333 (comment) +ckpt: |
有没有支持sdap的A100的对比结果? |
@@ -297,6 +300,8 @@ def attention_beam_search( | |||
# 2.1 Forward decoder step | |||
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( | |||
running_size, 1, 1).to(device) # (B*N, i, i) | |||
if model.decoder.use_sdpa: | |||
hyps_mask = mask_to_bias(hyps_mask, encoder_out.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
记录一下,未来可能用得到 pytorch/pytorch#110681 |
没有 手头没有a100 3090这些 |
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
TODO