-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[low-bit optim] Add coat for float8 optimizer #1231
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1231
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I was thinking you can just add a flag to the current |
i have added the flag for optimstatefp8. could you verify its right? |
I think this requires a bit more work. You need to verify that you can create an optimizer with this (add test to https://github.com/pytorch/ao/blob/main/test/prototype/test_low_bit_optim.py) as well do some short training runs for sanity checks (using https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_low_bit_adam.py). I think for merging the PR, we should wait for the official code release to check numeric against them. If you don't mind, we can discuss more details in GPU-MODE discord group https://discord.gg/gpumode. Just create a thread under torchao and tag me in (@gau.nernst) |
I understand the situation for merging the PR. Will be glad to work on working on this issue. creating thread in gpumode |
4c45349
to
7be5a6b
Compare
…plying condition on k
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update. The PR is coming out nicely. There are some failing CI tests. Can you fix them, including the ruff linter?
Some extra items once that is finished:
- Update doc (link to the paper + usage)
- Run benchmark for sanity check https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_low_bit_adam.py. I'm thinking comparing between BF16 baseline, FP8 optimizer, and FP8 COAT optimizer. Feel free to select a benchmark config suitable for you. And add the benchmark results in this PR description. Ideally, it should show that FP8 COAT is better than FP8 (though we might not observe it)
…skip marker to within the function.
This is a Work in Progress PR for #1190.
As a draft PR, I have followed the first piece of advice by @gau-nernst of "extending OptimStateFp8". Have created a separate Dynamic Range Function Instead of creating a different quantize_fp8 method as it will be applied before quantization to achieve larger representation range of float8 datatypes and the class will be storing value k to inverse the it after dequantization.
Requirements:
TBA
Additional Code/logic Added:
TBA
Logic/Code changes to existing codebase:
TBA
Outcome:
TBA
Scope of Usage:
TBA
Example
TBA
Changes:
Benchmarks
Parameters
lr
)amp
)optim
)Results