Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Low-bit optim] Improve compile time + Fix PyTorch 2.3 support for 4-bit optim #812

Merged
merged 7 commits into from
Sep 5, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Sep 5, 2024

Static-shape compile optim step for single parameter + disable cache size limit.

  • For a given model, the number of different argument combinations to single_param_adam() is fixed -> safe to disable cache limit without the risk of always re-compiling.

Benefits

  • Improve compile time for 8-bit and FP8 optim (since we don't compile optim step for all parameters at once anymore) -> no noticeable compile time now 🤯
  • Improve speed for 4-bit optim (thanks to static shape) -> on par with 8-bit optim now
  • Fix PyTorch 2.3 support for 4-bit optim (introduced by Move more utils to TorchAOBaseTensor #784 (comment))
  • (Unintended) Fix unusual memory usage of FP8 optim -> same memory footprint as 8-bit optim now

Others

TODO:

Llama2-7B benchmarks

Fine-tune Llama2-7B on Alpaca dataset. PyTorch 2.4, full BF16, 1 epoch, A100, fixed random seed. Benchmark is done with torchtune 52d1b838.

AdamW impl Peak memory allocated (GB) toks/s truthfulqa_mc2 acc
Not fine-tuned - - 38.95
PyTorch (fused) 51.6 3200 42.61
bnb 8-bit 39.3 3000 42.75
ao 8-bit 39.1 2900 41.50
ao 4-bit 33.2 2900 42.27

NOTE: lpmm's 4-bit AdamW does not support BF16 weights.

image

Copy link

pytorch-bot bot commented Sep 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/812

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d144f42 with merge base 599319f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 5, 2024
@gau-nernst gau-nernst marked this pull request as draft September 5, 2024 02:43
@gau-nernst gau-nernst marked this pull request as ready for review September 5, 2024 09:56
@gau-nernst gau-nernst requested a review from msaroufim September 5, 2024 11:35
@msaroufim msaroufim merged commit 1e7f132 into pytorch:main Sep 5, 2024
17 checks passed
@gau-nernst gau-nernst deleted the low_bit_optim_fix branch September 5, 2024 16:02
HDCharles pushed a commit that referenced this pull request Sep 9, 2024
…bit optim (#812)

* disable recompile limit

* remove _prepare_param_groups()

* re-enable FSDP test. update ViT benchmarks

* update

* update

* update readme
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* clean up unused files

* fix tests: HF TOKEN not available on-pr, add evaluation.md to tests

* markup docs

* fix evaluations.md

* add markup to native execution md

* install wget for gguf.md testing, prevent evaluation.md failures

* remove secrets from yml files

* update

* remove copy pasta from macosand macos-mps tests

* typo

* format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants