-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use scaled_dot_product_attention in Wav2vec2/HuBERT's SelfAttention #3253
Conversation
Here are the benchmark results with new changes. The benchmark script can be found in https://gist.github.com/nateanl/97b2f9adb39c05a4e854fbd924de01f6.
|
@nateanl has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Here is the script which shows the |
Does this work with quantization? One of the reason here that I did not use PyTorch native MHA is it does not support quantization. |
I think so. There is a quantization unit test and it has passed. |
Hey @nateanl. |
…ytorch#3253) Summary: Replace the attention computation with `torch.nn.functional.scaled_dot_product_attention` to improve running efficiency. Pull Request resolved: pytorch#3253 Reviewed By: mthrok Differential Revision: D44800353 Pulled By: nateanl fbshipit-source-id: 41550d868c809099aadbe812b0ebe2c38121efb8
…3253) (#3261) Summary: Replace the attention computation with `torch.nn.functional.scaled_dot_product_attention` to improve running efficiency. Pull Request resolved: #3253 Reviewed By: mthrok Differential Revision: D44800353 Pulled By: nateanl fbshipit-source-id: 41550d868c809099aadbe812b0ebe2c38121efb8
Replace the attention computation with
torch.nn.functional.scaled_dot_product_attention
to improve running efficiency.