Skip to content

Commit

Permalink
[low-bit optim] Change 8-bit and FP8 optim block size from 2048 to 25…
Browse files Browse the repository at this point in the history
…6 to match new bnb v0.44 (#927)
  • Loading branch information
gau-nernst authored Sep 24, 2024
1 parent 26e790d commit 728d629
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 8 deletions.
6 changes: 5 additions & 1 deletion test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import torch
from packaging.version import Version
from torch import nn
from torch.testing._internal.common_utils import (
TestCase,
Expand Down Expand Up @@ -105,8 +106,11 @@ def test_optim_8bit_correctness(self, optim_name):
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model2 = copy.deepcopy(model1)

# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)

for _ in range(2):
x = torch.randn(4, 32, device=device)
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ model = ...
optim = Adam8bit(model.parameters())
```

To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers.
To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 256 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers.

**Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand.

Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(
weight_decay=0,
amsgrad=False,
*,
block_size=2048,
block_size=256,
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False)

Expand Down Expand Up @@ -199,7 +199,7 @@ def __init__(
weight_decay=0,
amsgrad=False,
*,
block_size=2048,
block_size=256,
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False)

Expand All @@ -218,7 +218,7 @@ def __init__(
weight_decay=1e-2,
amsgrad=False,
*,
block_size=2048,
block_size=256,
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True)

Expand Down Expand Up @@ -256,7 +256,7 @@ def __init__(
weight_decay=1e-2,
amsgrad=False,
*,
block_size=2048,
block_size=256,
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True)

Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/low_bit_optim/subclass_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def dequantize(self, output_dtype=None):
return dequant_with_qmap(self.codes, self.qmap, self.scale).to(dtype)

@classmethod
def zeros(cls, shape, signed: bool = True, block_size: int = 2048, device=None):
def zeros(cls, shape, signed: bool = True, block_size: int = 256, device=None):
codes = torch.zeros(shape, dtype=torch.uint8, device=device)
scale = torch.zeros(codes.numel() // block_size, device=device)
qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=device)
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def dequantize(self, output_dtype=None):
return float_data.view(self.codes.shape).to(dtype)

@classmethod
def zeros(cls, shape, block_size: int = 2048, device=None):
def zeros(cls, shape, block_size: int = 256, device=None):
codes = torch.zeros(shape, dtype=DTYPE, device=device)
scale = torch.zeros(codes.numel() // block_size, device=device)
return cls(codes, scale)
Expand Down

0 comments on commit 728d629

Please sign in to comment.