Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer] support flash att by 'torch scaled dot attention' #2351

Merged
merged 8 commits into from
Feb 21, 2024

Conversation

Mddct
Copy link
Collaborator

@Mddct Mddct commented Feb 19, 2024

https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

截屏2024-02-19 17 41 21

TODO

  • encoder sdpa
  • decoder sdpa
  • unit test
  • refactor att mask (support float bias)
  • benchmark
  • aishell traiing with sdpa

@Mddct
Copy link
Collaborator Author

Mddct commented Feb 19, 2024

https://github.com/wenet-e2e/wenet/blob/0f967f84dfd98f560f3a9e989ac63469a6807307/test/wenet/transformer/test_attention.py#L56-#L57

当att_mask 为bool mask的时候,torch当前的实现会产生NAN, 这里改为float ,并且这时候得到的结果中mask==0的部分是不一致的, zero out后一致

https://github.com/wenet-e2e/wenet/blob/0f967f84dfd98f560f3a9e989ac63469a6807307/test/wenet/transformer/test_attention.py#L62-#L67

ref:
pytorch/pytorch#103749

@Mddct
Copy link
Collaborator Author

Mddct commented Feb 19, 2024

test_w2vbert.py 中fairseq2的torch依赖小于2.1.1, 与该pr conflict,这里先把他删了 后边w2vber2 的权重迁移的pr中再添加该unit test

@Mddct Mddct force-pushed the Mddct-flash-att branch 2 times, most recently from 803416e to 896f05c Compare February 19, 2024 16:34
@Mddct
Copy link
Collaborator Author

Mddct commented Feb 20, 2024

A100:
torch-2.11
float 32
custom attention: 78.41721597802022 ms
sdpa attention: 61.47173937250135 ms

Float16
torch.float16
custom attention: 29.283620386028534 ms
sdpa attention: 16.00699585556984 ms

torch-2.2

torch.float16
custom attention: 29.052001880462903 ms
sdpa attention: 13.475210270261538 ms

+ activation  checkpoint (only forward)
custom attention: 29.876373025612057 ms
sdpa attention: 13.504191947530318 ms

torch.bfloat16
custom attention: 32.66699460367317 ms
sdpa attention: 13.980973962767248 ms
import torch.utils.benchmark as benchmark

from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES

import torch.utils.checkpoint as ckpt
from wenet.utils.mask import add_optional_chunk_mask, make_non_pad_mask

import torch


def benchmarks():
    dtype=torch.float16
    # dtype=torch.float32
    with torch.cuda.amp.autocast(enabled=dtype is not None,
                                  dtype=dtype,
                                     cache_enabled=False):
        args = {"n_feat": 512, "n_head": 4, "dropout_rate": 0.0}
        n_blocks = 12
        torch.manual_seed(777)
        mha_layers = [
            TransformerEncoderLayer(
                args['n_feat'],
                MultiHeadedAttention(use_sdpa=False, **args),
                PositionwiseFeedForward(
                    args['n_feat'],
                    2048,
                    0.0,
                    WENET_ACTIVATION_CLASSES['swish'](),
                ),
                0.0,
                normalize_before=True,
            )for _ in range(n_blocks)
        ]

        torch.manual_seed(777)
        mha_layers_with_sdpa = [
            TransformerEncoderLayer(
                args['n_feat'],
                MultiHeadedAttention(use_sdpa=True, **args),
                PositionwiseFeedForward(
                    args['n_feat'],
                    2048,
                    0.0,
                    WENET_ACTIVATION_CLASSES['swish'](),
                ),
                0.0,
                normalize_before=True,
            ) for _ in range(n_blocks)
        ]
        use_cuda = True
        if use_cuda:
            for i in range(n_blocks):
                mha_layers[i].cuda()
                mha_layers_with_sdpa[i].cuda()
        device = torch.device('cuda' if use_cuda else 'cpu')
        q = torch.rand(10, 1000, args['n_feat']).to(device)
        k = torch.rand(10, 1000, args['n_feat']).to(device)
        v = torch.rand(10, 1000, args['n_feat']).to(device)
        input_lens = torch.tensor([1000, 900, 800, 790, 600, 510, 400, 300, 100, 50])
        mask = make_non_pad_mask(input_lens).unsqueeze(1).to(device).to()
        att_mask = add_optional_chunk_mask(q,
                                           mask,
                                           use_dynamic_chunk=True,
                                           decoding_chunk_size=0,
                                           static_chunk_size=0,
                                           use_dynamic_left_chunk=True,
                                           num_decoding_left_chunks=-1)

        att_mask_bias = (1.0 - att_mask.to(dtype)) * torch.finfo(dtype).min

        # benchmark
        num_iter = 100
        timer = benchmark.Timer(
            stmt="forward(n_blocks, layers, att_mask, mask, input)",
            globals={
                'forward': forward,
                'n_blocks': n_blocks,
                'layers': mha_layers,
                'att_mask': att_mask,
                'mask': None,
                'input': q,
            })

        elasped = timer.blocked_autorange(min_run_time=num_iter)
        print('custom attention: {} ms'.format(elasped.mean * 1000))

        timer = benchmark.Timer(
            stmt="forward(n_blocks, layers, att_mask, mask, input)",
            globals={
                'forward': forward,
                'n_blocks': n_blocks,
                'layers': mha_layers_with_sdpa,
                'att_mask': att_mask_bias,
                'mask': None,
                'input': q,
            })

        elasped = timer.blocked_autorange(min_run_time=num_iter)
        print('sdpa attention: {} ms'.format(elasped.mean * 1000))

def forward(n_blocks, layers, att_mask, mask, input, enable_ckpt=False):
   if not enable_ckpt:
       for i in range(n_blocks):
            output, _, _, _ = layers[i](input, att_mask, None, mask)
            input = output
   else:
       for i in range(n_blocks):
            output, _, _, _ = ckpt.checkpoint(layers[i].__call__, input, att_mask, None, mask)
            input = output


benchmarks()

@Mddct
Copy link
Collaborator Author

Mddct commented Feb 20, 2024

training works !
截屏2024-02-21 00 35 22

@Mddct
Copy link
Collaborator Author

Mddct commented Feb 21, 2024

前后向benchmark:a100:
float16
custom attention: 79.05162322356945 ms
sdpa attention: 47.92393183424359 ms

float16 + activation ckpt
custom attention: 109.72641133095908 ms
sdpa attention: 64.0403646742925 ms

float32
custom attention: 152.07037702202797 ms
sdpa attention: 129.51342441523687 ms
import torch.utils.benchmark as benchmark

from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES

import torch.utils.checkpoint as ckpt
from wenet.utils.mask import add_optional_chunk_mask, make_non_pad_mask

import torch


class DummyTransformer(torch.nn.Module):

    def __init__(self,
                 output_size,
                 linear_unit=2048,
                 n_blocks=12,
                 use_sdpa=False,
                 use_ckpt=False) -> None:
        super().__init__()
        self.layers = torch.nn.ModuleList([
            TransformerEncoderLayer(
                output_size,
                MultiHeadedAttention(use_sdpa=use_sdpa,
                                     n_feat=512,
                                     n_head=4,
                                     dropout_rate=0.0),
                PositionwiseFeedForward(
                    output_size,
                    linear_unit,
                    0.0,
                    WENET_ACTIVATION_CLASSES['swish'](),
                ),
                0.0,
                normalize_before=True,
            ) for _ in range(n_blocks)
        ])
        self.user_ckpt = use_ckpt

    def forward(self, input, att_mask, mask):
        output: torch.Tensor
        if not self.user_ckpt:
            for layer in self.layers:
                input, _, _, _ = layer(input, att_mask, None, mask)
        else:
            for layer in self.layers:
                input, _, _, _ = ckpt.checkpoint(layer.__call__, input,
                                                  att_mask, None, mask)
        return input


def benchmarks():
    dtype = torch.float32
    output_size = 512
    torch.manual_seed(777)
    use_cuda = True
    device = torch.device('cuda' if use_cuda else 'cpu')
    q = torch.rand(10, 1000, output_size, requires_grad=True).to(device).to(dtype)
    k = torch.rand(10, 1000, output_size, requires_grad=True).to(device).to(dtype)
    v = torch.rand(10, 1000, output_size, requires_grad=True).to(device).to(dtype)
    input_lens = torch.tensor(
        [1000, 900, 800, 790, 600, 510, 400, 300, 100, 50])
    mask = make_non_pad_mask(input_lens).unsqueeze(1).to(device)
    att_mask = add_optional_chunk_mask(q,
                                       mask,
                                       use_dynamic_chunk=True,
                                       decoding_chunk_size=0,
                                       static_chunk_size=0,
                                       use_dynamic_left_chunk=True,
                                       num_decoding_left_chunks=-1)

    att_mask_bias = (1.0 - att_mask.to(dtype)) * torch.finfo(dtype).min

    custom_model = DummyTransformer(output_size, use_sdpa=False, use_ckpt=False)
    torch.manual_seed(777)
    sdap_model = DummyTransformer(output_size, use_sdpa=True, use_ckpt=False)
    if use_cuda:
        custom_model.cuda()
        sdap_model.cuda()

    def _time(model,
              input,
              att_mask,
              mask,
              dtype,
              num_iter=10,
              name='custom attention'):
        with torch.cuda.amp.autocast(enabled=dtype is not None,
                                     dtype=dtype,
                                     cache_enabled=False):
            timer = benchmark.Timer(stmt="loss=model(input,att_mask, mask); loss.sum().backward()",
                                    globals={
                                        'model': model,
                                        'input': input,
                                        'att_mask': att_mask,
                                        'mask': mask,
                                    })

            elasped = timer.blocked_autorange(min_run_time=num_iter)
            print('{}: {} ms'.format(name, elasped.mean * 1000))

    num_iter = 100
    _time(
        custom_model,
        q,
        att_mask,
        mask,
        dtype,
        num_iter=num_iter,
    )
    _time(sdap_model,
          q,
          att_mask,
          mask,
          dtype,
          num_iter=num_iter,
          name='sdpa attention')


benchmarks()

@Mddct
Copy link
Collaborator Author

Mddct commented Feb 21, 2024

#2333 (comment)
这个相同的配置
截屏2024-02-21 14 56 31
机器:v00
att: sdap, v00 不支持flash att, sdap 会自动启用memory efficient attention
解码使用:sdap(w) / sdpa (w/o) 都是以下结果
耗时:13h46min
性能: 5.68/5.38/5.89/5.89

+ckpt:
时间:15h27min
5.66/5.30/5.88/5.88

@xingchensong xingchensong merged commit 935250b into main Feb 21, 2024
6 checks passed
@xingchensong xingchensong deleted the Mddct-flash-att branch February 21, 2024 07:53
@fanlu
Copy link
Member

fanlu commented Feb 22, 2024

#2333 (comment) 这个相同的配置 截屏2024-02-21 14 56 31 机器:v00 att: sdap, v00 不支持flash att, sdap 会自动启用memory efficient attention 解码使用:sdap(w) / sdpa (w/o) 都是以下结果 耗时:13h46min 性能: 5.68/5.38/5.89/5.89

+ckpt: 时间:15h27min 5.66/5.30/5.88/5.88

有没有支持sdap的A100的对比结果?

@Mddct Mddct mentioned this pull request Feb 22, 2024
@@ -297,6 +300,8 @@ def attention_beam_search(
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
if model.decoder.use_sdpa:
hyps_mask = mask_to_bias(hyps_mask, encoder_out.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

如果是bitransformer,这里访问use_sdpa属性还要再加一层module,model.decoder.left_decoder.use_sdpa

@xingchensong
Copy link
Member

记录一下,未来可能用得到 pytorch/pytorch#110681

@Mddct
Copy link
Collaborator Author

Mddct commented Mar 21, 2024

#2333 (comment) 这个相同的配置 截屏2024-02-21 14 56 31 机器:v00 att: sdap, v00 不支持flash att, sdap 会自动启用memory efficient attention 解码使用:sdap(w) / sdpa (w/o) 都是以下结果 耗时:13h46min 性能: 5.68/5.38/5.89/5.89
+ckpt: 时间:15h27min 5.66/5.30/5.88/5.88

有没有支持sdap的A100的对比结果?

没有 手头没有a100 3090这些

@xingchensong
Copy link
Member

How to use :

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants