[low-bit optim] Upcast everything to FP32 for internal calculations #1068
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #1067.
Previously, it seems like torch.compile will not check for dtype mismatch when tensor subclass is used (e.g.
tensor_subclass_fp32.lerp(plain_tensor_bf16, weight)
). Now it does, raising the error. To fix it, I simply cast everything to FP32.The dtype mismatch comes from the fact that my tensor subclasses for optim state have always used FP32 appearance dtype, even if param is BF16. This results in FP32 calculations, which is correct, though not originally intentional. Now I have made it explicit and intentional. This also means that BF16 param + BF16 optim state combination is now more accurate.
If I have time, I will re-run some some of the benchmarks to make sure things are alright.