[Low-bit optim] Improve compile time + Fix PyTorch 2.3 support for 4-bit optim #812
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Static-shape compile optim step for single parameter + disable cache size limit.
single_param_adam()
is fixed -> safe to disable cache limit without the risk of always re-compiling.Benefits
Others
TODO:
benchmarks/benchmark_low_bit_adam.py
)Llama2-7B benchmarks
Fine-tune Llama2-7B on Alpaca dataset. PyTorch 2.4, full BF16, 1 epoch, A100, fixed random seed. Benchmark is done with torchtune 52d1b838.
truthfulqa_mc2
accNOTE: lpmm's 4-bit AdamW does not support BF16 weights.