Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[low-bit optim] Upcast everything to FP32 for internal calculations #1068

Merged
merged 2 commits into from
Oct 14, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Oct 14, 2024

Fixes #1067.

Previously, it seems like torch.compile will not check for dtype mismatch when tensor subclass is used (e.g. tensor_subclass_fp32.lerp(plain_tensor_bf16, weight)). Now it does, raising the error. To fix it, I simply cast everything to FP32.

The dtype mismatch comes from the fact that my tensor subclasses for optim state have always used FP32 appearance dtype, even if param is BF16. This results in FP32 calculations, which is correct, though not originally intentional. Now I have made it explicit and intentional. This also means that BF16 param + BF16 optim state combination is now more accurate.

If I have time, I will re-run some some of the benchmarks to make sure things are alright.

Copy link

pytorch-bot bot commented Oct 14, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1068

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 72ca834 with merge base e7b33bc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 14, 2024
@gau-nernst
Copy link
Collaborator Author

@msaroufim The failing low-bit optim tests pass now, but now CI timeouts 🤣

@msaroufim
Copy link
Member

msaroufim commented Oct 14, 2024

You can extend it to 2h

-timeout: 60
+timeout: 120

In https://github.com/pytorch/ao/blob/main/.github/workflows/regression_test.yml

EDIT: I just made the change myself

Copy link
Contributor

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems fine though CI is still broken, see mark's comment, will merge once CI is passing

@msaroufim msaroufim merged commit afc0a02 into pytorch:main Oct 14, 2024
17 checks passed
@gau-nernst gau-nernst deleted the fix_optim branch October 14, 2024 22:39
jainapurva pushed a commit that referenced this pull request Oct 15, 2024
…1068)

* fix dtype

* Update regression_test.yml

---------

Co-authored-by: Mark Saroufim <[email protected]>
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
This PR makes torchchat support multi-modality model definition and constructions. To show our power in multi-modality area, we integrate flamingo component into our system.
Note that this is only for bare-minimum support for model definition. Please check openai_api_multimodal branch for e2e, and pytorch#1123 (comment) for better structure and llama3.1 support
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

PT nightly optim failing both on cpu and gpu
4 participants