Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use scaled_dot_product_attention in Wav2vec2/HuBERT's SelfAttention #3253

Closed
wants to merge 1 commit into from

Conversation

nateanl
Copy link
Member

@nateanl nateanl commented Apr 7, 2023

Replace the attention computation with torch.nn.functional.scaled_dot_product_attention to improve running efficiency.

@nateanl
Copy link
Member Author

nateanl commented Apr 7, 2023

Here are the benchmark results with new changes. The benchmark script can be found in https://gist.github.com/nateanl/97b2f9adb39c05a4e854fbd924de01f6.

SelfAttention
[--------------- Wav2Vec2 benchmark ----------------]
                                 |    CPU    |  CUDA
1 threads: -----------------------------------------
      wav2vec2_base 5 seconds    |   1508.9  |   9.2
      wav2vec2_base 10 seconds   |   3198.6  |  19.9
      wav2vec2_base 15 seconds   |   5323.6  |  28.0
      wav2vec2_base 20 seconds   |   7605.2  |  35.6
      wav2vec2_large 5 seconds   |   4247.2  |  20.3
      wav2vec2_large 10 seconds  |   9066.1  |  43.3
      wav2vec2_large 15 seconds  |  14862.5  |  62.6
      wav2vec2_large 20 seconds  |  21632.6  |  82.5
Times are in milliseconds (ms).


scaled_dot_product_attention
[--------------- Wav2Vec2 benchmark ----------------]
                                 |    CPU    |  CUDA
1 threads: -----------------------------------------
      wav2vec2_base 5 seconds    |   1430.9  |   8.3
      wav2vec2_base 10 seconds   |   3047.9  |  18.8
      wav2vec2_base 15 seconds   |   5097.1  |  25.8
      wav2vec2_base 20 seconds   |   7121.4  |  32.3
      wav2vec2_large 5 seconds   |   4091.0  |  19.2
      wav2vec2_large 10 seconds  |   8651.0  |  40.3
      wav2vec2_large 15 seconds  |  14594.4  |  57.2
      wav2vec2_large 20 seconds  |  20186.2  |  72.5
Times are in milliseconds (ms).
"""

@facebook-github-bot
Copy link
Contributor

@nateanl has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@nateanl
Copy link
Member Author

nateanl commented Apr 8, 2023

Here is the script which shows the scaled_dot_product_attention achieves identical outputs as SelfAttention.

@mthrok
Copy link
Collaborator

mthrok commented Apr 10, 2023

Does this work with quantization? One of the reason here that I did not use PyTorch native MHA is it does not support quantization.

@nateanl
Copy link
Member Author

nateanl commented Apr 10, 2023

Does this work with quantization?

I think so. There is a quantization unit test and it has passed.

@facebook-github-bot
Copy link
Contributor

@nateanl merged this pull request in 94cc4bd.

@github-actions
Copy link

Hey @nateanl.
You merged this PR, but labels were not properly added. Please add a primary and secondary label (See https://github.com/pytorch/audio/blob/main/.github/process_commit.py)

nateanl added a commit to nateanl/audio that referenced this pull request Apr 10, 2023
…ytorch#3253)

Summary:
Replace the attention computation with `torch.nn.functional.scaled_dot_product_attention` to improve running efficiency.

Pull Request resolved: pytorch#3253

Reviewed By: mthrok

Differential Revision: D44800353

Pulled By: nateanl

fbshipit-source-id: 41550d868c809099aadbe812b0ebe2c38121efb8
nateanl added a commit that referenced this pull request Apr 11, 2023
…3253) (#3261)

Summary:
Replace the attention computation with `torch.nn.functional.scaled_dot_product_attention` to improve running efficiency.

Pull Request resolved: #3253

Reviewed By: mthrok

Differential Revision: D44800353

Pulled By: nateanl

fbshipit-source-id: 41550d868c809099aadbe812b0ebe2c38121efb8
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants