Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bitpackingv2 #307

Merged
merged 16 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 68 additions & 54 deletions test/prototype/test_bitpacking.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,84 @@
import torch
from torchao.prototype.common.bitpacking import pack, unpack
from torchao.prototype.common.bitpacking import pack, unpack, dtype_to_bits
import pytest
from torch.utils._triton import has_triton
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

def test_uint4_to_uint8_CPU():
test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8)
packed = pack(test_tensor, 8, 4, device='cpu')
unpacked = unpack(packed, 4, device='cpu')
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))
dtypes = (torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, "trinary")
expected_pack_size = {torch.uint2: 1, torch.uint3: 2, torch.uint4: 2, torch.uint5: 4, torch.uint6: 4, torch.uint7: 4, "trinary": 1}
dimensions = (0, 1, 2)

def test_uint3_to_int16_col_wise_cpu():
test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16)
packed = pack(test_tensor,16, 3, False, device='cpu')
unpacked = unpack(packed, 3, False, device='cpu')
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("dim", dimensions)
def test_CPU(dtype, dim):
shape = [4, 4, 4]
if dtype == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8, device='cpu')
else:
test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8, device='cpu')

packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8, device='cpu')
assert(packed.shape[dim] == expected_pack_size[dtype])
unpacked = unpack(packed, dtype, dimension = dim, device='cpu')
assert(unpacked.allclose(test_tensor))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_uint4_to_uint8():
test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda()
packed = pack(test_tensor, 8, 4)
unpacked = unpack(packed, 4)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("dim", dimensions)
def test_GPU(dtype, dim):
shape = [4, 4, 4]
if dtype == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
else:
test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8).cuda()

packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8)
assert(packed.shape[dim] == expected_pack_size[dtype])
unpacked = unpack(packed, dtype, dimension = dim)
assert(unpacked.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
def test_uint4_to_uint8_compile():
torch._dynamo.config.specialize_int = True
pack_compiled = torch.compile(pack, fullgraph=True)
unpack_compiled = torch.compile(unpack, fullgraph=True)
test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda()
packed = pack_compiled(test_tensor, 8, 4)
unpacked = unpack_compiled(packed, 4)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("dim", dimensions)
def test_compile(dtype, dim):
pack_compile = torch.compile(pack, fullgraph=True)
unpack_compile = torch.compile(unpack, fullgraph=True)
vayuda marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_uint3_to_int16():
test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda()
packed = pack(test_tensor,16, 3)
unpacked = unpack(packed, 3)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))
shape = [4, 4, 4]
if dtype == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
else:
test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8).cuda()

packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8)
assert(packed.shape[dim] == expected_pack_size[dtype])
unpacked = unpack(packed, dtype, dimension = dim)
assert(unpacked.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
def test_uint2_to_uint8_col_wise_compile():
torch._dynamo.config.specialize_int = True
pack_compiled = torch.compile(pack, fullgraph=True)
vayuda marked this conversation as resolved.
Show resolved Hide resolved
unpack_compiled = torch.compile(unpack, fullgraph=True)
test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda()
packed = pack_compiled(test_tensor, 8, 2, False)
unpacked = unpack_compiled(packed,2, False)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_uint3_to_int16_col_wise():
test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda()
packed = pack(test_tensor,16, 3, False)
unpacked = unpack(packed, 3, False)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("dim", dimensions)
def test_padding(dtype, dim):
pack_compile = torch.compile(pack, fullgraph=True)
unpack_compile = torch.compile(unpack, fullgraph=True)

shape =[4, 4, 4]
shape[dim] = 5

if dtype == "trinary":
test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda()
else:
test_tensor = torch.randint(0, 2**dtype_to_bits(dtype), shape, dtype=torch.uint8).cuda()

packed = pack(test_tensor, dtype, dimension = dim, container_dtype = torch.uint8)
assert(packed.shape[dim] == expected_pack_size[dtype]+1) # +1 for this scenario
unpacked = unpack(packed, dtype, dimension = dim)
slices = [slice(None)] * packed.ndim
slices[dim] = slice(None, 5)
assert(unpacked[slices].allclose(test_tensor))
180 changes: 109 additions & 71 deletions torchao/prototype/common/bitpacking.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,139 @@
import torch
from functools import reduce
from typing import Optional, Union

def mod_shape(shape, mod, dim):
"""changes a select dimension of the input shape to mod"""
return (*shape[:dim], mod, *shape[dim+1:])


def unpack(data, data_size, by_rows = True, device="cuda"):
def dtype_to_bits(dtype):
vayuda marked this conversation as resolved.
Show resolved Hide resolved
'''returns the number of bits in a dtype'''
if dtype in {torch.uint2, 'trinary'}:
return 2
elif dtype == torch.uint3:
return 3
elif dtype == torch.uint4:
return 4
elif dtype == torch.uint5:
return 5
elif dtype == torch.uint6:
return 6
elif dtype == torch.uint7:
return 7
elif dtype in {torch.uint8, torch.int8}:
return 8
elif dtype in {torch.uint16, torch.int16, torch.float16}:
return 16
elif dtype in {torch.uint32, torch.int32, torch.float32}:
return 32
elif dtype == {torch.uint64, torch.int64, torch.float64}:
return 64
else:
raise ValueError(f"dtype {dtype} not supported (yet)")

def unpack(data: torch.Tensor,
element_dtype: Union[torch.dtype, str], # accepting strings for trinary until thats added to torch
dimension: Optional[int] = 0,
device: Optional[str] ="cuda") -> torch.Tensor:
"""
Unpacks small dtype elements from a larger dtype.

Inputs:
data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype.
data_size: int - the size of the small dtype in bits.
data: - a tensor of packed elements
element_dtype: - the dtype of the elements to unpack

optional:
by_rows: bool - specifies whether to unpack...
by rows: tensor(n,m) -> tensor(n*scale, m)
or by columns: tensor(n,m) -> tensor(n,m*scale)

defaults to rows because quantization is typically done by rows
but choose the version which matches how you quantize as this improves memory accesses/performance
dimension: - the dimension to unpack along


Returns: torch.Tensor - a tensor of the unpacked elements.
"""
if by_rows:
return _unpack_by_rows(data, data_size, device)
else:
return _unpack_by_cols(data, data_size)
container_size = dtype_to_bits(data.dtype)
element_size = dtype_to_bits(element_dtype)
scale = container_size // element_size

def pack(data, container_size, data_size, by_rows = True, device="cuda"):
unpacked = _unpack(data, element_size, container_size, scale, dimension, device)

if element_dtype == "trinary":
unpacked = unpacked.to(torch.int8) - 1
return unpacked

def _unpack(data, element_size, container_size, scale ,dim, device):
shape = data.shape
unpacked_data = torch.zeros(mod_shape(shape, shape[dim]*scale, dim), dtype=data.dtype).to(device)
nbits = (1 << element_size) - 1 # mask for the last dtype_size bits
for i in range(scale):
shift_amt = container_size - element_size * (i + 1)
slices = [slice(None)] * unpacked_data.ndim
slices[dim] = slice(i, None, scale)
unpacked_data[slices] = ((data >> shift_amt) & (nbits)).to(data.dtype)

# stack the unpacked data and reshape to the original shape
return unpacked_data.view(mod_shape(shape,scale*shape[dim], dim))


def pack(data: torch.Tensor,
element_dtype: Union[torch.dtype, str], # accepting strings for trinary until thats added to torch
dimension: Optional[int] = 0,
container_dtype: Optional[torch.dtype] = None,
device: Optional[str] = "cuda") -> torch.Tensor:
"""
Packs small dtype elements into a larger dtype.
Pads rows to be divisible by the scale.
Packs small dtype elements into a container of a larger dtype.
**Pads rows to be divisible by the scale**
TODO: support something like packing 8 uint 3s into 3 uint8s

Inputs:
data: torch.Tensor - a tensor of unpacked elements of a small dtype.
container_size: int - the size of the large dtype in bits.
data_size: int - the size of the small dtype in bits.
data: a tensor of unpacked elements of a small dtype. The dtype used for the data will be used for the container.
dimension: the dimension to pack along
element_dtype: the dtype of the elements to pack
vayuda marked this conversation as resolved.
Show resolved Hide resolved

optional:
by_rows: bool - specifies whether to pack values...
by rows: tensor(n,m) -> tensor(n//scale, m)
or by columns: tensor(n,m) -> tensor(n,m//scale)
container_dtype: specify the dtype of the container if the data is not already inside a tensor of that dtype


defaults to rows because quantization is typically done by rows
but choose the version which matches how you quantize as this improves memory accesses/performance

Returns: torch.Tensor - a tensor of packed elements.
"""
if by_rows:
return _pack_by_rows(data, container_size, data_size, device)
else:
return _pack_by_cols(data, container_size, data_size, device)

def _unpack_by_rows(data, data_size, device) -> torch.Tensor:
shape = data.shape
scale = data.element_size() * 8 // data_size
if element_dtype == "trinary":
data = data + 1

if container_dtype is not None:
data = data.to(container_dtype)

unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).to(device)
nbits = (1 << data_size) - 1 # mask for the last dtype_size bits
for i in range(scale):
shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint
unpacked_data[i::scale] = ((data >> shift_amt) & (nbits))
return unpacked_data
container_size = dtype_to_bits(data.dtype)
element_size = dtype_to_bits(element_dtype)
scale = container_size // element_size

assert data.shape[dimension] >= scale, f"not enough values to pack along dimension {dimension} ({data.shape[dimension]}) < scale ({scale})"
return _pack(data, container_size, element_size, scale, dimension, device)

def _unpack_by_cols(data, data_size) -> torch.Tensor:
shape = data.shape
scale = data.element_size() * 8 // data_size
unpacked_data = []
nbits = (1 << data_size) - 1 # mask for the last dtype_size bits


def _pack(data, container_size, element_size, scale, dim, device) -> torch.Tensor:
#pad dimension to be divisible by scale
if data.shape[dim] % scale != 0:
padding = torch.zeros(mod_shape(data.shape, scale - data.shape[dim] % scale, dim), dtype=data.dtype).to(device)
data = torch.cat([data, padding], dim=dim).to(device)

packed = torch.zeros(mod_shape(data.shape, data.shape[dim] // scale, dim), dtype=data.dtype).to(device)
for i in range(scale):
shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint
unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype))
return torch.stack(unpacked_data,dim=-1).view(*shape[:-1],shape[-1]*scale) # stack the unpacked data and reshape to the original shape
slices = [slice(None)] * packed.ndim
slices[dim] = slice(i, None, scale)
packed |= data[slices] << container_size-element_size*(i+1)
return packed

def _pack_by_rows(data, container_size, data_size, device) -> torch.Tensor:

scale = container_size // data_size
assert scale > 1, f"container_size ({container_size}) is not larger than data_size ({data_size})"
assert data.shape[0] >= scale, f"not enough values to pack, data.shape[0] ({data.shape[0]}) < scale ({scale})"
# pad the data to be divisible by scale
if data.shape[0] % scale != 0:
padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).to(device)
data = torch.cat([data, padding], dim=0).cuda()

shape = data.shape
ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)])
return ret.view(shape[0] // scale, *shape[1:]).to(device)

def _pack_by_cols(data, container_size, data_size, device) -> torch.Tensor:
scale = container_size // data_size
assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})"
# pad the data to be divisible by scale
if data.shape[-1] % scale != 0:
padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).to(device)
data = torch.cat([data, padding], dim=-1).cuda()

shape = data.shape
data = data.contiguous().view(-1)
#shift the data to the different indexes within the larger dtype and then union them together
ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)])
return ret.view(*shape[:-1],shape[-1] // scale).to(device)
# shape = [5, 1]
# dtype= torch.uint2
# test_tensor = torch.randint(0, 2, shape, dtype=torch.uint8).cuda()
# print(test_tensor)
# packed = pack(test_tensor, dtype, dimension = 0, container_dtype = torch.uint8)
# print(packed)
# unpacked = unpack(packed, dtype, dimension = 0)

# slices = [slice(None)] * packed.ndim
# slices[0] = slice(None, 4+1)
# print(unpacked[slices])
# assert(unpacked[slices].allclose(test_tensor))
Loading