Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 4-bit Adam #478

Merged
merged 22 commits into from
Jul 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
# pip install timm wandb tqdm datasets
# pip install timm wandb tqdm datasets yacs bitsandbytes git+https://github.com/thu-ml/low-bit-optimizers.git
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core
#
# python benchmarks_adam_8bit.py \
#
# python benchmark_low_bit_adam.py \
# --model "timm/vit_base_patch16_224.augreg_in21k" \
# --amp bf16 \
# --optim Adam
#
# To use bnb 8-bit optimizer, set --optim Adam8bitBnb. To use 8-bit optimizer implemented in torchao, set --optim Adam8bitAo
#
# See OPTIM_MAP for the available optimizer options
# To profile and export chrome trace, set --profile
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler

import argparse
import datetime
import math
from contextlib import nullcontext
from functools import partial
from pathlib import Path

import bitsandbytes as bnb
import datasets
import lpmm
import timm
import torch
import torch.nn.functional as F
Expand All @@ -25,7 +28,16 @@
from torchvision.transforms import v2
from tqdm import tqdm

from torchao.prototype.optim_8bit import Adam8bit
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit

# lpmm doesn't have Adam, only AdamW
OPTIM_MAP = dict(
Adam=torch.optim.Adam,
Adam8bitBnb=bnb.optim.Adam8bit,
Adam8bitAo=Adam8bit,
Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True),
Adam4bitAo=Adam4bit,
)


class CosineSchedule:
Expand Down Expand Up @@ -72,7 +84,7 @@ def get_parser():
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--n_workers", type=int, default=4)

parser.add_argument("--optim", default="Adam")
parser.add_argument("--optim", default="Adam", choices=OPTIM_MAP.keys())
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--cosine_lr_scheduler", action="store_true")
Expand Down Expand Up @@ -159,16 +171,12 @@ def evaluate_model(model, args):
model.compile(fullgraph=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

OPTIM_MAP = dict(
Adam=torch.optim.Adam,
Adam8bitBnb=bnb.optim.Adam8bit,
Adam8bitAo=Adam8bit,
)
optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay)
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)

grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")

start_time = datetime.datetime.now()
step = 0
for epoch_idx in range(args.n_epochs):
model.train()
Expand Down Expand Up @@ -208,4 +216,5 @@ def evaluate_model(model, args):
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
logger.log(dict(val_acc=val_acc), step=step)

print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / (1 << 30):.2f} GB")
print(f"Time taken: {(datetime.datetime.now() - start_time)}")
print(f"Max used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
148 changes: 148 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import copy
from functools import partial

import pytest
import torch
from torch import nn
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.prototype import low_bit_optim
from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit
from torchao.utils import TORCH_VERSION_AFTER_2_3

try:
import bitsandbytes as bnb
except ImportError:
bnb = None

try:
import lpmm
except ImportError:
lpmm = None


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


class TestQuantize(TestCase):
@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_correctness(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device)

actual_codes, actual_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=1)
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=0)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)

@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_compile(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device)

compiled_f = torch.compile(subclass_8bit.quantize_8bit_with_qmap, fullgraph=True)
actual_codes, actual_scale = compiled_f(x, qmap, 256)
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)

@parametrize("device", _DEVICES)
def test_quantize_4bit_with_qmap_correctness(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device)

actual_codes, actual_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=1)
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=0)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)

@parametrize("device", _DEVICES)
def test_quantize_4bit_with_qmap_compile(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device)

compiled_f = torch.compile(subclass_4bit.quantize_4bit_with_qmap, fullgraph=True)
actual_codes, actual_scale = compiled_f(x, qmap, 256)
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)


class TestOptim(TestCase):
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_optim_8bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model2 = copy.deepcopy(model1)

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())

for _ in range(2):
x = torch.randn(4, 32, device=device)

loss1 = model1(x).sum()
loss1.backward()
optim1.step()
optim1.zero_grad()

loss2 = model2(x).sum()
loss2.backward()
optim2.step()
optim2.zero_grad()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
def test_optim_4bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model2 = copy.deepcopy(model1)

# lpmm doesn't have Adam. use AdamW with no weight decay instead.
if optim_name == "Adam4bit":
optim1 = lpmm.optim.AdamW(model1.parameters(), weight_decay=0)
elif optim_name == "AdamW4bit":
optim1 = lpmm.optim.AdamW(model1.parameters())
else:
raise ValueError(f"Unsupported {optim_name} optimizer for lpmm")
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())

for _ in range(2):
x = torch.randn(4, 32, device=device)

loss1 = model1(x).sum()
loss1.backward()
optim1.step()
optim1.zero_grad()

loss2 = model2(x).sum()
loss2.backward()
optim2.step()
optim2.zero_grad()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)


if __name__ == "__main__":
run_tests()
84 changes: 0 additions & 84 deletions test/prototype/test_optim_8bit.py

This file was deleted.

2 changes: 1 addition & 1 deletion torchao/prototype/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
- `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
- [`quant_llm`](quant_llm) - FP16 x FPx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)
- [`optim_8bit`](optim_8bit) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
- [`low_bit_optim`](low_bit_optim) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers).

#### Roadmap

Expand Down
48 changes: 48 additions & 0 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Low-bit optimizers

This folder implements:

- 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507

The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.

## Usage

This is a drop-in replacement for `torch.optim.Adam`

```python
from torchao.prototype.low_bit_optim import Adam8bit

model = ...
optim = Adam8bit(model.parameters())
```

To use 4-bit Adam, replace the above with `Adam4bit`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit optimizers, and 128 for 4-bit optimizers.

**Other optimizers**: AdamW is also available as `AdamW8bit` and `AdamW4bit`. Other optimizers can be added based on demand.

NOTE:
- The low-bit optimizers require PyTorch >= 2.3
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
Copy link
Member

@msaroufim msaroufim Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me how come? You had some nice charts for the 8bit optimizer for convergence tests, was hoping to see something similar for this PR

Also did you have some theory as to the delta with lpmm

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me how come?

You mean why I didn't implement rank-1 normalization? I did it previously, and then removed it because (1) by default, lpmm 4-bit optimizer doesn't use rank-1 normalization (just group-wise scaling as usual) and they don't have fused kernel for rank-1 normalization (2) to keep the code simpler. Adding rank-1 normalization is a bit "hacky" and requires quite a big chunk of code (for more details, you can trace the quant logic here: https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/functional.py#L190-L203). And prelim results I did showed that rank-1 normalization was not better, so I removed it.

You had some nice charts for the 8bit optimizer for convergence tests, was hoping to see something similar for this PR

You mean the wandb charts? Ya I didn't do wandb logging when I ran the benchmarks this time. You can help me run to produce the charts if you want. Just add --project something --run_name something to the benchmark script and it will log to wandb.

Also did you have some theory as to the delta with lpmm

Which delta are you referring to? Speed or accuracy? In terms of accuracy, at least in the benchmark run I did, accuracy was better. In terms of speed, I haven't looked into it much.

- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed.

## Benchmarks

Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).

Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER:

Adam impl | max memory (GB) | time taken | accuracy
-----------|-----------------|------------|----------
PyTorch | 12.98 | 10m 08s | 87.70
bnb 8-bit | 8.31 | 8m 38s | 86.22
ao 8-bit | 8.32 | 10m 54s | 86.67
lpmm 4-bit | 7.72 | 7m 48s | 84.70
ao 4-bit | 7.72 | 9m 17s | 85.60

NOTE: time taken includes validation time, and compile time for torchao optimizers.
Copy link
Member

@msaroufim msaroufim Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh didn't notice you're including compile times, it's customary to exclude that out cause it will be amortized over more steps

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the optimizer compile time is a few seconds or slightly more at most, so should already be amortized over the 10min duration. The per-step speed after compile is indeed slower than bnb and lpmm (which used custom CUDA kernels) for larger models. Haven't digged into this yet.

I'm lazy to re-run the benchmarks since it takes some time. You can help me run if you want. I used this command python benchmarks/benchmark_low_bit_adam.py --model timm/vit_huge_patch14_224.orig_in21k --amp bf16 --optim Adam4bitAo --compile --batch_size 8 --n_epochs 1 --lr 1e-5. You can use epochs=2 and change the timer to measure training time for 2nd epoch only (to remove optimizer compile time).


## Credits

Credits to Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library, and [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers.
2 changes: 2 additions & 0 deletions torchao/prototype/low_bit_optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .adam import Adam8bit, Adam4bit
from .adamw import AdamW8bit, AdamW4bit
Loading
Loading