-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BF16 stochastic rounding option for optimizers #1124
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1124
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 935d198 with merge base 85ec209 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, wanna start moving low bit optim out of prototype? as far as I can tell you've been keeping BC guarantees. Might just need to forward fix with huggingface transformers or keep the calling functions in prototype but just make them throw a warning and call the non prototype code
Once CI is green can I merge? Just changed the llm.c ref to permalink instead of pointing to main branch
Sounds good. We can re-import the optim under Should we also take this chance to rename the folder to just |
…stom formatting options including file name: line_number (pytorch#1124) * add SingletonLogger with custom formatting options * ruff formatting
Stochastic rounding for BF16 weight
BF16 only has around 3 decimal precision. This means that if weight update is smaller than 1e-3 of the weight magnitude, there will be no change to the weight (using nearest rounding). This is highly problematic for full BF16 training, where we don't keep an FP32 copy of model weights.
Note that our optimizer step calculations are always done in FP32 to ensure accurate results. The "underflow" only happens when we copy the new weight value (in FP32) to the existing BF16 weight. To combat this problem, one way is to perform stochastic rounding when casting FP32->BF16.
(x - round_down(x)) / (round_up(x) - round_down(x))
, and round down otherwise.All of our low-bit optimizers mentioned above also support
bf16_stochastic_round
flag. Note that this flag only applies to BF16 weight.Experimental results
I purposely use small LR (1e-5) to exaggerate the problem.
BF16 stochastic round matches BF16 amp loss curve, while having the same memory footprint and speed as full BF16 (BF16 amp is slower due to amp overhead).