Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bitpackingv2 #307

Merged
merged 16 commits into from
Jun 10, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
removed conversion
vayuda committed Jun 3, 2024
commit 55d0db80640d324d508634a18c2f96854651cdfa
10 changes: 7 additions & 3 deletions test/prototype/test_bitpacking.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@ def test_to_uint8_CPU():
for i in range(len(test_tensor.shape)):
packed = pack(test_tensor, dtype, dimension = i, container_dtype = torch.uint8, device='cpu')
unpacked = unpack(packed, dtype, dimension = i, device='cpu')
assert unpacked.to(dtype).allclose(test_tensor), f"Failed for {dtype} on dim {i}"
assert unpacked.allclose(test_tensor), f"Failed for {dtype} on dim {i}"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_trinary_to_uint8():
@@ -29,7 +29,6 @@ def test_trinary_to_uint8():
packed = pack(test_tensor, "trinary", dimension = i, container_dtype = torch.uint8)
unpacked = unpack(packed, "trinary", dimension = i)
assert(unpacked.to(torch.int32).allclose(test_tensor))
print('trinary passed')

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_to_uint8():
@@ -38,4 +37,9 @@ def test_to_uint8():
for i in range(len(test_tensor.shape)):
packed = pack(test_tensor, dtype, dimension = i, container_dtype = torch.uint8)
unpacked = unpack(packed, dtype, dimension = i)
assert unpacked.allclose(test_tensor), f"Failed for {dtype} on dim {i}"
assert unpacked.allclose(test_tensor), f"Failed for {dtype} on dim {i}"

test_trinary_to_uint8_CPU()
test_to_uint8_CPU()
test_trinary_to_uint8()
test_to_uint8()