-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 4-bit Adam #478
Add 4-bit Adam #478
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/478
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cb6176e with merge base a2e8e2a (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
NOTE: | ||
- The low-bit optimizers require PyTorch >= 2.3 | ||
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remind me how come? You had some nice charts for the 8bit optimizer for convergence tests, was hoping to see something similar for this PR
Also did you have some theory as to the delta with lpmm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remind me how come?
You mean why I didn't implement rank-1 normalization? I did it previously, and then removed it because (1) by default, lpmm 4-bit optimizer doesn't use rank-1 normalization (just group-wise scaling as usual) and they don't have fused kernel for rank-1 normalization (2) to keep the code simpler. Adding rank-1 normalization is a bit "hacky" and requires quite a big chunk of code (for more details, you can trace the quant logic here: https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/functional.py#L190-L203). And prelim results I did showed that rank-1 normalization was not better, so I removed it.
You had some nice charts for the 8bit optimizer for convergence tests, was hoping to see something similar for this PR
You mean the wandb charts? Ya I didn't do wandb logging when I ran the benchmarks this time. You can help me run to produce the charts if you want. Just add --project something --run_name something
to the benchmark script and it will log to wandb.
Also did you have some theory as to the delta with lpmm
Which delta are you referring to? Speed or accuracy? In terms of accuracy, at least in the benchmark run I did, accuracy was better. In terms of speed, I haven't looked into it much.
lpmm 4-bit | 7.72 | 7m 48s | 84.70 | ||
ao 4-bit | 7.72 | 9m 17s | 85.60 | ||
|
||
NOTE: time taken includes validation time, and compile time for torchao optimizers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh didn't notice you're including compile times, it's customary to exclude that out cause it will be amortized over more steps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the optimizer compile time is a few seconds or slightly more at most, so should already be amortized over the 10min duration. The per-step speed after compile is indeed slower than bnb and lpmm (which used custom CUDA kernels) for larger models. Haven't digged into this yet.
I'm lazy to re-run the benchmarks since it takes some time. You can help me run if you want. I used this command python benchmarks/benchmark_low_bit_adam.py --model timm/vit_huge_patch14_224.orig_in21k --amp bf16 --optim Adam4bitAo --compile --batch_size 8 --n_epochs 1 --lr 1e-5
. You can use epochs=2 and change the timer to measure training time for 2nd epoch only (to remove optimizer compile time).
from torch import Tensor | ||
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE | ||
|
||
from .subclass_8bit import create_dynamic_map |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wanna pull this out into its own file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only 1 function so I personally think keeping it like this is fine. But if you prefer moving it to a separate file (to avoid dependency of "same-level" files), I'm ok too.
|
||
# GPU-friendly binary search | ||
# https://blog.demofox.org/2017/06/20/simd-gpu-friendly-branchless-binary-search/ | ||
elif implementation == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this still blows my mind :)
gap = 256 - len(data) | ||
for i in range(gap): | ||
data.append(0) | ||
# gap = 256 - len(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete?
The branchless binary search example kinda blew my mind so writing a short explanation for others So let's say we have a tensor and we'd like to cast it to 4 bit and want to search over a codebook that's not uniform with nf4 (normalized float4) being the most popular example. qmap will in the case of a 4 bit optimizer contain 2^4 = 16 elements and if you 0 index the largest index is 15 nf4_qmap = [-1.0, -0.6961928009986877, -0.5250730514526367,
-0.39491748809814453, -0.28444138169288635, -0.18477343022823334,
-0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725,
0.24611230194568634, 0.33791524171829224, 0.44070982933044434,
0.5626170039176941, 0.7229568362236023, 1.0] In this case when quantizing a new parameter instead of going through the entire codebook linearly you can binary search over it to find the closest element from the source dtype. The codebook or qmap has 16 elements and so you'll need at most 4 binary search steps The key intuition is listed in this blog https://blog.demofox.org/2017/06/20/simd-gpu-friendly-branchless-binary-search/ With the full algorithm being input = input.view(-1)
codes = torch.where(input >= qmap[8], 8, 0)
codes += torch.where(input >= qmap[codes + 4], 4, 0)
codes += torch.where(input >= qmap[codes + 2], 2, 0)
codes += torch.where(input >= qmap[codes + 1], 1, 0) The first step Now for the steps that follow, typically in binary search the index value can go up or down but in this case we will rewrite binary search such that the index increases monotonically in 1 direction from left to right. The second step which is more interesting does The third step we continue the binary search starting from the previously found index But now we go to the third step and and now we search at index When we're doing a binary search over codebooks that are typically small because we usually quantize to 8 bit (2^8) or 4 bit (2^4) then we can manually unroll the binary search algorithm into 8 or 4 steps respectively. Another way to think of about this is this a binary search where the value of the index increases monotonically in 1 direction because at every step you ask yourself do I add 8 or 0 to the index and then 4 or 0 and then 2 or 0 and finally 1 or 0. If you add them all up |
Adding on about the binary search. I haven't verified myself, but I think apart from time complexity improvement, it is also an improvement in memory usage, which can help with kernel fusion as less shared memory is needed. The naive approach |
* add 4bit * rename * simplify 4bit * add rank1 scaling * add lpmm to benchmark * remove rank-1 scaling * update * clean * rename * update test * fix * fix * update adam * add AdamW 4bit * update * remove lpmm from dev cuz CI can't compile * fix test * update README * Update README.md * update readme. small fixes * remove zero padding
From https://github.com/thu-ml/low-bit-optimizers
Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER:
Sanity check: