Add BF16 stochastic rounding option for optimizers #1124
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stochastic rounding for BF16 weight
BF16 only has around 3 decimal precision. This means that if weight update is smaller than 1e-3 of the weight magnitude, there will be no change to the weight (using nearest rounding). This is highly problematic for full BF16 training, where we don't keep an FP32 copy of model weights.
Note that our optimizer step calculations are always done in FP32 to ensure accurate results. The "underflow" only happens when we copy the new weight value (in FP32) to the existing BF16 weight. To combat this problem, one way is to perform stochastic rounding when casting FP32->BF16.
(x - round_down(x)) / (round_up(x) - round_down(x))
, and round down otherwise.All of our low-bit optimizers mentioned above also support
bf16_stochastic_round
flag. Note that this flag only applies to BF16 weight.Experimental results
I purposely use small LR (1e-5) to exaggerate the problem.
BF16 stochastic round matches BF16 amp loss curve, while having the same memory footprint and speed as full BF16 (BF16 amp is slower due to amp overhead).