-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bitpacking #291
Merged
Merged
Bitpacking #291
Changes from 16 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
37e4d2d
init additions
vayuda b37b529
extended pack/unpack
vayuda f834885
Merge branch 'pytorch:main' into uint4-improvements
vayuda 33845fe
pack/unpack from n to m dtypes
vayuda 3e7ca9b
works with torch.compile, but not optimized
vayuda 88fe113
works on gpu
vayuda 3e1813a
Merge branch 'pytorch:main' into uint4-improvements
vayuda 80b9a41
added row-wise bitpack
vayuda 036334c
Merge branch 'pytorch:main' into uint4-improvements
vayuda 6a02cc1
Merge branch 'uint4-improvements' of https://github.com/JayakumarPawa…
vayuda 61d3666
Merge branch 'pytorch:main' into uint4-improvements
vayuda 47d9c92
restructured into prototype/
vayuda 5bdef89
Merge branch 'uint4-improvements' of https://github.com/vayuda/ao int…
vayuda 8d1ea34
revert nuclear fix
vayuda 6e1a7d6
removed temp log
vayuda 46e39fd
removed trinary stuff from this branch
vayuda dcada3e
moved tests to tests/
vayuda a5dd25d
updated tests to skip if cuda DNE
vayuda 5ec3deb
added full_graph=True to compile
vayuda 6e9a738
Apply suggestions from code review
msaroufim 189677d
fixed test skipping
vayuda 4e5b8a5
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda 98cf146
Merge branch 'main' into bitpacking
vayuda c8055b2
added import for has_triton
vayuda 45b990c
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda 74f03e2
Merge branch 'pytorch:main' into bitpacking
vayuda ae91c6c
added support for any device type
vayuda e1253b5
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda f6758e9
Merge branch 'pytorch:main' into bitpacking
vayuda 68179e8
fix gpu tests
vayuda 4ce1295
fix gpu tests
vayuda b01550f
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import torch | ||
from functools import reduce | ||
|
||
|
||
|
||
def unpack(data, data_size, by_rows = True): | ||
""" | ||
Unpacks small dtype elements from a larger dtype. | ||
|
||
Inputs: | ||
data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype. | ||
data_size: int - the size of the small dtype in bits. | ||
|
||
optional: | ||
by_rows: bool - specifies whether to unpack... | ||
by rows: tensor(n,m) -> tensor(n*scale, m) | ||
or by columns: tensor(n,m) -> tensor(n,m*scale) | ||
|
||
defaults to rows because quantization is typically done by rows | ||
but choose the version which matches how you quantize as this improves memory accesses/performance | ||
|
||
Returns: torch.Tensor - a tensor of the unpacked elements. | ||
""" | ||
if by_rows: | ||
return _unpack_by_rows(data, data_size) | ||
else: | ||
return _unpack_by_cols(data, data_size) | ||
|
||
def pack(data, container_size, data_size, by_rows = True): | ||
""" | ||
Packs small dtype elements into a larger dtype. | ||
Pads rows to be divisible by the scale. | ||
|
||
Inputs: | ||
data: torch.Tensor - a tensor of unpacked elements of a small dtype. | ||
container_size: int - the size of the large dtype in bits. | ||
data_size: int - the size of the small dtype in bits. | ||
|
||
optional: | ||
by_rows: bool - specifies whether to pack values... | ||
by rows: tensor(n,m) -> tensor(n//scale, m) | ||
or by columns: tensor(n,m) -> tensor(n,m//scale) | ||
|
||
defaults to rows because quantization is typically done by rows | ||
but choose the version which matches how you quantize as this improves memory accesses/performance | ||
|
||
Returns: torch.Tensor - a tensor of packed elements. | ||
""" | ||
if by_rows: | ||
return _pack_by_rows(data, container_size, data_size) | ||
else: | ||
return _pack_by_cols(data, container_size, data_size) | ||
|
||
def _unpack_by_rows(data, data_size) -> torch.Tensor: | ||
shape = data.shape | ||
scale = data.element_size() * 8 // data_size | ||
|
||
unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).cuda() | ||
nbits = (1 << data_size) - 1 # mask for the last dtype_size bits | ||
for i in range(scale): | ||
shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint | ||
unpacked_data[i::scale] = ((data >> shift_amt) & (nbits)) | ||
return unpacked_data | ||
|
||
def _unpack_by_cols(data, data_size) -> torch.Tensor: | ||
shape = data.shape | ||
scale = data.element_size() * 8 // data_size | ||
unpacked_data = [] | ||
nbits = (1 << data_size) - 1 # mask for the last dtype_size bits | ||
for i in range(scale): | ||
shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint | ||
unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype)) | ||
return torch.stack(unpacked_data,dim=-1).view(*shape[:-1],shape[-1]*scale) # stack the unpacked data and reshape to the original shape | ||
|
||
def _pack_by_rows(data, container_size, data_size) -> torch.Tensor: | ||
|
||
scale = container_size // data_size | ||
assert scale > 1, f"container_size ({container_size}) is not larger than data_size ({data_size})" | ||
assert data.shape[0] >= scale, f"not enough values to pack, data.shape[0] ({data.shape[0]}) < scale ({scale})" | ||
# pad the data to be divisible by scale | ||
if data.shape[0] % scale != 0: | ||
padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).cuda() | ||
data = torch.cat([data, padding], dim=0).cuda() | ||
|
||
shape = data.shape | ||
ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)]).cuda() | ||
return ret.view(shape[0] // scale, *shape[1:]) | ||
|
||
def _pack_by_cols(data, container_size, data_size) -> torch.Tensor: | ||
scale = container_size // data_size | ||
assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})" | ||
# pad the data to be divisible by scale | ||
if data.shape[-1] % scale != 0: | ||
padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).cuda() | ||
data = torch.cat([data, padding], dim=-1).cuda() | ||
|
||
shape = data.shape | ||
data = data.contiguous().view(-1) | ||
#shift the data to the different indexes within the larger dtype and then union them together | ||
ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]).cuda() | ||
return ret.view(*shape[:-1],shape[-1] // scale) | ||
|
||
if __name__ == '__main__': | ||
vayuda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#debug | ||
# import lovely_tensors | ||
# lovely_tensors.monkey_patch() | ||
|
||
torch._dynamo.config.specialize_int = True | ||
pack = torch.compile(pack) | ||
unpack = torch.compile(unpack) | ||
|
||
test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() | ||
packed = pack(test_tensor, 8, 4) | ||
unpacked = unpack(packed, 4) | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
|
||
test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() | ||
packed = pack(test_tensor,16, 3) | ||
unpacked = unpack(packed, 3) | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
test_tensor = torch.randint(0, 15, (3, 9), dtype=torch.int32).cuda() | ||
packed = pack(test_tensor,32, 16) | ||
unpacked = unpack(packed,16) | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
|
||
test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda() | ||
packed = pack(test_tensor, 8, 2, False) | ||
unpacked = unpack(packed,2, False) | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() | ||
packed = pack(test_tensor, 8, 4, False) | ||
unpacked = unpack(packed, 4, False) | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
|
||
test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda() | ||
packed = pack(test_tensor,16, 3, False) | ||
unpacked = unpack(packed, 3, False) | ||
unpadded = unpacked[:test_tensor.shape[0], ...] | ||
assert(unpadded.allclose(test_tensor)) | ||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious.
container_size
can be determined fromdata.dtype
right? e.g. uint8 -> 8, uint16 -> 16. (there is also this - https://pytorch.org/docs/stable/type_info.html#torch.torch.iinfo).Also, is it assumed that
data.dtype
hascontainer_size
number of bits? What ifdata
use larger or smaller bit-width thancontainer_size
? e.g. store int4 in int32, then request to pack to int8. Depending on what are your assumptions to the inputs, perhaps some kind of type checking and/or type casting is good.