Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added first bits of Uint2Tensor and BitnetTensor #282

Merged
merged 21 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
6927d28
Added first bits of Uint2Tensor and BitnetTensor
andreaskoepf May 26, 2024
0d85b06
add conversion to standard signed and unsigned dtypes
andreaskoepf May 26, 2024
f64e457
added triton kernel for pack and unpack
melvinebenezer May 28, 2024
8b14c17
fix: test cases and device allocation for triton kernels
melvinebenezer May 28, 2024
fcd7c08
fix: moved uint2 to prototype folder
melvinebenezer Jun 1, 2024
30d95a1
Add packing and unpacking functions for uint{2,3,4,5,6,7}.
andreaskoepf Jun 1, 2024
a071487
housekeeping: renamed uint_small to uintgen and simple comments
melvinebenezer Jun 3, 2024
f0d5982
Update uint2.py
CoffeeVampir3 Jun 4, 2024
13fa9d8
added pytest ,compile tests and some cleanup
melvinebenezer Jun 5, 2024
326f552
Merge branch 'pytorch:main' into uint2_bitnet
vayuda Jun 13, 2024
6d6b9dc
Merge branch 'pytorch:main' into uint2_bitnet
CoffeeVampir3 Jun 15, 2024
1fdeb91
fix: implements pattern for uint2 and BitnetTensor
melvinebenezer Jun 16, 2024
4eb5679
fix: torch.uint2 available after torch 2.3
melvinebenezer Jun 16, 2024
4ec77e4
Merge branch 'pytorch:main' into uint2_bitnet
vayuda Jun 16, 2024
666a724
fix: test cases for BitnetTensor, UInt2Tensor and bitpacking gen
melvinebenezer Jun 17, 2024
a2a4359
fix: removed torch.uint2
melvinebenezer Jun 17, 2024
60970c5
fix: wrap detach in UIntTensor, torch.compile test
melvinebenezer Jun 17, 2024
c9e9583
fix: CI errors on compile tests
melvinebenezer Jun 18, 2024
7041216
fix: skip tests less than torch 2.4
melvinebenezer Jun 18, 2024
5ef3f6b
Added pytest fixture
msaroufim Jun 18, 2024
ae4ead1
remove tensor core flag
msaroufim Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions test/dtypes/test_bitnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
import torch
import torch.nn as nn
from torchao.prototype.dtypes import BitnetTensor
from torchao.prototype.dtypes.uint2 import unpack_uint2
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

@pytest.fixture(autouse=True)
def run_before_and_after_tests():
# source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501

# setup (currently do nothing)

# tests will run here
yield

# teardown
# avoid dynamo cache limit issues
torch._dynamo.reset()

@pytest.fixture
def bitnet_tensor():
input_tensor = torch.randint(0, 15, (4,4), dtype=torch.uint8)
return BitnetTensor.from_unpacked(input_tensor)

def test_copy(bitnet_tensor):
copied_tensor = bitnet_tensor.clone()
assert torch.equal(bitnet_tensor.elem, copied_tensor.elem)

def test_transpose(bitnet_tensor):
transposed_tensor = bitnet_tensor.t()
expected_tensor = unpack_uint2(bitnet_tensor.elem).t()
assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor)

def test_multiply(bitnet_tensor):
w_t = torch.randint(0, 15, (4, 16), dtype=torch.uint8)
w = BitnetTensor.from_unpacked(w_t)
y = torch.addmm(torch.Tensor([1]), bitnet_tensor, w)

@pytest.mark.parametrize("dtype", [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64])
def test_conversion(bitnet_tensor, dtype):
converted_tensor = bitnet_tensor.to(dtype)
expected_tensor = unpack_uint2(bitnet_tensor.elem).to(dtype)
assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5)

def _apply_weight_only_uint2_quant(model):
def fn(mod):
mod.weight = torch.nn.Parameter(BitnetTensor.from_float(mod.weight), requires_grad=False)
return mod

_replace_with_custom_fn_if_matches_filter(
model,
lambda mod: fn(mod),
lambda mod, fqn: isinstance(mod, torch.nn.Linear),
)

@pytest.mark.parametrize("input_shape", [[2, 4], [5, 5, 5, 4], [1, 4, 4]])
def test_uint2_quant(input_shape):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.randn(*input_shape).to(device)
m = nn.Sequential(nn.Linear(4, 16)).to(device)
y_ref = m(x)
_apply_weight_only_uint2_quant(m)
y_wo = m(x)
assert y_ref.shape == y_wo.shape
y_compiled = torch.compile(m, fullgraph=True)(x)


if __name__ == "__main__":
pytest.main(__file__)

33 changes: 33 additions & 0 deletions test/dtypes/test_uint2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
import torch
import torch.nn as nn
from torchao.prototype.dtypes import UInt2Tensor
from torchao.prototype.dtypes.uint2 import unpack_uint2
from torchao.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

@pytest.fixture
def uint2_tensor():
input_tensor = torch.randint(0, 15, (4,4), dtype = torch.uint8)
return UInt2Tensor(input_tensor)

def test_copy(uint2_tensor):
copied_tensor = uint2_tensor.clone()
assert torch.equal(uint2_tensor.elem, copied_tensor.elem)

def test_transpose(uint2_tensor):
transposed_tensor = uint2_tensor.t()
expected_tensor = unpack_uint2(uint2_tensor.elem).t()
assert torch.equal(unpack_uint2(transposed_tensor.elem), expected_tensor)

@pytest.mark.parametrize("dtype", [torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64])
def test_conversion(uint2_tensor, dtype):
converted_tensor = uint2_tensor.to(dtype)
expected_tensor = unpack_uint2(uint2_tensor.elem).to(dtype)
assert torch.allclose(converted_tensor, expected_tensor, atol=1e-5)

if __name__ == '__main__':
pytest.main(__file__)

26 changes: 26 additions & 0 deletions test/prototype/test_bitpacking_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
import torch

from torchao.prototype.dtypes.uintgen import (
pack_uint2, unpack_uint2, pack_uint3, unpack_uint3, pack_uint4, unpack_uint4,
pack_uint5, unpack_uint5, pack_uint6, unpack_uint6, pack_uint7, unpack_uint7
)

@pytest.mark.parametrize("pack_fn, unpack_fn, bit_count", [
(pack_uint2, unpack_uint2, 2),
(pack_uint3, unpack_uint3, 3),
(pack_uint4, unpack_uint4, 4),
(pack_uint5, unpack_uint5, 5),
(pack_uint6, unpack_uint6, 6),
(pack_uint7, unpack_uint7, 7),
])
def test_uint_packing(pack_fn, unpack_fn, bit_count):
x = torch.arange(0, 256, dtype=torch.uint8)
y = pack_fn(x)
z = unpack_fn(y)
k = z.view(-1, 2 ** bit_count)
check = torch.arange(0, 2 ** bit_count, dtype=torch.uint8).repeat(k.size(0), 1)
assert torch.all(k == check), f"Failed for {bit_count}-bit packing"

if __name__ == "__main__":
pytest.main(__file__)
1 change: 1 addition & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .nf4tensor import NF4Tensor, to_nf4
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
from .uint4 import UInt4Tensor
from .aqt import AffineQuantizedTensor, to_affine_quantized

Expand Down
9 changes: 9 additions & 0 deletions torchao/prototype/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

from .uint2 import UInt2Tensor
from .bitnet import BitnetTensor

__all__ = [
"BitnetTensor",
"UInt2Tensor",
]

161 changes: 161 additions & 0 deletions torchao/prototype/dtypes/bitnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import torch
from torchao.prototype.dtypes.uint2 import UInt2Tensor, unpack_uint2, pack_uint2

BITNET_OPS_TABLE = {}

def implements(aten_ops):
def decorator(fn):
for op in aten_ops:
BITNET_OPS_TABLE[op] = fn
return fn
return decorator

def _quantize_int2(x: torch.Tensor) -> torch.Tensor:
# Quantize the input tensor to int2
quant = x.sign() + 1
quant = BitnetTensor.from_unpacked(quant.to(torch.uint8))
return quant

class BitnetTensor(UInt2Tensor):
def __new__(cls, input_tensor: torch.Tensor, **kwargs):
return super(BitnetTensor, cls).__new__(cls, input_tensor, **kwargs)

def __init__(self, input_tensor: torch.Tensor, **kwargs):
super(BitnetTensor, self).__init__(input_tensor, **kwargs)

@staticmethod
def __tensor_unflatten__(flattened, *meta):
# TODO - meta is not None, is it ok?
elem = flattened["elem"]
return BitnetTensor(elem)

@classmethod
def from_unpacked(cls, unpacked: torch.Tensor) -> "BitnetTensor":
return cls(pack_uint2(unpacked))

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
def allowed_subclasses(type):
return (
issubclass(cls, type) or
issubclass(torch._subclasses.fake_tensor.FakeTensor, type) or
issubclass(torch._subclasses.functional_tensor.FunctionalTensor, type)
)

if not all(allowed_subclasses(t) for t in types):
return NotImplemented("Bitnet, Up to the next one to handle")

if func in BITNET_OPS_TABLE:
return BITNET_OPS_TABLE[func](func, args, kwargs)
raise NotImplementedError(f"Bitnet dispatch: attempting to run {func}, this is not supported")

@classmethod
def from_float(cls, w: torch.Tensor):
w_intq = _quantize_int2(w)
w_int2 = w_intq.to(device=w.device)
return w_int2

def clone(self):
return BitnetTensor(self.elem.clone())

def copy_(self, src):
self.elem.copy_(src.elem)
return self

def tolist(self):
data = unpack_uint2(self.elem).tolist()
return data

def __repr__(self):
try:
data = unpack_uint2(self.elem).tolist()
except AssertionError:
data = f"Tensor of shape {self.shape} and dtype {self.elem.dtype}"
return f"BitnetTensor({data}, dtype={self.elem.dtype})"

def to(self, *args, **kwargs):
if len(args) == 1 and isinstance(args[0], torch.dtype):
dtype = args[0]
if dtype == torch.int8:
return unpack_uint2(self.elem).view(self.shape).view(torch.int8)
elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64):
return unpack_uint2(self.elem).to(torch.int8).to(dtype)
elif dtype == torch.uint8:
return unpack_uint2(self.elem).view(torch.uint8)
elif isinstance(self, BitnetTensor):
return self
if 'device' in kwargs:
device = kwargs['device']
return BitnetTensor(self.elem.to(device=device))

return super().to(*args, **kwargs)

@implements([torch.ops.aten.mm.default])
def mm(func, args, kwargs):
x, weight = args
if isinstance(x, BitnetTensor):
x = unpack_uint2(x.elem).to(torch.float32)
if isinstance(weight, BitnetTensor):
weight = unpack_uint2(weight.elem).to(torch.float32)
y = torch.mm(x, weight)
return y

@implements([torch.ops.aten.addmm.default])
def addmm(func, args, kwargs):
bias, x, weight = args
if isinstance(x, BitnetTensor):
x = unpack_uint2(x.elem).to(torch.float32)
if isinstance(weight, BitnetTensor):
weight = unpack_uint2(weight.elem).to(torch.float32)
if bias is not None:
bias = bias.to(torch.float32)
y = torch.addmm(bias, x, weight)
return y

@implements([torch.ops.aten.t.default])
def t(func, args, kwargs):
(tensor,) = args
unpacked = unpack_uint2(tensor.elem).to(tensor.device)
transposed = unpacked.t()
return BitnetTensor(pack_uint2(transposed))

@implements([torch.ops.aten.detach.default])
def detach(func, args, kwargs):
(tensor,) = args
return tensor

@implements([torch.ops.aten.to.dtype])
def to_dtype(func, args, kwargs):
(tensor, dtype) = args
if dtype == torch.int8:
return unpack_uint2(tensor.elem).view(torch.uint8) - 1
elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64):
return unpack_uint2(tensor.elem).to(torch.int8).to(dtype)
elif dtype == torch.uint8:
return unpack_uint2(tensor.elem).view(torch.uint8)
elif isinstance(tensor, BitnetTensor):
return tensor.elem
raise NotImplementedError(f"to {dtype} not supported")

@implements([torch.ops.aten._to_copy.default])
def _to_copy(func, args, kwargs):
(tensor,) = args
dtype = kwargs["dtype"]
if dtype == torch.int8:
return BitnetTensor(unpack_uint2(tensor).view(tensor.shape).view(torch.int8) - 1)
elif dtype in (torch.float, torch.float16, torch.bfloat16, torch.int16, torch.int32, torch.int64):
return BitnetTensor(tensor.to(torch.int8).to(dtype))
elif isinstance(tensor, BitnetTensor):
return BitnetTensor(tensor)
raise NotImplementedError(f"to {dtype} not supported")

@implements([torch.ops.aten.clone.default])
def clone(func, args, kwargs):
(tensor,) = args
return tensor.clone()

@implements([torch.ops.aten.allclose.default])
def allclose(func, args, kwargs):
(a, b) = args
return torch.allclose(a.elem, b.elem, **kwargs)

Loading
Loading