-
Notifications
You must be signed in to change notification settings - Fork 31.6k
Fix dtype casting in swinv2 and swinv2sr to allow non-FP32 inference #31589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
Could you also add some small tests like this one for vit?
Co-authored-by: amyeroberts <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
|
Added them, they pass locally |
|
@aliencaocao Great! Could you push an empty commit with the message: |
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@amyeroberts need your approval for slow tests |
|
ugh it seems the specific gpu used indeed has numerical difference... |
|
@aliencaocao As the change is simple, and look OK, updating the tests to use the results from the CI runs should be OK |
|
How do I get the CI outputs? |
|
@aliencaocao Good question! Indeed, they're not part of the console output. Let me see if I can ssh in |
|
@aliencaocao Running on the runners, I get the following logits swin2sr swinv2 |
|
Thanks, triggered again |
|
@aliencaocao Thanks! All looks good - we can merge 🤗 |
What does this PR do?
The current implementation uses
.float()intransformers/src/transformers/models/swin2sr/modeling_swin2sr.py
Lines 286 to 287 in 0f67ba1
relative_coords_tableto be always intorch.float32, not respecting whatever precision the other weights might be, e.g.torch.float16.This PR adds a cast to the same
dtypeas thecontinuous_position_bias_mlplayer sincerelative_coords_tableis being passed directly into the layer attransformers/src/transformers/models/swin2sr/modeling_swin2sr.py
Line 349 in 0f67ba1
Same issue & fix for swinv2
Prerequisite for #31342 image to image pipeline FP16 test to pass.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@amyeroberts