-
Notifications
You must be signed in to change notification settings - Fork 31.8k
fix(Wav2Vec2ForCTC): torch export #34023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Resolves the issue described in huggingface#34022 by implementing the masking of the hidden states using an elementwise multiplication rather than indexing with assignment. The torch.export functionality seems to mark the tensor as frozen even though the update is legal. This change is a workaround for now to allow the export of the model as a FxGraph. Further investigation is required to find the real solution in pytorch.
2f745d4 to
fc9d68b
Compare
ylacombe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM, but I'd like another opinion for a torch compile expert as well!
Maybe @gante or @zucchini-nlp ? Thanks for your help
|
Seems okey to me, compile stuff usually complain on in-place modification to tensor |
|
Thanks @zucchini-nlp, requesting a core maintainer review then ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's run slow tests just in case. You can do it by pushing an empty commit like this: git commit --allow-empty -m "[run-slow] hubert, unispeech, unispeech_sat, wav2vec2"
|
@ylacombe @zucchini-nlp these errors seem unrelated to the slow tests. Is this expected? |
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, thanks! @ylacombe feel free to merge once you're satisfied with the tests
|
LGTM 👍 (I've rerun CI in case it was a transient error) |
* fix(Wav2Vec2ForCTC): torch export Resolves the issue described in huggingface#34022 by implementing the masking of the hidden states using an elementwise multiplication rather than indexing with assignment. The torch.export functionality seems to mark the tensor as frozen even though the update is legal. This change is a workaround for now to allow the export of the model as a FxGraph. Further investigation is required to find the real solution in pytorch. * [run-slow] hubert, unispeech, unispeech_sat, wav2vec2
What does this PR do?
Fixes #34022 by implementing the masking of the hidden states using an elementwise multiplication rather than indexing with assignment.
The torch.export functionality seems to mark the tensor as frozen even though the update is legal.
This change is a workaround for now to allow the export of the model as a FxGraph. Further investigation is required to find the real solution in pytorch.
Tagging:
@ylacombe, @eustlb
Please let me know if someone else is more appropriate to review this PR.