Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bitpacking #291

Merged
merged 32 commits into from
May 30, 2024
Merged
Changes from 16 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
37e4d2d
init additions
vayuda May 23, 2024
b37b529
extended pack/unpack
vayuda May 25, 2024
f834885
Merge branch 'pytorch:main' into uint4-improvements
vayuda May 25, 2024
33845fe
pack/unpack from n to m dtypes
vayuda May 26, 2024
3e7ca9b
works with torch.compile, but not optimized
vayuda May 26, 2024
88fe113
works on gpu
vayuda May 26, 2024
3e1813a
Merge branch 'pytorch:main' into uint4-improvements
vayuda May 26, 2024
80b9a41
added row-wise bitpack
vayuda May 28, 2024
036334c
Merge branch 'pytorch:main' into uint4-improvements
vayuda May 28, 2024
6a02cc1
Merge branch 'uint4-improvements' of https://github.com/JayakumarPawa…
vayuda May 28, 2024
61d3666
Merge branch 'pytorch:main' into uint4-improvements
vayuda May 29, 2024
47d9c92
restructured into prototype/
vayuda May 29, 2024
5bdef89
Merge branch 'uint4-improvements' of https://github.com/vayuda/ao int…
vayuda May 29, 2024
8d1ea34
revert nuclear fix
vayuda May 29, 2024
6e1a7d6
removed temp log
vayuda May 29, 2024
46e39fd
removed trinary stuff from this branch
vayuda May 29, 2024
dcada3e
moved tests to tests/
vayuda May 29, 2024
a5dd25d
updated tests to skip if cuda DNE
vayuda May 29, 2024
5ec3deb
added full_graph=True to compile
vayuda May 29, 2024
6e9a738
Apply suggestions from code review
msaroufim May 29, 2024
189677d
fixed test skipping
vayuda May 29, 2024
4e5b8a5
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda May 29, 2024
98cf146
Merge branch 'main' into bitpacking
vayuda May 29, 2024
c8055b2
added import for has_triton
vayuda May 29, 2024
45b990c
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda May 29, 2024
74f03e2
Merge branch 'pytorch:main' into bitpacking
vayuda May 29, 2024
ae91c6c
added support for any device type
vayuda May 29, 2024
e1253b5
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda May 29, 2024
f6758e9
Merge branch 'pytorch:main' into bitpacking
vayuda May 29, 2024
68179e8
fix gpu tests
vayuda May 29, 2024
4ce1295
fix gpu tests
vayuda May 29, 2024
b01550f
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda May 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions torchao/prototype/common/bitpacking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import torch
from functools import reduce



def unpack(data, data_size, by_rows = True):
"""
Unpacks small dtype elements from a larger dtype.

Inputs:
data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype.
data_size: int - the size of the small dtype in bits.

optional:
by_rows: bool - specifies whether to unpack...
by rows: tensor(n,m) -> tensor(n*scale, m)
or by columns: tensor(n,m) -> tensor(n,m*scale)

defaults to rows because quantization is typically done by rows
but choose the version which matches how you quantize as this improves memory accesses/performance

Returns: torch.Tensor - a tensor of the unpacked elements.
"""
if by_rows:
return _unpack_by_rows(data, data_size)
else:
return _unpack_by_cols(data, data_size)

def pack(data, container_size, data_size, by_rows = True):
"""
Packs small dtype elements into a larger dtype.
Pads rows to be divisible by the scale.

Inputs:
data: torch.Tensor - a tensor of unpacked elements of a small dtype.
container_size: int - the size of the large dtype in bits.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious. container_size can be determined from data.dtype right? e.g. uint8 -> 8, uint16 -> 16. (there is also this - https://pytorch.org/docs/stable/type_info.html#torch.torch.iinfo).
Also, is it assumed that data.dtype has container_size number of bits? What if data use larger or smaller bit-width than container_size? e.g. store int4 in int32, then request to pack to int8. Depending on what are your assumptions to the inputs, perhaps some kind of type checking and/or type casting is good.

data_size: int - the size of the small dtype in bits.

optional:
by_rows: bool - specifies whether to pack values...
by rows: tensor(n,m) -> tensor(n//scale, m)
or by columns: tensor(n,m) -> tensor(n,m//scale)

defaults to rows because quantization is typically done by rows
but choose the version which matches how you quantize as this improves memory accesses/performance

Returns: torch.Tensor - a tensor of packed elements.
"""
if by_rows:
return _pack_by_rows(data, container_size, data_size)
else:
return _pack_by_cols(data, container_size, data_size)

def _unpack_by_rows(data, data_size) -> torch.Tensor:
shape = data.shape
scale = data.element_size() * 8 // data_size

unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).cuda()
nbits = (1 << data_size) - 1 # mask for the last dtype_size bits
for i in range(scale):
shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint
unpacked_data[i::scale] = ((data >> shift_amt) & (nbits))
return unpacked_data

def _unpack_by_cols(data, data_size) -> torch.Tensor:
shape = data.shape
scale = data.element_size() * 8 // data_size
unpacked_data = []
nbits = (1 << data_size) - 1 # mask for the last dtype_size bits
for i in range(scale):
shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint
unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype))
return torch.stack(unpacked_data,dim=-1).view(*shape[:-1],shape[-1]*scale) # stack the unpacked data and reshape to the original shape

def _pack_by_rows(data, container_size, data_size) -> torch.Tensor:

scale = container_size // data_size
assert scale > 1, f"container_size ({container_size}) is not larger than data_size ({data_size})"
assert data.shape[0] >= scale, f"not enough values to pack, data.shape[0] ({data.shape[0]}) < scale ({scale})"
# pad the data to be divisible by scale
if data.shape[0] % scale != 0:
padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).cuda()
data = torch.cat([data, padding], dim=0).cuda()

shape = data.shape
ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)]).cuda()
return ret.view(shape[0] // scale, *shape[1:])

def _pack_by_cols(data, container_size, data_size) -> torch.Tensor:
scale = container_size // data_size
assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})"
# pad the data to be divisible by scale
if data.shape[-1] % scale != 0:
padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).cuda()
data = torch.cat([data, padding], dim=-1).cuda()

shape = data.shape
data = data.contiguous().view(-1)
#shift the data to the different indexes within the larger dtype and then union them together
ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]).cuda()
return ret.view(*shape[:-1],shape[-1] // scale)

if __name__ == '__main__':
vayuda marked this conversation as resolved.
Show resolved Hide resolved
#debug
# import lovely_tensors
# lovely_tensors.monkey_patch()

torch._dynamo.config.specialize_int = True
pack = torch.compile(pack)
unpack = torch.compile(unpack)

test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda()
packed = pack(test_tensor, 8, 4)
unpacked = unpack(packed, 4)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))


test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda()
packed = pack(test_tensor,16, 3)
unpacked = unpack(packed, 3)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

test_tensor = torch.randint(0, 15, (3, 9), dtype=torch.int32).cuda()
packed = pack(test_tensor,32, 16)
unpacked = unpack(packed,16)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))


test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda()
packed = pack(test_tensor, 8, 2, False)
unpacked = unpack(packed,2, False)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda()
packed = pack(test_tensor, 8, 4, False)
unpacked = unpack(packed, 4, False)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))


test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda()
packed = pack(test_tensor,16, 3, False)
unpacked = unpack(packed, 3, False)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))