Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bitpacking #291

Merged
merged 32 commits into from
May 30, 2024
Merged
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
37e4d2d
init additions
vayuda May 23, 2024
b37b529
extended pack/unpack
vayuda May 25, 2024
f834885
Merge branch 'pytorch:main' into uint4-improvements
vayuda May 25, 2024
33845fe
pack/unpack from n to m dtypes
vayuda May 26, 2024
3e7ca9b
works with torch.compile, but not optimized
vayuda May 26, 2024
88fe113
works on gpu
vayuda May 26, 2024
3e1813a
Merge branch 'pytorch:main' into uint4-improvements
vayuda May 26, 2024
80b9a41
added row-wise bitpack
vayuda May 28, 2024
036334c
Merge branch 'pytorch:main' into uint4-improvements
vayuda May 28, 2024
6a02cc1
Merge branch 'uint4-improvements' of https://github.com/JayakumarPawa…
vayuda May 28, 2024
61d3666
Merge branch 'pytorch:main' into uint4-improvements
vayuda May 29, 2024
47d9c92
restructured into prototype/
vayuda May 29, 2024
5bdef89
Merge branch 'uint4-improvements' of https://github.com/vayuda/ao int…
vayuda May 29, 2024
8d1ea34
revert nuclear fix
vayuda May 29, 2024
6e1a7d6
removed temp log
vayuda May 29, 2024
46e39fd
removed trinary stuff from this branch
vayuda May 29, 2024
dcada3e
moved tests to tests/
vayuda May 29, 2024
a5dd25d
updated tests to skip if cuda DNE
vayuda May 29, 2024
5ec3deb
added full_graph=True to compile
vayuda May 29, 2024
6e9a738
Apply suggestions from code review
msaroufim May 29, 2024
189677d
fixed test skipping
vayuda May 29, 2024
4e5b8a5
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda May 29, 2024
98cf146
Merge branch 'main' into bitpacking
vayuda May 29, 2024
c8055b2
added import for has_triton
vayuda May 29, 2024
45b990c
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda May 29, 2024
74f03e2
Merge branch 'pytorch:main' into bitpacking
vayuda May 29, 2024
ae91c6c
added support for any device type
vayuda May 29, 2024
e1253b5
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda May 29, 2024
f6758e9
Merge branch 'pytorch:main' into bitpacking
vayuda May 29, 2024
68179e8
fix gpu tests
vayuda May 29, 2024
4ce1295
fix gpu tests
vayuda May 29, 2024
b01550f
Merge branch 'bitpacking' of https://github.com/vayuda/ao into bitpac…
vayuda May 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
works on gpu
  • Loading branch information
vayuda committed May 26, 2024
commit 88fe113e9727e2ceb069c647961342ebf2817c7a
17 changes: 8 additions & 9 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from functools import reduce
import os

@torch.compile
def unpack(data, data_size) -> torch.Tensor:
Expand Down Expand Up @@ -38,13 +37,13 @@ def pack(data, container_size, data_size) -> torch.Tensor:
assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})"
# pad the data to be divisible by scale
if data.shape[-1] % scale != 0:
padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype)
data = torch.cat([data, padding], dim=-1)
padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).cuda()
data = torch.cat([data, padding], dim=-1).cuda()

shape = data.shape
data = data.contiguous().view(-1)
#shift the data to the different indexes within the larger dtype and then union them together
ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)])
ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]).cuda()
newshape = down_size(shape, scale)
return ret.view(newshape)

Expand All @@ -58,26 +57,26 @@ def up_size(size, amt):


torch._dynamo.config.specialize_int = True
os.environ["TORCH_LOGS"] = "output_code"
test_tensor = torch.randint(0, 15, (1, 1, 6), dtype=torch.uint8)

test_tensor = torch.randint(0, 15, (1, 1, 6), dtype=torch.uint8).cuda()
packed = pack(test_tensor, 8, 4)
unpacked = unpack(packed, 4)
unpadded = unpacked[..., :test_tensor.shape[-1]]
assert(unpadded.allclose(test_tensor))

test_tensor = torch.randint(0, 7, (5,1, 4), dtype=torch.int16)
test_tensor = torch.randint(0, 7, (5,1, 4), dtype=torch.int16).cuda()
packed = pack(test_tensor,16, 3)
unpacked = unpack(packed, 3)
unpadded = unpacked[..., :test_tensor.shape[-1]]
assert(unpadded.allclose(test_tensor))

test_tensor = torch.randint(0, 15, (3,1, 9), dtype=torch.int32)
test_tensor = torch.randint(0, 15, (3,1, 9), dtype=torch.int32).cuda()
packed = pack(test_tensor,32, 16)
unpacked = unpack(packed,16)
unpadded = unpacked[..., :test_tensor.shape[-1]]
assert(unpadded.allclose(test_tensor))

test_tensor = torch.randint(0, 3, (8, 8, 7), dtype=torch.uint8)
test_tensor = torch.randint(0, 3, (8, 8, 7), dtype=torch.uint8).cuda()
packed = pack(test_tensor, 8, 2)
unpacked = unpack(packed,2)
unpadded = unpacked[..., :test_tensor.shape[-1]]
Expand Down