Skip to content

Commit

Permalink
Add 4-bit Adam (#478)
Browse files Browse the repository at this point in the history
* add 4bit

* rename

* simplify 4bit

* add rank1 scaling

* add lpmm to benchmark

* remove rank-1 scaling

* update

* clean

* rename

* update test

* fix

* fix

* update adam

* add AdamW 4bit

* update

* remove lpmm from dev cuz CI can't compile

* fix test

* update README

* Update README.md

* update readme. small fixes

* remove zero padding
  • Loading branch information
gau-nernst authored Jul 6, 2024
1 parent 9f85488 commit 34fedff
Show file tree
Hide file tree
Showing 12 changed files with 492 additions and 198 deletions.
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
# pip install timm wandb tqdm datasets
# pip install timm wandb tqdm datasets yacs bitsandbytes git+https://github.com/thu-ml/low-bit-optimizers.git
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core
#
# python benchmarks_adam_8bit.py \
#
# python benchmark_low_bit_adam.py \
# --model "timm/vit_base_patch16_224.augreg_in21k" \
# --amp bf16 \
# --optim Adam
#
# To use bnb 8-bit optimizer, set --optim Adam8bitBnb. To use 8-bit optimizer implemented in torchao, set --optim Adam8bitAo
#
# See OPTIM_MAP for the available optimizer options
# To profile and export chrome trace, set --profile
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler

import argparse
import datetime
import math
from contextlib import nullcontext
from functools import partial
from pathlib import Path

import bitsandbytes as bnb
import datasets
import lpmm
import timm
import torch
import torch.nn.functional as F
Expand All @@ -25,7 +28,16 @@
from torchvision.transforms import v2
from tqdm import tqdm

from torchao.prototype.optim_8bit import Adam8bit
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit

# lpmm doesn't have Adam, only AdamW
OPTIM_MAP = dict(
Adam=torch.optim.Adam,
Adam8bitBnb=bnb.optim.Adam8bit,
Adam8bitAo=Adam8bit,
Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True),
Adam4bitAo=Adam4bit,
)


class CosineSchedule:
Expand Down Expand Up @@ -72,7 +84,7 @@ def get_parser():
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--n_workers", type=int, default=4)

parser.add_argument("--optim", default="Adam")
parser.add_argument("--optim", default="Adam", choices=OPTIM_MAP.keys())
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--cosine_lr_scheduler", action="store_true")
Expand Down Expand Up @@ -159,16 +171,12 @@ def evaluate_model(model, args):
model.compile(fullgraph=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

OPTIM_MAP = dict(
Adam=torch.optim.Adam,
Adam8bitBnb=bnb.optim.Adam8bit,
Adam8bitAo=Adam8bit,
)
optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay)
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)

grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")

start_time = datetime.datetime.now()
step = 0
for epoch_idx in range(args.n_epochs):
model.train()
Expand Down Expand Up @@ -208,4 +216,5 @@ def evaluate_model(model, args):
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
logger.log(dict(val_acc=val_acc), step=step)

print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / (1 << 30):.2f} GB")
print(f"Time taken: {(datetime.datetime.now() - start_time)}")
print(f"Max used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
148 changes: 148 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import copy
from functools import partial

import pytest
import torch
from torch import nn
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.prototype import low_bit_optim
from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit
from torchao.utils import TORCH_VERSION_AFTER_2_3

try:
import bitsandbytes as bnb
except ImportError:
bnb = None

try:
import lpmm
except ImportError:
lpmm = None


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


class TestQuantize(TestCase):
@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_correctness(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device)

actual_codes, actual_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=1)
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=0)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)

@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_compile(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device)

compiled_f = torch.compile(subclass_8bit.quantize_8bit_with_qmap, fullgraph=True)
actual_codes, actual_scale = compiled_f(x, qmap, 256)
expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)

@parametrize("device", _DEVICES)
def test_quantize_4bit_with_qmap_correctness(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device)

actual_codes, actual_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=1)
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=0)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)

@parametrize("device", _DEVICES)
def test_quantize_4bit_with_qmap_compile(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device)

compiled_f = torch.compile(subclass_4bit.quantize_4bit_with_qmap, fullgraph=True)
actual_codes, actual_scale = compiled_f(x, qmap, 256)
expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)


class TestOptim(TestCase):
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_optim_8bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model2 = copy.deepcopy(model1)

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())

for _ in range(2):
x = torch.randn(4, 32, device=device)

loss1 = model1(x).sum()
loss1.backward()
optim1.step()
optim1.zero_grad()

loss2 = model2(x).sum()
loss2.backward()
optim2.step()
optim2.zero_grad()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
def test_optim_4bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model2 = copy.deepcopy(model1)

# lpmm doesn't have Adam. use AdamW with no weight decay instead.
if optim_name == "Adam4bit":
optim1 = lpmm.optim.AdamW(model1.parameters(), weight_decay=0)
elif optim_name == "AdamW4bit":
optim1 = lpmm.optim.AdamW(model1.parameters())
else:
raise ValueError(f"Unsupported {optim_name} optimizer for lpmm")
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())

for _ in range(2):
x = torch.randn(4, 32, device=device)

loss1 = model1(x).sum()
loss1.backward()
optim1.step()
optim1.zero_grad()

loss2 = model2(x).sum()
loss2.backward()
optim2.step()
optim2.zero_grad()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)


if __name__ == "__main__":
run_tests()
84 changes: 0 additions & 84 deletions test/prototype/test_optim_8bit.py

This file was deleted.

2 changes: 1 addition & 1 deletion torchao/prototype/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
- `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
- [`quant_llm`](quant_llm) - FP16 x FPx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)
- [`optim_8bit`](optim_8bit) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
- [`low_bit_optim`](low_bit_optim) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers).

#### Roadmap

Expand Down
48 changes: 48 additions & 0 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Low-bit optimizers

This folder implements:

- 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507

The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.

## Usage

This is a drop-in replacement for `torch.optim.Adam`

```python
from torchao.prototype.low_bit_optim import Adam8bit

model = ...
optim = Adam8bit(model.parameters())
```

To use 4-bit Adam, replace the above with `Adam4bit`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit optimizers, and 128 for 4-bit optimizers.

**Other optimizers**: AdamW is also available as `AdamW8bit` and `AdamW4bit`. Other optimizers can be added based on demand.

NOTE:
- The low-bit optimizers require PyTorch >= 2.3
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed.

## Benchmarks

Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py).

Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER:

Adam impl | max memory (GB) | time taken | accuracy
-----------|-----------------|------------|----------
PyTorch | 12.98 | 10m 08s | 87.70
bnb 8-bit | 8.31 | 8m 38s | 86.22
ao 8-bit | 8.32 | 10m 54s | 86.67
lpmm 4-bit | 7.72 | 7m 48s | 84.70
ao 4-bit | 7.72 | 9m 17s | 85.60

NOTE: time taken includes validation time, and compile time for torchao optimizers.

## Credits

Credits to Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library, and [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers.
2 changes: 2 additions & 0 deletions torchao/prototype/low_bit_optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .adam import Adam8bit, Adam4bit
from .adamw import AdamW8bit, AdamW4bit
Loading

0 comments on commit 34fedff

Please sign in to comment.