Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cherry-pick] Use scaled_dot_product_attention in WavLM attention (#3252, #3265) #3264

Merged
merged 2 commits into from
Apr 12, 2023

Conversation

nateanl
Copy link
Member

@nateanl nateanl commented Apr 11, 2023

Summary:
Fix #3219.

torch.nn.MultiheadAttention will throw an error if torch.no_grad() and mask are both given. The pull request fixes it by replacing the forward method with torch.nn.functional.scaled_dot_product_attention.

Pull Request resolved: #3252

Reviewed By: mthrok

Differential Revision: D44798634

Pulled By: nateanl

fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665

@nateanl nateanl requested a review from a team April 11, 2023 01:44
@nateanl nateanl force-pushed the cherrypick-wav2vec2 branch from ce8cef6 to 614cdfe Compare April 12, 2023 12:18
nateanl added 2 commits April 12, 2023 08:23
Summary:
Fix pytorch#3219.

`torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`.

Pull Request resolved: pytorch#3252

Reviewed By: mthrok

Differential Revision: D44798634

Pulled By: nateanl

fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665
Summary:
When `key_padding_mask` is not `None`, it needs to be combined with `attn_mask_rel_pos` as one mask for `scaled_dot_product_attention` function.

Pull Request resolved: pytorch#3265

Reviewed By: hwangjeff

Differential Revision: D44901093

Pulled By: nateanl

fbshipit-source-id: 73ca7af48faf7f4eb36b35b603187a11e5582c70
@nateanl nateanl force-pushed the cherrypick-wav2vec2 branch from 46b7195 to 15c0b62 Compare April 12, 2023 12:23
@nateanl nateanl changed the title [Cherry-pick] Use scaled_dot_product_attention in WavLM attention (#3252) [Cherry-pick] Use scaled_dot_product_attention in WavLM attention (#3252, #3265) Apr 12, 2023
@nateanl nateanl merged commit 54f6c1f into pytorch:release/2.0 Apr 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants