diff --git a/benchmarks/benchmark_bitpacking.py b/benchmarks/benchmark_bitpacking.py new file mode 100644 index 0000000000..e974efca58 --- /dev/null +++ b/benchmarks/benchmark_bitpacking.py @@ -0,0 +1,227 @@ +from math import log +import torch + +from torchao.prototype.common.bitpacking import pack, unpack +from torchao.dtypes.uint4 import unpack_uint4, pack_uint4 + + +def benchmark(function, num_runs, setup =None): + args = setup() + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + + for _ in range(num_runs): + function(*args) + + end_event.record() + torch.cuda.synchronize() + return start_event.elapsed_time(end_event) / num_runs + + +def test_vs_existing(): + def new_(): + fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda() + packed = pack(fake_tensor, 4, dim=1) + unpacked = unpack(packed, 4, dim=1) + def old_(): + fake_tensor = torch.randint(0, 2**8, (1, 1024,1024), dtype=torch.uint8).cuda() + packed = pack_uint4(fake_tensor) + unpacked = unpack_uint4(packed) + new_ = torch.compile(new_, fullgraph=True) + old_ = torch.compile(old_, fullgraph=True) + new_() + old_() + print(f"new: {benchmark(new_, 1000)} ms ") + print(f"old: {benchmark(old_, 1000)} ms") + + +def test_iso_bitpack(): + def load4x(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, 4*scale,scale), dtype=torch.uint8).cuda() + + def load2x(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, 2*scale,scale), dtype=torch.uint8).cuda() + + def loadx(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() + + def unpack8to2(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() + unpacked_tensor = unpack_c(fake_tensor, 2, dim=1) + + def unpack8to4(scale=1024): + fake_tensor = torch.randint(0, 2**8, (1, scale,scale), dtype=torch.uint8).cuda() + unpacked_tensor = unpack_c(fake_tensor, 4, dim=1) + + def t8to4wmm(scale=1024): + fake_tensor = torch.randint(0, 2**8, (8, 1024,1024), dtype=torch.uint8).cuda() + unpacked_tensor = unpack_c(fake_tensor, 4, dim=1) + + torch._dynamo.config.specialize_int = True + # _unpack_c = torch.compile(_unpack, fullgraph=True) + unpack_c = torch.compile(unpack, fullgraph=True) + + scale = [16,64,256,1024,4096] + load4x_times = [] + unpack8to2_times = [] + load2x_times = [] + unpack8to4_times = [] + for s in scale: + res = benchmark(load4x, 50, scale=s) + load4x_times.append(res) + print(f"load(1, {4*s},{s}) time: {res} ms") + + res=benchmark(unpack8to2, 50, scale=s) + unpack8to2_times.append(res) + print(f"load(1, {s},{s}) unpack uint2 time: {res} ms") + + res = benchmark(load2x, 50, scale=s) + load2x_times.append(res) + print(f"load(1, {2*s},{s}) time: {res} ms") + + res = benchmark(unpack8to4, 50, scale=s) + unpack8to4_times.append(res) + print(f"load(1, {s},{s}) unpack uint4 time: {res} ms") + print() + + # import matplotlib.pyplot as plt + # plt.plot(scale, load4x_times, label="load(1, 4x, x)") + # plt.plot(scale, unpack8to2_times, label="unpack uint8 to uint2") + # plt.plot(scale, load2x_times, label="load(1, 2x, x)") + # plt.plot(scale, unpack8to4_times, label="unpack uint8 to uint4") + # plt.xlabel("scale") + # plt.ylabel("time (ms)") + # plt.yscale("log") + # plt.legend() + # plt.savefig("benchmark_bitpacking.png") + + +def test_vs_hqqpack(): + #requires hqq to be installed + import hqq + import hqq.core.quantize as hqq_quantize + HQQLinear = hqq_quantize.HQQLinear + BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig + from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm + + BASE_QUANT_CONFIG = { + "optimize": True, + "view_as_float": False, + "nbits": 4, + "bitpack": False, + "axis": 1, + } + + def mixed_mm( + shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8, pack_fn = True + ): + qcfg = { + **BASE_QUANT_CONFIG, + **dict(group_size=group_size, axis=axis), + } + M, N, K = shape + + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") + + quant_config = BaseQuantizeConfig( + quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False + ) + quant_config.update({"weight_quant_params": qcfg}) + hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) + W_q, meta = hqq_linear.W_q, hqq_linear.meta + W_q = W_q.to(dtype=quant_dtype) + W_q = ( + W_q.reshape(meta["shape"]) + if quant_config["weight_quant_params"]["bitpack"] == False + else W_q + ) + W_dq = hqq_linear.dequantize() + + scales, zeros = meta["scale"], meta["zero"] + scales = scales.reshape(N, -1) + zeros = zeros.reshape(N, -1) + if pack_fn: + packed_w = pack(W_q.T,4,dim=0,order=False) + else: + packed_w = pack_2xint4(W_q.T) + + if transposed: + x = torch.randn(M, N, dtype=dtype, device="cuda") + hqq_out = x @ W_dq + + tt_out = triton_mixed_mm( + x, + packed_w, + scales.T, + zeros.T, + transposed=True, + group_size=group_size, + fp8_fast_accum=False, + kernel_type=kernel_type, + ) + + else: + x = torch.randn(M, K, dtype=dtype, device="cuda") + hqq_out = x @ W_dq.T + + tt_out = triton_mixed_mm( + x, + packed_w, + scales.T, + zeros.T, + transposed=False, + group_size=group_size, + fp8_fast_accum=False, + kernel_type=kernel_type, + ) + + shapes = [ + [16, 128, 128], + [16, 4096, 4096], + ] + group_sizes = [64, 128] + shape = [16, 128, 128] + group_size = 64 + pack = torch.compile(pack, fullgraph=True) + for i in range(2): + shape = shapes[i] + group_size = group_sizes[i] + print("linear layer size: ", shape) + print("group size: ", group_size) + # run once to compile + test_mixed_mm( + shape, + group_size, + 1, + torch.float16, + True, + "compute_bound", + torch.uint8, + ) + # shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8 + print("pack time (ms): ", benchmark(test_mixed_mm, 100, + shape, + group_size, + 1, + torch.float16, + True, + "compute_bound", + torch.uint8)) + + print("pack_2xint4 time (ms): ", benchmark(test_mixed_mm, 100, + shape, + group_size, + 1, + torch.float16, + True, + "compute_bound", #max autotune doesnt work? + torch.uint8, + pack_fn=False)) + print("") + + +if __name__ == "__main__": + test_vs_existing() + diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index d1c1d261d1..7facc97dc9 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -7,64 +7,139 @@ if not TORCH_VERSION_AFTER_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) -def test_uint4_to_uint8_CPU(): - test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8) - packed = pack(test_tensor, 8, 4, device='cpu') - unpacked = unpack(packed, 4, device='cpu') - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) +dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4)) +dimensions = (2, 1, 0) +orders = (True, False) -def test_uint3_to_int16_col_wise_cpu(): - test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16) - packed = pack(test_tensor,16, 3, False, device='cpu') - unpacked = unpack(packed, 3, False, device='cpu') - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_uint4_to_uint8(): - test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda() - packed = pack(test_tensor, 8, 4) - unpacked = unpack(packed, 4) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -def test_uint4_to_uint8_compile(): - torch._dynamo.config.specialize_int = True - pack_compiled = torch.compile(pack, fullgraph=True) - unpack_compiled = torch.compile(unpack, fullgraph=True) - test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda() - packed = pack_compiled(test_tensor, 8, 4) - unpacked = unpack_compiled(packed, 4) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) + # setup (currently do nothing) + + # tests will run here + yield + + # teardown + # avoid dynamo cache limit issues + torch._dynamo.reset() +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("dim", dimensions) +@pytest.mark.parametrize("order", orders) +def test_CPU(dtype, dim, order): + element_bit_width, element_type,expected_pack_size = dtype + shape = [4, 4, 4] + if element_type == "trinary": + test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8, device='cpu') + else: + test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8, device='cpu') + + packed = pack(test_tensor, + element_bit_width, + element_type=element_type, + dim = dim, + order = order, + container_dtype = torch.uint8, + device='cpu') + assert(packed.shape[dim] == expected_pack_size) + unpacked = unpack(packed, + element_bit_width, + element_type=element_type, + dim = dim, + order = order, + device='cpu') + assert(unpacked.allclose(test_tensor)) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_uint3_to_int16(): - test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda() - packed = pack(test_tensor,16, 3) - unpacked = unpack(packed, 3) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("dim", dimensions) +@pytest.mark.parametrize("order", orders) +def test_GPU(dtype, dim, order): + element_bit_width, element_type,expected_pack_size = dtype + shape = [4, 4, 4] + if element_type == "trinary": + test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() + else: + test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda() + + packed = pack(test_tensor, + element_bit_width, + element_type=element_type, + dim = dim, + order = order, + container_dtype = torch.uint8) + assert(packed.shape[dim] == expected_pack_size) + unpacked = unpack(packed, + element_bit_width, + element_type=element_type, + order = order, + dim = dim) + assert(unpacked.allclose(test_tensor)) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -def test_uint2_to_uint8_col_wise_compile(): - torch._dynamo.config.specialize_int = True - pack_compiled = torch.compile(pack, fullgraph=True) - unpack_compiled = torch.compile(unpack, fullgraph=True) - test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda() - packed = pack_compiled(test_tensor, 8, 2, False) - unpacked = unpack_compiled(packed,2, False) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("dim", dimensions) +@pytest.mark.parametrize("order", orders) +def test_padding(dtype, dim, order): + element_bit_width, element_type,expected_pack_size = dtype + torch._dynamo.config.specialize_int = True + shape =[4, 4, 4] + shape[dim] = 5 + + if element_type == "trinary": + test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() + else: + test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda() + + packed = pack(test_tensor, + element_bit_width, + element_type=element_type, + dim = dim, + container_dtype = torch.uint8, + order = order, + pad= True) + assert packed.shape[dim] == expected_pack_size+1, f"packed.shape[dim] {packed.shape[dim]}" # +1 for this scenario + unpacked = unpack(packed, + element_bit_width, + element_type=element_type, + dim = dim, + order = order) + slices = [slice(None)] * packed.ndim + slices[dim] = slice(None, 5) + assert unpacked[slices].allclose(test_tensor) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_uint3_to_int16_col_wise(): - test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda() - packed = pack(test_tensor,16, 3, False) - unpacked = unpack(packed, 3, False) - unpadded = unpacked[:test_tensor.shape[0], ...] - assert(unpadded.allclose(test_tensor)) +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("dim", dimensions) +@pytest.mark.parametrize("order", orders) +def test_compile(dtype, dim, order): + pack_compile = torch.compile(pack, fullgraph=True, dynamic=True) + unpack_compile = torch.compile(unpack, fullgraph=True, dynamic=True) + element_bit_width, element_type,expected_pack_size = dtype + torch._dynamo.config.specialize_int = True + shape = [4, 4, 4] + if element_type == "trinary": + test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() + else: + test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.int8).cuda() + + packed = pack_compile(test_tensor, element_bit_width, + element_type=element_type, + dim = dim, + container_dtype = torch.int8, + order = order) + assert(packed.shape[dim] == expected_pack_size) + unpacked = unpack_compile(packed, + element_bit_width, + element_type=element_type, + dim = dim, + order = order) + assert(unpacked.allclose(test_tensor)) diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py index 35e471c347..60009d0e63 100644 --- a/torchao/prototype/common/bitpacking.py +++ b/torchao/prototype/common/bitpacking.py @@ -1,101 +1,120 @@ import torch -from functools import reduce +from typing import Optional, Union - - -def unpack(data, data_size, by_rows = True, device="cuda"): +def mod_shape(shape, mod, dim): + """changes a select dimension of the input shape to mod""" + return (*shape[:dim], mod, *shape[dim+1:]) + +def unpack(data: torch.Tensor, + element_bit_width: int, + element_type: Optional[str] = None, + dim: Optional[int] = 0, + order: Optional[bool] = True, + output_dtype: Optional[torch.dtype] = None, + device: Optional[str] ="cuda") -> torch.Tensor: """ Unpacks small dtype elements from a larger dtype. Inputs: - data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype. - data_size: int - the size of the small dtype in bits. - - optional: - by_rows: bool - specifies whether to unpack... - by rows: tensor(n,m) -> tensor(n*scale, m) - or by columns: tensor(n,m) -> tensor(n,m*scale) - - defaults to rows because quantization is typically done by rows - but choose the version which matches how you quantize as this improves memory accesses/performance + data: - a tensor of packed elements + element_bit_width: the size in bits of the elements to unpack + element_type: the dtype of the elements to unpack (uint,trinary,float, etc) + dim: the dimension to unpack along + output_dtype: specify the dtype of the output tensor if it is not the same as the input tensor + order: make sure it matches the value set in the pack function Returns: torch.Tensor - a tensor of the unpacked elements. """ - if by_rows: - return _unpack_by_rows(data, data_size, device) - else: - return _unpack_by_cols(data, data_size) + container_size = torch.iinfo(data.dtype).bits + scale = container_size // element_bit_width + + unpacked = _unpack(data, element_bit_width, container_size, scale, order, dim, device) + if element_type == "trinary": + unpacked = unpacked.to(torch.int8) - 1 + elif output_dtype is not None: + unpacked = unpacked.to(output_dtype) + + return unpacked + +def _unpack(data, element_size, container_size, scale, order, dim, device): + shape = data.shape + unpacked_data = torch.zeros(mod_shape(shape, shape[dim]*scale, dim), dtype=data.dtype).to(device) + nbits = (1 << element_size) - 1 # mask for the last dtype_size bits + for i in range(scale): + if order: + shift_amt = container_size - element_size * (i + 1) + else: + shift_amt = element_size * i + slices = [slice(None)] * unpacked_data.ndim + slices[dim] = slice(i, None, scale) + unpacked_data[slices] = ((data >> shift_amt) & (nbits)).to(data.dtype) + + # stack the unpacked data and reshape to the original shape + return unpacked_data.view(mod_shape(shape,scale*shape[dim], dim)) -def pack(data, container_size, data_size, by_rows = True, device="cuda"): + +def pack(data: torch.Tensor, + element_bit_width: int, + element_type: Optional[str] = None, + dim: Optional[int] = 0, + container_dtype: Optional[torch.dtype] = None, + pad: Optional[bool] = False, + order: Optional[bool] = True, + device: Optional[str] = "cuda") -> torch.Tensor: """ - Packs small dtype elements into a larger dtype. - Pads rows to be divisible by the scale. + Packs small dtype elements into a container of a larger dtype. Inputs: - data: torch.Tensor - a tensor of unpacked elements of a small dtype. - container_size: int - the size of the large dtype in bits. - data_size: int - the size of the small dtype in bits. + data: a tensor of unpacked elements of a small dtype. The dtype used for the data will be used for the container. + dim: the dimension to pack along + element_dtype: the dtype of the elements to pack + container_dtype: specify the dtype of the container if the data is not already inside a tensor of that dtype + pad: if set to true, pads the dimension to be divisible by the scale + order: if set to true, packs elements such that the lower index elements occupy the most significant bits - optional: - by_rows: bool - specifies whether to pack values... - by rows: tensor(n,m) -> tensor(n//scale, m) - or by columns: tensor(n,m) -> tensor(n,m//scale) + Returns: torch.Tensor - a tensor of packed elements. - defaults to rows because quantization is typically done by rows - but choose the version which matches how you quantize as this improves memory accesses/performance - Returns: torch.Tensor - a tensor of packed elements. + For example, packing 4-bit elements into 8-bit containers. + along dimension 0: along dimension 1: + (0, 9, B, 4) --> ( 9, B4) + (3, 8, F, C) --> (38, FC) + | | | | + v v v v + (3, 98, BF, 4C) + + if order was set to false: + (30, 89, FB, C4) """ - if by_rows: - return _pack_by_rows(data, container_size, data_size, device) - else: - return _pack_by_cols(data, container_size, data_size, device) -def _unpack_by_rows(data, data_size, device) -> torch.Tensor: - shape = data.shape - scale = data.element_size() * 8 // data_size + if element_type == "trinary": + data = data + 1 + + if container_dtype is not None: + data = data.to(container_dtype) - unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).to(device) - nbits = (1 << data_size) - 1 # mask for the last dtype_size bits - for i in range(scale): - shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint - unpacked_data[i::scale] = ((data >> shift_amt) & (nbits)) - return unpacked_data - -def _unpack_by_cols(data, data_size) -> torch.Tensor: - shape = data.shape - scale = data.element_size() * 8 // data_size - unpacked_data = [] - nbits = (1 << data_size) - 1 # mask for the last dtype_size bits - for i in range(scale): - shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint - unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype)) - return torch.stack(unpacked_data,dim=-1).view(*shape[:-1],shape[-1]*scale) # stack the unpacked data and reshape to the original shape - -def _pack_by_rows(data, container_size, data_size, device) -> torch.Tensor: + container_size = torch.iinfo(data.dtype).bits + scale = container_size // element_bit_width - scale = container_size // data_size - assert scale > 1, f"container_size ({container_size}) is not larger than data_size ({data_size})" - assert data.shape[0] >= scale, f"not enough values to pack, data.shape[0] ({data.shape[0]}) < scale ({scale})" - # pad the data to be divisible by scale - if data.shape[0] % scale != 0: - padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).to(device) - data = torch.cat([data, padding], dim=0).cuda() + if pad and data.shape[dim] % scale != 0: + padding = torch.zeros(mod_shape(data.shape, scale - data.shape[dim] % scale, dim), dtype=data.dtype).to(device) + data = torch.cat([data, padding], dim=dim).to(device) + - shape = data.shape - ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)]) - return ret.view(shape[0] // scale, *shape[1:]).to(device) + torch._assert(data.shape[dim] >= scale, f"not enough values to pack along dimension {dim}") + torch._assert(data.shape[dim] % scale == 0, "size of pack dimension not divisble by scale") + return _pack(data, container_size, element_bit_width, scale, dim, order, device) + -def _pack_by_cols(data, container_size, data_size, device) -> torch.Tensor: - scale = container_size // data_size - assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})" - # pad the data to be divisible by scale - if data.shape[-1] % scale != 0: - padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).to(device) - data = torch.cat([data, padding], dim=-1).cuda() + +def _pack(data, container_size, element_bit_width, scale, dim, order, device) -> torch.Tensor: + packed = torch.zeros(mod_shape(data.shape, data.shape[dim] // scale, dim), dtype=data.dtype).to(device) + slices = [slice(None)] * packed.ndim + for i in range(scale): + slices[dim] = slice(i, None, scale) + if order: + packed |= data[slices] << container_size-element_bit_width*(i+1) + else: + packed |= data[slices] << element_bit_width*i + return packed - shape = data.shape - data = data.contiguous().view(-1) - #shift the data to the different indexes within the larger dtype and then union them together - ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)]) - return ret.view(*shape[:-1],shape[-1] // scale).to(device) \ No newline at end of file