-
Notifications
You must be signed in to change notification settings - Fork 32k
Fix failing tests on main due to torch 2.1
#26607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for Hubert and Wav2Vec2, fx tracing starts to fail with torch 2.1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
despite the changes in the modeling files, this part still fails
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, let's try to use unitest skip and put the torch fx proxy in the model forward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of commenting we can add something like
self.skipTest("Skipping until we fix it ")There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do the check a lot earlier, we can check this in the model's forward only once!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure about this place. This block contain
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
and it should be run no matter if we are in fx proxy or not
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, let's ensure we check for eventual performance issues and fix the TODOs quickly.
Thanks for fixing these tests so quickly!
|
The documentation is not available anymore as the PR was closed or merged. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
|
@michaelbenayoun Could you help us on the torch fx tests for wav2vec2/hubert with torch 2.1. But in short, it can't do as the corresponding proxy object has no |
What does this PR do?
Fix failing tests on
maindue to torch 2.1