Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 4-bit Adam #478

Merged
merged 22 commits into from
Jul 6, 2024
Merged

Add 4-bit Adam #478

merged 22 commits into from
Jul 6, 2024

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Jul 5, 2024

From https://github.com/thu-ml/low-bit-optimizers

Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER:

Adam impl max memory (GB) time taken accuracy
PyTorch 12.98 10m 08s 87.70
bnb 8-bit 8.31 8m 38s 86.22
ao 8-bit 8.32 10m 54s 86.67
lpmm 4-bit 7.72 7m 48s 84.70
ao 4-bit 7.72 9m 17s 85.60

Sanity check:

  • Model size (FP32) = 0.630 x 4 = 2.52 GB
  • Optim state in FP32 = 2.52 x 2 = 5.04 GB
  • Optim state in 8-bit = 5.04 / 4 = 1.26 GB -> 3.78 GB savings
  • Optim state in 4-bit = 5.04 / 8 = 0.63 GB -> 4.41 GB savings

Copy link

pytorch-bot bot commented Jul 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/478

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cb6176e with merge base a2e8e2a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 5, 2024
@msaroufim msaroufim requested review from msaroufim and janeyx99 July 5, 2024 02:50
@gau-nernst gau-nernst marked this pull request as ready for review July 5, 2024 12:20

NOTE:
- The low-bit optimizers require PyTorch >= 2.3
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
Copy link
Member

@msaroufim msaroufim Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me how come? You had some nice charts for the 8bit optimizer for convergence tests, was hoping to see something similar for this PR

Also did you have some theory as to the delta with lpmm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me how come?

You mean why I didn't implement rank-1 normalization? I did it previously, and then removed it because (1) by default, lpmm 4-bit optimizer doesn't use rank-1 normalization (just group-wise scaling as usual) and they don't have fused kernel for rank-1 normalization (2) to keep the code simpler. Adding rank-1 normalization is a bit "hacky" and requires quite a big chunk of code (for more details, you can trace the quant logic here: https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/functional.py#L190-L203). And prelim results I did showed that rank-1 normalization was not better, so I removed it.

You had some nice charts for the 8bit optimizer for convergence tests, was hoping to see something similar for this PR

You mean the wandb charts? Ya I didn't do wandb logging when I ran the benchmarks this time. You can help me run to produce the charts if you want. Just add --project something --run_name something to the benchmark script and it will log to wandb.

Also did you have some theory as to the delta with lpmm

Which delta are you referring to? Speed or accuracy? In terms of accuracy, at least in the benchmark run I did, accuracy was better. In terms of speed, I haven't looked into it much.

lpmm 4-bit | 7.72 | 7m 48s | 84.70
ao 4-bit | 7.72 | 9m 17s | 85.60

NOTE: time taken includes validation time, and compile time for torchao optimizers.
Copy link
Member

@msaroufim msaroufim Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh didn't notice you're including compile times, it's customary to exclude that out cause it will be amortized over more steps

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the optimizer compile time is a few seconds or slightly more at most, so should already be amortized over the 10min duration. The per-step speed after compile is indeed slower than bnb and lpmm (which used custom CUDA kernels) for larger models. Haven't digged into this yet.

I'm lazy to re-run the benchmarks since it takes some time. You can help me run if you want. I used this command python benchmarks/benchmark_low_bit_adam.py --model timm/vit_huge_patch14_224.orig_in21k --amp bf16 --optim Adam4bitAo --compile --batch_size 8 --n_epochs 1 --lr 1e-5. You can use epochs=2 and change the timer to measure training time for 2nd epoch only (to remove optimizer compile time).

from torch import Tensor
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE

from .subclass_8bit import create_dynamic_map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wanna pull this out into its own file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only 1 function so I personally think keeping it like this is fine. But if you prefer moving it to a separate file (to avoid dependency of "same-level" files), I'm ok too.


# GPU-friendly binary search
# https://blog.demofox.org/2017/06/20/simd-gpu-friendly-branchless-binary-search/
elif implementation == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this still blows my mind :)

gap = 256 - len(data)
for i in range(gap):
data.append(0)
# gap = 256 - len(data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete?

@msaroufim
Copy link
Member

msaroufim commented Jul 5, 2024

The branchless binary search example kinda blew my mind so writing a short explanation for others

So let's say we have a tensor and we'd like to cast it to 4 bit and want to search over a codebook that's not uniform with nf4 (normalized float4) being the most popular example. qmap will in the case of a 4 bit optimizer contain 2^4 = 16 elements and if you 0 index the largest index is 15

nf4_qmap = [-1.0, -0.6961928009986877, -0.5250730514526367,
-0.39491748809814453, -0.28444138169288635, -0.18477343022823334,
-0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725,
0.24611230194568634, 0.33791524171829224, 0.44070982933044434,
0.5626170039176941, 0.7229568362236023, 1.0]

In this case when quantizing a new parameter instead of going through the entire codebook linearly you can binary search over it to find the closest element from the source dtype. The codebook or qmap has 16 elements and so you'll need at most 4 binary search steps

The key intuition is listed in this blog https://blog.demofox.org/2017/06/20/simd-gpu-friendly-branchless-binary-search/

With the full algorithm being

input = input.view(-1)
codes = torch.where(input >= qmap[8], 8, 0)
codes += torch.where(input >= qmap[codes + 4], 4, 0)
codes += torch.where(input >= qmap[codes + 2], 2, 0)
codes += torch.where(input >= qmap[codes + 1], 1, 0)

The first step input = input.view(-1) is because as we're casting the elements of the tensor 1 by1 we can treat it as a vector

Now for the steps that follow, typically in binary search the index value can go up or down but in this case we will rewrite binary search such that the index increases monotonically in 1 direction from left to right.

The second step which is more interesting does codes = torch.where(input >= qmap[8], 8, 0) and what this will do is check if an input value is larger than the midpoint of the qmap (remember it has 16 elements) if it is indeed larger then we know the index we're searching for is at least at 8 so we set codes to 8. Keep in mind that this is also a vectorized operation otherwise we would have called this code

The third step we continue the binary search starting from the previously found index 8 and we can to add it 4 so we're now doing binary search at index 12 in this case let's assume input is not larger than qmap[4+8=12] we add 0 to the index so our codes index is still 8

But now we go to the third step and and now we search at index 8+0+2 and let's assume again that's too small then we end our search at 8+0+1

When we're doing a binary search over codebooks that are typically small because we usually quantize to 8 bit (2^8) or 4 bit (2^4) then we can manually unroll the binary search algorithm into 8 or 4 steps respectively. Another way to think of about this is this a binary search where the value of the index increases monotonically in 1 direction because at every step you ask yourself do I add 8 or 0 to the index and then 4 or 0 and then 2 or 0 and finally 1 or 0. If you add them all up 8 + 4 + 2 + 1 = 15 which is the max index

@gau-nernst
Copy link
Collaborator Author

Adding on about the binary search. I haven't verified myself, but I think apart from time complexity improvement, it is also an improvement in memory usage, which can help with kernel fusion as less shared memory is needed. The naive approach codes = (qmap.view(1, -1) - input.view(-1, 1)).abs().argmin(-1) will required (input_size, codebook_size) memory - or if we tile the input, it will be (tile_size, codebook_size) of shared memory, while the binary search approach only needs (2-5x tile_size) of memory (including rounding logic).

@msaroufim msaroufim self-requested a review July 6, 2024 00:13
@msaroufim msaroufim merged commit 34fedff into pytorch:main Jul 6, 2024
13 checks passed
@gau-nernst gau-nernst deleted the 4bit_adam branch July 6, 2024 00:56
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
* add 4bit

* rename

* simplify 4bit

* add rank1 scaling

* add lpmm to benchmark

* remove rank-1 scaling

* update

* clean

* rename

* update test

* fix

* fix

* update adam

* add AdamW 4bit

* update

* remove lpmm from dev cuz CI can't compile

* fix test

* update README

* Update README.md

* update readme. small fixes

* remove zero padding
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants