-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT / Trainer: Add adamw 4bit optimizer #31865
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! LG2M, cc @msaroufim :)
There's also an Also cc @gau-nernst this is very exciting! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
Nice ! I'll add it in a separate PR ! |
This reverts commit 25278e8.
Heads up @SunMarc we just released torchao 0.4! https://github.com/pytorch/ao/releases/tag/v0.4.0 |
Nice ! I'll merge it as soon as we merge the torchao quantization PR in transformers as there is some overlap ! |
* add 4bit optimizer * style * fix msg * style * add qgalore * Revert "add qgalore" This reverts commit 25278e8. * style * version check
* add 4bit optimizer * style * fix msg * style * add qgalore * Revert "add qgalore" This reverts commit 25278e8. * style * version check
Hi, is it OK if I PR to add the 8bit counterpart? (see #34893 for details) Thanks! |
* add 4bit optimizer * style * fix msg * style * add qgalore * Revert "add qgalore" This reverts commit 25278e8. * style * version check
* add 4bit optimizer * style * fix msg * style * add qgalore * Revert "add qgalore" This reverts commit 25278e8. * style * version check
What does this PR do ?
This PR adds the 4-bit optimizer from torchao library into HF Trainer. For now, it requires the main branch of
torchao
and torch >=2.3 (maybe we can wait a bit before merging). For those who wants to try, you can passoptim="adamw_torch_4bit"
inTrainingArguments
.Since we already have the 8-bit optimizer from bnb that works well, i'm not adding it.
Related thread : https://x.com/marksaroufim/status/1809398186198593566
cc @muellerzr as you might be interested