Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BF16 stochastic rounding option for optimizers #1124

Merged
merged 7 commits into from
Oct 23, 2024

Conversation

gau-nernst
Copy link
Collaborator

Stochastic rounding for BF16 weight

BF16 only has around 3 decimal precision. This means that if weight update is smaller than 1e-3 of the weight magnitude, there will be no change to the weight (using nearest rounding). This is highly problematic for full BF16 training, where we don't keep an FP32 copy of model weights.

Note that our optimizer step calculations are always done in FP32 to ensure accurate results. The "underflow" only happens when we copy the new weight value (in FP32) to the existing BF16 weight. To combat this problem, one way is to perform stochastic rounding when casting FP32->BF16.

  • In stochastic rounding, we will round up with the probability of (x - round_down(x)) / (round_up(x) - round_down(x)), and round down otherwise.
  • It follows that successive weight update with stochastic rounding will correctly approximate high-precision weight update.
  • Since BF16 is simply a truncation of FP32, there is an efficient implementation for FP32->BF16 stochastic rounding (the same is not true for FP32->FP16).
  • More detailed discussion can be found at https://arxiv.org/abs/2010.06192. llm.c also implements this approach.
# a clone of torch.optim.AdamW with extra features
from torchao.prototype.low_bit_optim import _AdamW

model = ...
model_bf16 = model.bfloat16()
optim = _AdamW(model_bf16.parameters(), bf16_stochastic_round=True)

All of our low-bit optimizers mentioned above also support bf16_stochastic_round flag. Note that this flag only applies to BF16 weight.

Experimental results

I purposely use small LR (1e-5) to exaggerate the problem.

python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --lr 1e-5 # full BF16
python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_model --compile --lr 1e-5 --optim_kwargs {"bf16_stochastic_round":true} # with stochastic rounding
python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --bf16_amp --compile --lr 1e-5 # BF16 AMP (FP32 weight)

image

BF16 stochastic round matches BF16 amp loss curve, while having the same memory footprint and speed as full BF16 (BF16 amp is slower due to amp overhead).

Copy link

pytorch-bot bot commented Oct 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1124

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 935d198 with merge base 85ec209 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 21, 2024
Copy link
Member

@msaroufim msaroufim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, wanna start moving low bit optim out of prototype? as far as I can tell you've been keeping BC guarantees. Might just need to forward fix with huggingface transformers or keep the calling functions in prototype but just make them throw a warning and call the non prototype code

@gau-nernst
Copy link
Collaborator Author

Once CI is green can I merge? Just changed the llm.c ref to permalink instead of pointing to main branch

wanna start moving low bit optim out of prototype

Sounds good. We can re-import the optim under torchao.prototype.low_bit_optim (and raise a warning) for BC like you suggested.

Should we also take this chance to rename the folder to just optim? Since I added some non-low-bit features, like CPU offload, support for tensor subclass param (for quantized training), and this BF16 stochastic rounding 😄

@gau-nernst gau-nernst merged commit a31e15d into pytorch:main Oct 23, 2024
17 checks passed
@gau-nernst gau-nernst deleted the bf16_optim_sr branch October 23, 2024 01:26
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
…stom formatting options including file name: line_number (pytorch#1124)

* add SingletonLogger with custom formatting options

* ruff formatting
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants