Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Cherry-pick] Use scaled_dot_product_attention in WavLM attention (#3252
, #3265) (#3264) * Use scaled_dot_product_attention in WavLM attention (#3252) Summary: Fix #3219. `torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`. Pull Request resolved: #3252 Reviewed By: mthrok Differential Revision: D44798634 Pulled By: nateanl fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665 * Merge key_padding_mask into attn_mask_rel_pos in WavLM (#3265) Summary: When `key_padding_mask` is not `None`, it needs to be combined with `attn_mask_rel_pos` as one mask for `scaled_dot_product_attention` function. Pull Request resolved: #3265 Reviewed By: hwangjeff Differential Revision: D44901093 Pulled By: nateanl fbshipit-source-id: 73ca7af48faf7f4eb36b35b603187a11e5582c70
- Loading branch information