-
Notifications
You must be signed in to change notification settings - Fork 661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torchaudio.pipelines.WAVLM encounted error: Mask shape should match input #3219
Comments
The issue could be related with |
nateanl
added a commit
to nateanl/audio
that referenced
this issue
Apr 11, 2023
Summary: Fix pytorch#3219. `torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`. Pull Request resolved: pytorch#3252 Reviewed By: mthrok Differential Revision: D44798634 Pulled By: nateanl fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665
nateanl
added a commit
to nateanl/audio
that referenced
this issue
Apr 12, 2023
Summary: Fix pytorch#3219. `torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`. Pull Request resolved: pytorch#3252 Reviewed By: mthrok Differential Revision: D44798634 Pulled By: nateanl fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665
nateanl
added a commit
that referenced
this issue
Apr 12, 2023
, #3265) (#3264) * Use scaled_dot_product_attention in WavLM attention (#3252) Summary: Fix #3219. `torch.nn.MultiheadAttention` will throw an error if `torch.no_grad()` and mask are both given. The pull request fixes it by replacing the forward method with `torch.nn.functional.scaled_dot_product_attention`. Pull Request resolved: #3252 Reviewed By: mthrok Differential Revision: D44798634 Pulled By: nateanl fbshipit-source-id: abfa7fb84b7bd71848a92ab26da5a5f0f095c665 * Merge key_padding_mask into attn_mask_rel_pos in WavLM (#3265) Summary: When `key_padding_mask` is not `None`, it needs to be combined with `attn_mask_rel_pos` as one mask for `scaled_dot_product_attention` function. Pull Request resolved: #3265 Reviewed By: hwangjeff Differential Revision: D44901093 Pulled By: nateanl fbshipit-source-id: 73ca7af48faf7f4eb36b35b603187a11e5582c70
SevKod
referenced
this issue
in pyannote/pyannote-audio
May 15, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
🐛 Describe the bug
Bug: I tried to use "torchaudio.pipelines.WAVLM_BASE_PLUS", however, it suffered from an error: "RuntimeError: Mask shape should match input. mask: [384, 199, 199] input: [32, 12, 199, 199]". Which only happens when using "with torch.no_grad():", it seems it enables the fast path and then crashes.
Sample Code:
Error:
return torch._native_multi_head_attention(
Traceback (most recent call last):
File "/data1/test_wavlm.py", line 9, in
features, _ = ssl_model.extract_features(waveform.cuda())
File "/data1//myve_torch2/lib/python3.9/site-packages/torchaudio/models/wav2vec2/model.py", line 84, in extract_features
x = self.encoder.extract_features(x, lengths, num_layers)
File "/data1//myve_torch2/lib/python3.9/site-packages/torchaudio/models/wav2vec2/components.py", line 525, in extract_features
return self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers)
File "/data1//myve_torch2/lib/python3.9/site-packages/torchaudio/models/wav2vec2/components.py", line 474, in get_intermediate_outputs
x, _ = layer(x, attention_mask) # Ignore position_bias
File "/data1//myve_torch2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/data1//myve_torch2/lib/python3.9/site-packages/torchaudio/models/wav2vec2/components.py", line 405, in forward
x, position_bias = self.attention(
File "/data1//myve_torch2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/data1//myve_torch2/lib/python3.9/site-packages/torchaudio/models/wav2vec2/wavlm_attention.py", line 185, in forward
attn_output, _ = self.attention(
File "/data1//myve_torch2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/data1//myve_torch2/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1144, in forward
return torch._native_multi_head_attention(
RuntimeError: Mask shape should match input. mask: [384, 199, 199] input: [32, 12, 199, 199]
Versions
Versions of relevant libraries:
[pip3] numpy==1.24.1
[pip3] pytorch-triton==2.1.0+e650d3708b
[pip3] torch==2.0.0
[pip3] torchaudio==2.0.1
[pip3] torchvision==0.15.1+cu118
[pip3] triton==2.0.0
The text was updated successfully, but these errors were encountered: