Skip to content

Conversation

@aliencaocao
Copy link
Contributor

@aliencaocao aliencaocao commented Jun 25, 2024

What does this PR do?

The current implementation uses .float() in

relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
which causes subsequent relative_coords_table to be always in torch.float32, not respecting whatever precision the other weights might be, e.g. torch.float16.

This PR adds a cast to the same dtype as the continuous_position_bias_mlp layer since relative_coords_table is being passed directly into the layer at

relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view(

Same issue & fix for swinv2

Prerequisite for #31342 image to image pipeline FP16 test to pass.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts

@aliencaocao aliencaocao changed the title Fix dtype casting in modeling_swin2sr to allow non-FP32 inference Fix dtype casting in swinv2 and swinv2sr to allow non-FP32 inference Jun 25, 2024
Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Could you also add some small tests like this one for vit?

@aliencaocao
Copy link
Contributor Author

Added them, they pass locally

@amyeroberts
Copy link
Contributor

@aliencaocao Great! Could you push an empty commit with the message: [run_slow] swin2sr, swinv2. I trust the tests are passing locally, but because of differences that can creep in because of hardware and env set-up, the logits can still be slightly different. So let's make sure the numbers match what's going to be running on the CI : )

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@aliencaocao
Copy link
Contributor Author

@amyeroberts need your approval for slow tests

@aliencaocao
Copy link
Contributor Author

ugh it seems the specific gpu used indeed has numerical difference...
I ran the tests and got logit using RTX 3080Ti, torch2.3.1+cu121, nvidia 555.95 driver, Windows 10

@amyeroberts
Copy link
Contributor

@aliencaocao As the change is simple, and look OK, updating the tests to use the results from the CI runs should be OK

@aliencaocao
Copy link
Contributor Author

aliencaocao commented Jun 26, 2024

How do I get the CI outputs?
Do I have to print in CI?

@amyeroberts
Copy link
Contributor

@aliencaocao Good question! Indeed, they're not part of the console output. Let me see if I can ssh in

@amyeroberts
Copy link
Contributor

@aliencaocao Running on the runners, I get the following logits

swin2sr

tensor([[0.5454, 0.5542, 0.5640],
        [0.5518, 0.5562, 0.5649],
        [0.5391, 0.5425, 0.5620]], device='cuda:0', dtype=torch.float16)

swinv2

tensor([-0.3938, -0.4290,  0.0020], device='cuda:0', dtype=torch.float16)

@aliencaocao
Copy link
Contributor Author

Thanks, triggered again

@amyeroberts
Copy link
Contributor

@aliencaocao Thanks! All looks good - we can merge 🤗

@amyeroberts amyeroberts merged commit 1f9f57a into huggingface:main Jun 26, 2024
@aliencaocao aliencaocao deleted the fix-swin2sr-dtype branch June 26, 2024 22:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants