diff --git a/benchmarks/benchmark_adam_8bit.py b/benchmarks/benchmark_low_bit_adam.py similarity index 89% rename from benchmarks/benchmark_adam_8bit.py rename to benchmarks/benchmark_low_bit_adam.py index 0eee275f0c..63fdeff3dd 100644 --- a/benchmarks/benchmark_adam_8bit.py +++ b/benchmarks/benchmark_low_bit_adam.py @@ -1,22 +1,25 @@ -# pip install timm wandb tqdm datasets +# pip install timm wandb tqdm datasets yacs bitsandbytes git+https://github.com/thu-ml/low-bit-optimizers.git # To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core -# -# python benchmarks_adam_8bit.py \ +# +# python benchmark_low_bit_adam.py \ # --model "timm/vit_base_patch16_224.augreg_in21k" \ # --amp bf16 \ # --optim Adam -# -# To use bnb 8-bit optimizer, set --optim Adam8bitBnb. To use 8-bit optimizer implemented in torchao, set --optim Adam8bitAo +# +# See OPTIM_MAP for the available optimizer options # To profile and export chrome trace, set --profile # To enable cosine learning rate scheduler, set --cosine_lr_scheduler import argparse +import datetime import math from contextlib import nullcontext +from functools import partial from pathlib import Path import bitsandbytes as bnb import datasets +import lpmm import timm import torch import torch.nn.functional as F @@ -25,7 +28,16 @@ from torchvision.transforms import v2 from tqdm import tqdm -from torchao.prototype.optim_8bit import Adam8bit +from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit + +# lpmm doesn't have Adam, only AdamW +OPTIM_MAP = dict( + Adam=torch.optim.Adam, + Adam8bitBnb=bnb.optim.Adam8bit, + Adam8bitAo=Adam8bit, + Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True), + Adam4bitAo=Adam4bit, +) class CosineSchedule: @@ -72,7 +84,7 @@ def get_parser(): parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--n_workers", type=int, default=4) - parser.add_argument("--optim", default="Adam") + parser.add_argument("--optim", default="Adam", choices=OPTIM_MAP.keys()) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--weight_decay", type=float, default=0) parser.add_argument("--cosine_lr_scheduler", action="store_true") @@ -159,16 +171,12 @@ def evaluate_model(model, args): model.compile(fullgraph=True) print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") - OPTIM_MAP = dict( - Adam=torch.optim.Adam, - Adam8bitBnb=bnb.optim.Adam8bit, - Adam8bitAo=Adam8bit, - ) optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay) lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") + start_time = datetime.datetime.now() step = 0 for epoch_idx in range(args.n_epochs): model.train() @@ -208,4 +216,5 @@ def evaluate_model(model, args): print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}") logger.log(dict(val_acc=val_acc), step=step) - print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / (1 << 30):.2f} GB") + print(f"Time taken: {(datetime.datetime.now() - start_time)}") + print(f"Max used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py new file mode 100644 index 0000000000..c995d83c85 --- /dev/null +++ b/test/prototype/test_low_bit_optim.py @@ -0,0 +1,148 @@ +import copy +from functools import partial + +import pytest +import torch +from torch import nn +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) +from torchao.prototype import low_bit_optim +from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit +from torchao.utils import TORCH_VERSION_AFTER_2_3 + +try: + import bitsandbytes as bnb +except ImportError: + bnb = None + +try: + import lpmm +except ImportError: + lpmm = None + + +_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + +class TestQuantize(TestCase): + @parametrize("device", _DEVICES) + def test_quantize_8bit_with_qmap_correctness(self, device): + x = torch.randn(32, 1024, device=device) + qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device) + + actual_codes, actual_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=1) + expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256, implementation=0) + + torch.testing.assert_close(actual_codes, expected_codes) + torch.testing.assert_close(actual_scale, expected_scale) + + @parametrize("device", _DEVICES) + def test_quantize_8bit_with_qmap_compile(self, device): + x = torch.randn(32, 1024, device=device) + qmap = torch.tensor(subclass_8bit.QMAP_SIGNED, device=device) + + compiled_f = torch.compile(subclass_8bit.quantize_8bit_with_qmap, fullgraph=True) + actual_codes, actual_scale = compiled_f(x, qmap, 256) + expected_codes, expected_scale = subclass_8bit.quantize_8bit_with_qmap(x, qmap, 256) + + torch.testing.assert_close(actual_codes, expected_codes) + torch.testing.assert_close(actual_scale, expected_scale) + + @parametrize("device", _DEVICES) + def test_quantize_4bit_with_qmap_correctness(self, device): + x = torch.randn(32, 1024, device=device) + qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device) + + actual_codes, actual_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=1) + expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256, implementation=0) + + torch.testing.assert_close(actual_codes, expected_codes) + torch.testing.assert_close(actual_scale, expected_scale) + + @parametrize("device", _DEVICES) + def test_quantize_4bit_with_qmap_compile(self, device): + x = torch.randn(32, 1024, device=device) + qmap = torch.tensor(subclass_4bit.QMAP_SIGNED, device=device) + + compiled_f = torch.compile(subclass_4bit.quantize_4bit_with_qmap, fullgraph=True) + actual_codes, actual_scale = compiled_f(x, qmap, 256) + expected_codes, expected_scale = subclass_4bit.quantize_4bit_with_qmap(x, qmap, 256) + + torch.testing.assert_close(actual_codes, expected_codes) + torch.testing.assert_close(actual_scale, expected_scale) + + +class TestOptim(TestCase): + @pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") + @pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") + @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) + def test_optim_8bit_correctness(self, optim_name): + device = "cuda" + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model2 = copy.deepcopy(model1) + + optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) + optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) + + for _ in range(2): + x = torch.randn(4, 32, device=device) + + loss1 = model1(x).sum() + loss1.backward() + optim1.step() + optim1.zero_grad() + + loss2 = model2(x).sum() + loss2.backward() + optim2.step() + optim2.zero_grad() + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA") + @pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") + @parametrize("optim_name", ["Adam4bit", "AdamW4bit"]) + def test_optim_4bit_correctness(self, optim_name): + device = "cuda" + model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model2 = copy.deepcopy(model1) + + # lpmm doesn't have Adam. use AdamW with no weight decay instead. + if optim_name == "Adam4bit": + optim1 = lpmm.optim.AdamW(model1.parameters(), weight_decay=0) + elif optim_name == "AdamW4bit": + optim1 = lpmm.optim.AdamW(model1.parameters()) + else: + raise ValueError(f"Unsupported {optim_name} optimizer for lpmm") + optim2 = getattr(low_bit_optim, optim_name)(model2.parameters()) + + for _ in range(2): + x = torch.randn(4, 32, device=device) + + loss1 = model1(x).sum() + loss1.backward() + optim1.step() + optim1.zero_grad() + + loss2 = model2(x).sum() + loss2.backward() + optim2.step() + optim2.zero_grad() + + for p1, p2 in zip(model1.parameters(), model2.parameters()): + torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) + + +instantiate_parametrized_tests(TestQuantize) +instantiate_parametrized_tests(TestOptim) + + +if __name__ == "__main__": + run_tests() diff --git a/test/prototype/test_optim_8bit.py b/test/prototype/test_optim_8bit.py deleted file mode 100644 index 482c5fcb42..0000000000 --- a/test/prototype/test_optim_8bit.py +++ /dev/null @@ -1,84 +0,0 @@ -import copy - -import pytest -import torch -from torch import nn -from torch.testing._internal.common_utils import ( - TestCase, - instantiate_parametrized_tests, - parametrize, - run_tests, -) -from torchao.prototype import optim_8bit -from torchao.prototype.optim_8bit.subclass import quantize_8bit_with_qmap, QMAP_SIGNED -from torchao.utils import TORCH_VERSION_AFTER_2_3 - -try: - import bitsandbytes as bnb -except ImportError: - bnb = None - - -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) - - -class TestDTQ8bit(TestCase): - @parametrize("device", _DEVICES) - def test_quantize_8bit_with_qmap_correctness(self, device): - x = torch.randn(32, 1024, device=device) - qmap = torch.tensor(QMAP_SIGNED, device=device) - - actual_codes, actual_scale = quantize_8bit_with_qmap(x, qmap, 256, implementation=1) - expected_codes, expected_scale = quantize_8bit_with_qmap(x, qmap, 256, implementation=0) - - torch.testing.assert_close(actual_codes, expected_codes) - torch.testing.assert_close(actual_scale, expected_scale) - - @parametrize("device", _DEVICES) - def test_quantize_8bit_with_qmap_compile(self, device): - x = torch.randn(32, 1024, device=device) - qmap = torch.tensor(QMAP_SIGNED, device=device) - - actual_codes, actual_scale = torch.compile(quantize_8bit_with_qmap, fullgraph=True)(x, qmap, 256) - expected_codes, expected_scale = quantize_8bit_with_qmap(x, qmap, 256) - - torch.testing.assert_close(actual_codes, expected_codes) - torch.testing.assert_close(actual_scale, expected_scale) - - -@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") -@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") -class TestOptim8bit(TestCase): - @parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) - def test_adam_8bit_correctness(self, optim_name): - device = "cuda" - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) - model2 = copy.deepcopy(model1) - - optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) - optim2 = getattr(optim_8bit, optim_name)(model2.parameters()) - - for _ in range(2): - x = torch.randn(4, 32, device=device) - - loss1 = model1(x).sum() - loss1.backward() - optim1.step() - optim1.zero_grad() - - loss2 = model2(x).sum() - loss2.backward() - optim2.step() - optim2.zero_grad() - - for p1, p2 in zip(model1.parameters(), model2.parameters()): - torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) - - -instantiate_parametrized_tests(TestDTQ8bit) -instantiate_parametrized_tests(TestOptim8bit) - - -if __name__ == "__main__": - run_tests() diff --git a/torchao/prototype/README.md b/torchao/prototype/README.md index b00e077457..1024d635c0 100644 --- a/torchao/prototype/README.md +++ b/torchao/prototype/README.md @@ -10,7 +10,7 @@ - `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm - `galore/docs` - implementation notes and discussion of issues faced in kernel design. - [`quant_llm`](quant_llm) - FP16 x FPx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) -- [`optim_8bit`](optim_8bit) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes). +- [`low_bit_optim`](low_bit_optim) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and 4-bit optimizers from [lpmm](https://github.com/thu-ml/low-bit-optimizers). #### Roadmap diff --git a/torchao/prototype/low_bit_optim/README.md b/torchao/prototype/low_bit_optim/README.md new file mode 100644 index 0000000000..da3cc064f5 --- /dev/null +++ b/torchao/prototype/low_bit_optim/README.md @@ -0,0 +1,48 @@ +# Low-bit optimizers + +This folder implements: + +- 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861 +- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507 + +The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. + +## Usage + +This is a drop-in replacement for `torch.optim.Adam` + +```python +from torchao.prototype.low_bit_optim import Adam8bit + +model = ... +optim = Adam8bit(model.parameters()) +``` + +To use 4-bit Adam, replace the above with `Adam4bit`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit optimizers, and 128 for 4-bit optimizers. + +**Other optimizers**: AdamW is also available as `AdamW8bit` and `AdamW4bit`. Other optimizers can be added based on demand. + +NOTE: +- The low-bit optimizers require PyTorch >= 2.3 +- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper. +- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed. + +## Benchmarks + +Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_low_bit_adam.py](../../../benchmarks/benchmark_low_bit_adam.py). + +Results for fine-tuning ViT-H (630M params) with BF16 AMP, batch size 4, 1 epoch, on 4070Ti SUPER: + +Adam impl | max memory (GB) | time taken | accuracy +-----------|-----------------|------------|---------- +PyTorch | 12.98 | 10m 08s | 87.70 +bnb 8-bit | 8.31 | 8m 38s | 86.22 +ao 8-bit | 8.32 | 10m 54s | 86.67 +lpmm 4-bit | 7.72 | 7m 48s | 84.70 +ao 4-bit | 7.72 | 9m 17s | 85.60 + +NOTE: time taken includes validation time, and compile time for torchao optimizers. + +## Credits + +Credits to Tim Dettmers for creating the wonderful [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library, and [lpmm](https://github.com/thu-ml/low-bit-optimizers) authors for their work on 4-bit optimizers. diff --git a/torchao/prototype/low_bit_optim/__init__.py b/torchao/prototype/low_bit_optim/__init__.py new file mode 100644 index 0000000000..ab7d8fd99b --- /dev/null +++ b/torchao/prototype/low_bit_optim/__init__.py @@ -0,0 +1,2 @@ +from .adam import Adam8bit, Adam4bit +from .adamw import AdamW8bit, AdamW4bit diff --git a/torchao/prototype/optim_8bit/adam.py b/torchao/prototype/low_bit_optim/adam.py similarity index 76% rename from torchao/prototype/optim_8bit/adam.py rename to torchao/prototype/low_bit_optim/adam.py index 461123ad20..49223b48e9 100644 --- a/torchao/prototype/optim_8bit/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -4,21 +4,12 @@ from torch import Tensor from torch.optim import Optimizer -from .subclass import maybe_new_zero_buffer +from .subclass_8bit import maybe_new_8bit_zero_buffer +from .subclass_4bit import maybe_new_4bit_zero_buffer -class Adam8bit(Optimizer): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - *, - block_size=2048, - ): +class _Adam(Optimizer): + def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) -> None: if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -36,6 +27,10 @@ def __setstate__(self, state): for group in self.param_groups: group.setdefault("amsgrad", False) + @staticmethod + def _new_buffer(p: Tensor, signed: bool, block_size: int): + raise NotImplementedError + @torch.no_grad() def step(self, closure=None): loss = None @@ -58,19 +53,20 @@ def step(self, closure=None): # state is flattened so that torch.compile won't recompile for tensors with different ndim if len(state) == 0: state["step"] = torch.tensor(0.0, device=p.device) - state["exp_avg"] = maybe_new_zero_buffer(p.view(-1), True, self.block_size) - state["exp_avg_sq"] = maybe_new_zero_buffer(p.view(-1), False, self.block_size) + state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size) + state["exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) if group["amsgrad"]: - state["max_exp_avg_sq"] = maybe_new_zero_buffer(p.view(-1), False, self.block_size) + state["max_exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) state["step"] += 1 - # flatten p and grad so that torch.compile won't recompile for tensors with different ndim # must explicitly convert lr to Tensor since torch.compile() will treat it as a constant # if it is a python float. practically, only lr is changed during training. # NOTE: if lr is change at every step, moving lr to CUDA will be a bottleneck. if not isinstance(group["lr"], Tensor): group["lr"] = torch.tensor(group["lr"], device=p.device) + + # flatten p and grad so that torch.compile won't recompile for tensors with different ndim single_param_adam( p.view(-1), grad.view(-1), @@ -125,3 +121,37 @@ def single_param_adam( step_size = lr / bias_correction1 p.addcdiv_(new_exp_avg, denom, value=-step_size) + + +class Adam8bit(_Adam): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + *, + block_size=2048 + ) -> None: + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) + + _new_buffer = staticmethod(maybe_new_8bit_zero_buffer) + + +class Adam4bit(_Adam): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + *, + block_size=128, + ) -> None: + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) + + _new_buffer = staticmethod(maybe_new_4bit_zero_buffer) diff --git a/torchao/prototype/optim_8bit/adamw.py b/torchao/prototype/low_bit_optim/adamw.py similarity index 76% rename from torchao/prototype/optim_8bit/adamw.py rename to torchao/prototype/low_bit_optim/adamw.py index 835d1c6f04..440f75620b 100644 --- a/torchao/prototype/optim_8bit/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -4,21 +4,12 @@ from torch import Tensor from torch.optim import Optimizer -from .subclass import maybe_new_zero_buffer +from .subclass_8bit import maybe_new_8bit_zero_buffer +from .subclass_4bit import maybe_new_4bit_zero_buffer -class AdamW8bit(Optimizer): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - *, - block_size=2048, - ): +class _AdamW(Optimizer): + def __init__(self, params, lr, betas, eps, weight_decay, amsgrad, *, block_size) -> None: if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -36,6 +27,10 @@ def __setstate__(self, state): for group in self.param_groups: group.setdefault("amsgrad", False) + @staticmethod + def _new_buffer(p: Tensor, signed: bool, block_size: int): + raise NotImplementedError + @torch.no_grad() def step(self, closure=None): loss = None @@ -58,19 +53,20 @@ def step(self, closure=None): # state is flattened so that torch.compile won't recompile for tensors with different ndim if len(state) == 0: state["step"] = torch.tensor(0.0, device=p.device) - state["exp_avg"] = maybe_new_zero_buffer(p.view(-1), True, self.block_size) - state["exp_avg_sq"] = maybe_new_zero_buffer(p.view(-1), False, self.block_size) + state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size) + state["exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) if group["amsgrad"]: - state["max_exp_avg_sq"] = maybe_new_zero_buffer(p.view(-1), False, self.block_size) + state["max_exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) state["step"] += 1 - # flatten p and grad so that torch.compile won't recompile for tensors with different ndim # must explicitly convert lr to Tensor since torch.compile() will treat it as a constant # if it is a python float. practically, only lr is changed during training. # NOTE: if lr is change at every step, moving lr to CUDA will be a bottleneck. if not isinstance(group["lr"], Tensor): group["lr"] = torch.tensor(group["lr"], device=p.device) + + # flatten p and grad so that torch.compile won't recompile for tensors with different ndim single_param_adamw( p.view(-1), grad.view(-1), @@ -124,3 +120,37 @@ def single_param_adamw( step_size = lr / bias_correction1 p.addcdiv_(new_exp_avg, denom, value=-step_size) + + +class AdamW8bit(_AdamW): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + *, + block_size=2048 + ) -> None: + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) + + _new_buffer = staticmethod(maybe_new_8bit_zero_buffer) + + +class AdamW4bit(_AdamW): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + *, + block_size=128, + ) -> None: + super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size) + + _new_buffer = staticmethod(maybe_new_4bit_zero_buffer) diff --git a/torchao/prototype/low_bit_optim/subclass_4bit.py b/torchao/prototype/low_bit_optim/subclass_4bit.py new file mode 100644 index 0000000000..2b3608ce1d --- /dev/null +++ b/torchao/prototype/low_bit_optim/subclass_4bit.py @@ -0,0 +1,164 @@ +import math + +import torch +from torch import Tensor +from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE + +from .subclass_8bit import create_dynamic_map + + +aten = torch.ops.aten + + +# https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml +# NOTE: power-1 is linear +QMAP_SIGNED = create_dynamic_map(True, 3, 4) +QMAP_UNSIGNED = torch.linspace(0, 1, 17)[1:].tolist() # no zero + + +def quantize_4bit_with_qmap(input: Tensor, qmap: Tensor, block_size: int, implementation: int = 1): + # section 2.1 from https://arxiv.org/abs/2110.02861 + input = input.view(-1, block_size) + scale = input.abs().amax(-1).clip(1e-12) + input = input / scale.view(-1, 1) + + # reference implementation. equation 4 from https://arxiv.org/abs/2110.02861 + if implementation == 0: + codes = (qmap.view(1, -1) - input.view(-1, 1)).abs().argmin(-1) + codes = codes.to(torch.uint8) + + # GPU-friendly binary search + # https://blog.demofox.org/2017/06/20/simd-gpu-friendly-branchless-binary-search/ + elif implementation == 1: + input = input.view(-1) + codes = torch.where(input >= qmap[8], 8, 0) + codes += torch.where(input >= qmap[codes + 4], 4, 0) + codes += torch.where(input >= qmap[codes + 2], 2, 0) + codes += torch.where(input >= qmap[codes + 1], 1, 0) + + # rounding + codes_up = (codes + 1).clip(max=15) + val_down = qmap[codes] + val_up = qmap[codes_up] + residual = input - val_down + codes = torch.where(residual >= (val_up - val_down) * 0.5, codes_up, codes) + + codes = codes.to(torch.uint8) + + else: + raise ValueError(f"Unsupported implementation={implementation}") + + # packing + codes = (codes[::2] << 4) | codes[1::2] + return codes, scale + + +class OptimState4bit(Tensor): + implements = classmethod(_implements) + tensor_attrs = ["codes", "scale", "qmap"] + + @staticmethod + def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): + return Tensor._make_wrapper_subclass( + cls, + shape, + device=codes.device, + requires_grad=False, + ) + + def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape): + assert codes.dtype is torch.uint8 + assert codes.ndim == 1 # flattened buffer + self.codes = codes + self.scale = scale + self.qmap = qmap + self.signed = signed + self._shape = shape + + @property + def block_size(self): + return self.codes.numel() * 2 // self.scale.numel() + + def __tensor_flatten__(self): + return self.tensor_attrs, [self.signed, self._shape] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) + + def dequantize(self, output_dtype=None): + # unpack + codes = torch.stack([self.codes >> 4, self.codes & 0b1111], dim=-1) + + # torch.compile() cannot use uint8 as index + float_data = self.qmap[codes.int()] + float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1) + + dtype = output_dtype or torch.get_default_dtype() + return float_data.view(self._shape).to(dtype) + + @classmethod + def zeros(cls, shape, signed: bool = True, block_size: int = 128, device=None): + shape = (shape,) if isinstance(shape, int) else shape + n_elems = math.prod(shape) + + codes = torch.zeros(n_elems // 2, dtype=torch.uint8, device=device) + scale = torch.zeros(n_elems // block_size, device=device) + qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=device) + return cls(codes, scale, qmap, signed, shape) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(signed={self.signed}, block_size={self.block_size}, " + f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})" + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: + return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs) + + raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported") + + +@OptimState4bit.implements(aten.copy_.default) +def _(func, *args, **kwargs): + dst = args[0] + src = args[1] + + if isinstance(dst, OptimState4bit) and isinstance(src, OptimState4bit): + assert ( + dst.signed == src.signed + and dst.block_size == src.block_size + and dst._shape == src._shape + ) + dst.codes.copy_(src.codes) + dst.scale.copy_(src.scale) + # qmap should be the same, don't need to copy + + elif isinstance(dst, OptimState4bit): + codes, scale = quantize_4bit_with_qmap(src, dst.qmap, dst.block_size) + dst.codes.copy_(codes) + dst.scale.copy_(scale) + + else: + dst.copy_(src.dequantize()) + + return dst + + +@OptimState4bit.implements(aten.lerp.Scalar) +def _(func, *args, **kwargs): + args = [x.dequantize() if isinstance(x, OptimState4bit) else x for x in args] + return func(*args, **kwargs) + + +# https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/config.py#L37 +# only apply quantization for tensor with more than 4096 values +# TODO: also skip 1D tensor? e.g. biases and norm scales +def maybe_new_4bit_zero_buffer(p: Tensor, signed: bool = True, block_size: int = 128): + if p.numel() >= 4096 and p.numel() % block_size == 0: + out = OptimState4bit.zeros(p.shape, signed, block_size, device=p.device) + else: + out = torch.zeros_like(p) + return out diff --git a/torchao/prototype/optim_8bit/subclass.py b/torchao/prototype/low_bit_optim/subclass_8bit.py similarity index 83% rename from torchao/prototype/optim_8bit/subclass.py rename to torchao/prototype/low_bit_optim/subclass_8bit.py index 60dbc79b85..44a3d593cf 100644 --- a/torchao/prototype/optim_8bit/subclass.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -7,6 +7,7 @@ # https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391 +# NOTE: zero padding is removed so this function can work with 4-bit qmap def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): """ Creates the dynamic quantiztion map. @@ -54,10 +55,6 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): assert len(data) == 2**total_bits - gap = 256 - len(data) - for i in range(gap): - data.append(0) - data.sort() return data @@ -65,9 +62,6 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): QMAP_SIGNED = create_dynamic_map(signed=True) QMAP_UNSIGNED = create_dynamic_map(signed=False) -ZERO_CODE_SIGNED = QMAP_SIGNED.index(0) -ZERO_CODE_UNSIGNED = QMAP_UNSIGNED.index(0) - def quantize_8bit_with_qmap(input: Tensor, qmap: Tensor, block_size: int, implementation: int = 1): # section 2.1 from https://arxiv.org/abs/2110.02861 @@ -111,7 +105,7 @@ def quantize_8bit_with_qmap(input: Tensor, qmap: Tensor, block_size: int, implem # dynamic tree quantization # https://arxiv.org/pdf/1511.04561 # https://arxiv.org/abs/2110.02861 -class DTQ8bit(Tensor): +class OptimState8bit(Tensor): implements = classmethod(_implements) tensor_attrs = ["codes", "scale", "qmap"] @@ -142,12 +136,6 @@ def __tensor_flatten__(self): def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes) - @classmethod - def from_float(cls, input_float: Tensor, signed: bool = True, block_size: int = 2048): - qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=input_float.device) - codes, scale = quantize_8bit_with_qmap(input_float, qmap, block_size) - return cls(codes.view(input_float.shape), scale, qmap, signed) - def dequantize(self, output_dtype=None): # torch.compile() cannot use uint8 as index float_data = self.qmap[self.codes.int()] @@ -158,10 +146,9 @@ def dequantize(self, output_dtype=None): @classmethod def zeros(cls, shape, signed: bool = True, block_size: int = 2048, device=None): - shape = (shape,) if isinstance(shape, int) else shape - codes = torch.full(shape, ZERO_CODE_SIGNED if signed else ZERO_CODE_UNSIGNED, dtype=torch.uint8, device=device) + codes = torch.zeros(shape, dtype=torch.uint8, device=device) + scale = torch.zeros(codes.numel() // block_size, device=device) qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=device) - scale = torch.ones(codes.numel() // block_size, device=device) return cls(codes, scale, qmap, signed) def __repr__(self): @@ -178,18 +165,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported") -@DTQ8bit.implements(aten.copy_.default) +@OptimState8bit.implements(aten.copy_.default) def _(func, *args, **kwargs): dst = args[0] src = args[1] - if isinstance(dst, DTQ8bit) and isinstance(src, DTQ8bit): - assert dst.signed == src.signed + if isinstance(dst, OptimState8bit) and isinstance(src, OptimState8bit): + assert dst.signed == src.signed and dst.block_size == src.block_size dst.codes.copy_(src.codes) dst.scale.copy_(src.scale) # qmap should be the same, don't need to copy - elif isinstance(dst, DTQ8bit): + elif isinstance(dst, OptimState8bit): codes, scale = quantize_8bit_with_qmap(src, dst.qmap, dst.block_size) dst.codes.copy_(codes) dst.scale.copy_(scale) @@ -200,18 +187,18 @@ def _(func, *args, **kwargs): return dst -@DTQ8bit.implements(aten.lerp.Scalar) +@OptimState8bit.implements(aten.lerp.Scalar) def _(func, *args, **kwargs): - args = [x.dequantize() if isinstance(x, DTQ8bit) else x for x in args] + args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args] return func(*args, **kwargs) # follow bitsandbytes # only apply quantization for tensor with more than 4096 values # TODO: also skip 1D tensor? e.g. biases and norm scales -def maybe_new_zero_buffer(p: Tensor, signed: bool = True, block_size: int = 2048): - if p.numel() >= 4096 and p.numel() %block_size == 0: - out = DTQ8bit.zeros(p.shape, signed, block_size, device=p.device) +def maybe_new_8bit_zero_buffer(p: Tensor, signed: bool = True, block_size: int = 2048): + if p.numel() >= 4096 and p.numel() % block_size == 0: + out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device) else: out = torch.zeros_like(p) return out diff --git a/torchao/prototype/optim_8bit/README.md b/torchao/prototype/optim_8bit/README.md deleted file mode 100644 index 2c35924181..0000000000 --- a/torchao/prototype/optim_8bit/README.md +++ /dev/null @@ -1,38 +0,0 @@ -# 8-bit optimizers - -This folder implements 8-bit optimizers using dynamic tree quantization as outlined in https://arxiv.org/abs/2110.02861. The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. - -## Usage - -This is a drop-in replacement for `torch.optim.Adam` - -```python -from torchao.prototype.optim_8bit import Adam8bit - -model = ... -optim = Adam8bit(model.parameters()) -``` - -You can also change quantization block size (default 2048) by passing `block_size=value` to the optimizer. - -**Other optimizers**: AdamW is also available as `AdamW8bit`. - -NOTE: this requires PyTorch >= 2.3 - -## Benchmarks - -Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_adam_8bit.py](../../../benchmarks/benchmark_adam_8bit.py). - -Results for fine-tuning ViT-B with BF16 AMP, on 4070Ti SUPER: - -Adam impl | max memory (GB) | training time | accuracy -----------|-----------------|---------------|---------- -PyTorch | 5.26 | 9m 11s | 93.62% -bnb 8-bit | 4.78 | 9m 10s | 93.06% -ao 8-bit | 4.78 | 9m 15s | 94.14% - -**Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed - -## Credits - -Credits to Tim Dettmers for creating the wonderful bitsandbytes library. diff --git a/torchao/prototype/optim_8bit/__init__.py b/torchao/prototype/optim_8bit/__init__.py deleted file mode 100644 index 2684331b9a..0000000000 --- a/torchao/prototype/optim_8bit/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .adam import Adam8bit -from .adamw import AdamW8bit