-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 4-bit Adam #478
Add 4-bit Adam #478
Changes from all commits
af8cfe2
fc53c07
f2262f0
8f56895
5b97b14
37bfb12
58af742
d402e01
b694bdc
189d500
b870663
27af0e3
20ff160
3f42c3d
9bfce3a
82a578c
bd6c5fb
eea1967
b3c8c25
cbc3f05
c4bb9e7
cb6176e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import copy | ||
from functools import partial | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
instantiate_parametrized_tests, | ||
parametrize, | ||
run_tests, | ||
) | ||
from torchao.prototype import low_bit_optim | ||
from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit | ||
from torchao.utils import TORCH_VERSION_AFTER_2_3 | ||
|
||
try: | ||
import bitsandbytes as bnb | ||
except ImportError: | ||
bnb = None | ||
|
||
try: | ||
import lpmm | ||
except ImportError: | ||
lpmm = None | ||
|
||
|
||
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
|
||
|
||
class TestQuantize(TestCase): | ||
@parametrize("device", _DEVICES) | ||
def test_quantize_8bit_with_qmap_correctness(self, device): | ||
x = torch.randn(32, 1024, device=device) | ||
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device) | ||
|
||
actual_codes, actual_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=1) | ||
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=0) | ||
|
||
torch.testing.assert_close(actual_codes, expected_codes) | ||
torch.testing.assert_close(actual_scale, expected_scale) | ||
|
||
@parametrize("device", _DEVICES) | ||
def test_quantize_8bit_with_qmap_compile(self, device): | ||
x = torch.randn(32, 1024, device=device) | ||
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device) | ||
|
||
compiled_f = torch.compile(subclass_8bit.quantize_8bit_with_qmap, fullgraph=True) | ||
actual_codes, actual_scale = compiled_f(x, qmap, 256) | ||
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256) | ||
|
||
torch.testing.assert_close(actual_codes, expected_codes) | ||
torch.testing.assert_close(actual_scale, expected_scale) | ||
|
||
@parametrize("device", _DEVICES) | ||
def test_quantize_4bit_with_qmap_correctness(self, device): | ||
x = torch.randn(32, 1024, device=device) | ||
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device) | ||
|
||
actual_codes, actual_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=1) | ||
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=0) | ||
|
||
torch.testing.assert_close(actual_codes, expected_codes) | ||
torch.testing.assert_close(actual_scale, expected_scale) | ||
|
||
@parametrize("device", _DEVICES) | ||
def test_quantize_4bit_with_qmap_compile(self, device): | ||
x = torch.randn(32, 1024, device=device) | ||
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device) | ||
|
||
compiled_f = torch.compile(subclass_4bit.quantize_4bit_with_qmap, fullgraph=True) | ||
actual_codes, actual_scale = compiled_f(x, qmap, 256) | ||
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256) | ||
|
||
torch.testing.assert_close(actual_codes, expected_codes) | ||
torch.testing.assert_close(actual_scale, expected_scale) | ||
|
||
|
||
class TestOptim(TestCase): | ||
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") | ||
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") | ||
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) | ||
def test_optim_8bit_correctness(self, optim_name): | ||
device = "cuda" | ||
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) | ||
model2 = copy.deepcopy(model1) | ||
|
||
optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) | ||
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) | ||
|
||
for _ in range(2): | ||
x = torch.randn(4, 32, device=device) | ||
|
||
loss1 = model1(x).sum() | ||
loss1.backward() | ||
optim1.step() | ||
optim1.zero_grad() | ||
|
||
loss2 = model2(x).sum() | ||
loss2.backward() | ||
optim2.step() | ||
optim2.zero_grad() | ||
|
||
for p1, p2 in zip(model1.parameters(), model2.parameters()): | ||
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) | ||
|
||
@pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle") | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") | ||
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") | ||
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) | ||
def test_optim_4bit_correctness(self, optim_name): | ||
device = "cuda" | ||
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) | ||
model2 = copy.deepcopy(model1) | ||
|
||
# lpmm doesn't have Adam. use AdamW with no weight decay instead. | ||
if optim_name == "Adam4bit": | ||
optim1 = lpmm.optim.AdamW(model1.parameters(), weight_decay=0) | ||
elif optim_name == "AdamW4bit": | ||
optim1 = lpmm.optim.AdamW(model1.parameters()) | ||
else: | ||
raise ValueError(f"Unsupported {optim_name} optimizer for lpmm") | ||
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) | ||
|
||
for _ in range(2): | ||
x = torch.randn(4, 32, device=device) | ||
|
||
loss1 = model1(x).sum() | ||
loss1.backward() | ||
optim1.step() | ||
optim1.zero_grad() | ||
|
||
loss2 = model2(x).sum() | ||
loss2.backward() | ||
optim2.step() | ||
optim2.zero_grad() | ||
|
||
for p1, p2 in zip(model1.parameters(), model2.parameters()): | ||
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) | ||
|
||
|
||
instantiate_parametrized_tests(TestQuantize) | ||
instantiate_parametrized_tests(TestOptim) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Low-bit optimizers | ||
|
||
This folder implements: | ||
|
||
- 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861 | ||
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507 | ||
|
||
The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. | ||
|
||
## Usage | ||
|
||
This is a drop-in replacement for `torch.optim.Adam` | ||
|
||
```python | ||
from torchao.prototype.low_bit_optim import Adam8bit | ||
|
||
model = ... | ||
optim = Adam8bit(model.parameters()) | ||
``` | ||
|
||
To use 4-bit Adam, replace the above with `Adam4bit`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit optimizers, and 128 for 4-bit optimizers. | ||
|
||
**Other optimizers**: AdamW is also available as `AdamW8bit` and `AdamW4bit`. Other optimizers can be added based on demand. | ||
|
||
NOTE: | ||
- The low-bit optimizers require PyTorch >= 2.3 | ||
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. | ||
- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed. | ||
|
||
## Benchmarks | ||
|
||
Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py). | ||
|
||
Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER: | ||
|
||
Adam impl | max memory (GB) | time taken | accuracy | ||
-----------|-----------------|------------|---------- | ||
PyTorch | 12.98 | 10m 08s | 87.70 | ||
bnb 8-bit | 8.31 | 8m 38s | 86.22 | ||
ao 8-bit | 8.32 | 10m 54s | 86.67 | ||
lpmm 4-bit | 7.72 | 7m 48s | 84.70 | ||
ao 4-bit | 7.72 | 9m 17s | 85.60 | ||
|
||
NOTE: time taken includes validation time, and compile time for torchao optimizers. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh didn't notice you're including compile times, it's customary to exclude that out cause it will be amortized over more steps There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the optimizer compile time is a few seconds or slightly more at most, so should already be amortized over the 10min duration. The per-step speed after compile is indeed slower than bnb and lpmm (which used custom CUDA kernels) for larger models. Haven't digged into this yet. I'm lazy to re-run the benchmarks since it takes some time. You can help me run if you want. I used this command |
||
|
||
## Credits | ||
|
||
Credits to Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library, and [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .adam import Adam8bit, Adam4bit | ||
from .adamw import AdamW8bit, AdamW4bit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remind me how come? You had some nice charts for the 8bit optimizer for convergence tests, was hoping to see something similar for this PR
Also did you have some theory as to the delta with lpmm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean why I didn't implement rank-1 normalization? I did it previously, and then removed it because (1) by default, lpmm 4-bit optimizer doesn't use rank-1 normalization (just group-wise scaling as usual) and they don't have fused kernel for rank-1 normalization (2) to keep the code simpler. Adding rank-1 normalization is a bit "hacky" and requires quite a big chunk of code (for more details, you can trace the quant logic here: https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/functional.py#L190-L203). And prelim results I did showed that rank-1 normalization was not better, so I removed it.
You mean the wandb charts? Ya I didn't do wandb logging when I ran the benchmarks this time. You can help me run to produce the charts if you want. Just add
--project something --run_name something
to the benchmark script and it will log to wandb.Which delta are you referring to? Speed or accuracy? In terms of accuracy, at least in the benchmark run I did, accuracy was better. In terms of speed, I haven't looked into it much.