From 0d4166879c4a17c2de5e587aa45b28ff4ebc313b Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 02:34:01 +0000 Subject: [PATCH 01/18] add test / benchmark --- benchmarks/benchmark_hqq.py | 134 ++++++++++++++++++++++++++++++++++++ test/hqq/test_triton_mm.py | 101 +++++++++++++++++++++++++++ 2 files changed, 235 insertions(+) create mode 100644 benchmarks/benchmark_hqq.py create mode 100644 test/hqq/test_triton_mm.py diff --git a/benchmarks/benchmark_hqq.py b/benchmarks/benchmark_hqq.py new file mode 100644 index 0000000000..a51401a3ab --- /dev/null +++ b/benchmarks/benchmark_hqq.py @@ -0,0 +1,134 @@ +import torch +from termcolor import colored + +import pandas as pd +from hqq.core.quantize import HQQLinear, BaseQuantizeConfig +from torchao.prototype.hqq.hqq_tinygemm_linear import HQQLinearTorchWeightOnlyInt4 +from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4 + +from triton.testing import do_bench + + +BASE_QUANT_CONFIG = { + "optimize": True, + "view_as_float": False, + "nbits": 4, + "bitpack": False, + "axis": 1, +} + + +def bench_custom_kernel(x, W_q, scales, zeros, group_size, kernel_type="max_autotune", fp8_fast_accum=False): + packed_w = pack_2xint4(W_q.T) + + def fn(): + _ = triton_mixed_mm( + x, + packed_w, + scales.T, + zeros.T, + group_size=group_size, + fp8_fast_accum=fp8_fast_accum, + kernel_type=kernel_type, + ) + + t = do_bench(fn) + return t + + +def bench_hqq(x, hqq_linear: HQQLinear): + def fn(): + _ = hqq_linear.forward(x) + + t = do_bench(fn) + return t + + +def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8): + qcfg = { + **BASE_QUANT_CONFIG, + **dict(group_size=group_size, axis=axis), + } + M, N, K = shape + + x = torch.randn(M, K, dtype=dtype, device="cuda") + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") + + quant_config = BaseQuantizeConfig( + quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False + ) + quant_config.update({"weight_quant_params": qcfg}) + + hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) + + # Reference + ref_time = bench_hqq(x, hqq_linear) + + # Custom kernel + W_q, meta = hqq_linear.W_q, hqq_linear.meta + scales, zeros = meta["scale"], meta["zero"] + + W_q = ( + W_q.reshape(meta["shape"]) + if quant_config["weight_quant_params"]["bitpack"] == False + else W_q + ) + W_q = W_q.to(dtype=quant_dtype) + scales = scales.reshape(N, -1) + zeros = zeros.reshape(N, -1) + tt_time = bench_custom_kernel(x, W_q, scales, zeros, group_size) + + if dtype == torch.bfloat16: + _ = quant_config["weight_quant_params"].pop("bitpack") + hqq_int4mm = HQQLinearTorchWeightOnlyInt4( + linear, quant_config, compute_dtype=dtype, del_orig=False + ) + int4_time = bench_hqq(x, hqq_int4mm) + + print(colored(f"{shape=} {group_size=} {dtype=}:", attrs=["bold"])) + + print( + colored(f"Ref: {ref_time:.4f}", "blue"), + colored(f"Triton: {tt_time:.4f}", "green"), + colored(f"Torch int4mm: {int4_time:.4f}", "yellow") + if dtype == torch.bfloat16 + else "", + ) + print() + return ref_time, tt_time, int4_time if dtype == torch.bfloat16 else None + + +SHAPES = [ + [16, 4096, 4096], + [32, 4096, 4096], + [128, 4096, 4096], + [256, 4096, 4096], + [512, 4096, 4096], + [1024, 4096, 4096], +] + +DTYPES = [torch.bfloat16] # , torch.float16] +GROUP_SIZES = [128] + +print(torch.cuda.get_device_properties(0)) + +HEADERS = [ + "M", + "N", + "K", + "group_size", + "dtype", + "ref", + "triton", + "tinygemm", +] +data = [] +for shape in SHAPES: + for group_size in GROUP_SIZES: + for dtype in DTYPES: + timings = run_benchmark(shape, group_size, dtype) + data.append((*shape, group_size, dtype, *timings)) + + +df = pd.DataFrame(data, columns=HEADERS) +df.to_csv("benchmark_triton.csv", index=False) diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py new file mode 100644 index 0000000000..a11a398896 --- /dev/null +++ b/test/hqq/test_triton_mm.py @@ -0,0 +1,101 @@ +import itertools + +import torch +from termcolor import colored + +from hqq.core.quantize import HQQLinear, BaseQuantizeConfig +from hqq.kernels.custom_quant.triton import triton_mixed_mm, pack_2xint4 +from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4 +from torchao.prototype.hqq.hqq_tinygemm_linear import HQQLinearTorchWeightOnlyInt4 + + +#TODO: refactor to pytest + +#Test configs +SHAPES = [ + # [16, 128], + [16, 128, 128], + [16, 4096, 4096], + # [1024, 4096], + # [4096, 4096], + # [4096, 11008], +] + +DTYPES = [torch.bfloat16, torch.float16] +GROUP_SIZES = [64, 128] +AXES = [1] #Only axis = 1 supported +TRITON_KERNEL_TYPE = ["compute_bound"] #["max_autotune", "compute_bound"] +TEST_CONFIGS = list(itertools.product(SHAPES, GROUP_SIZES, AXES, DTYPES, TRITON_KERNEL_TYPE)) + +BASE_QUANT_CONFIG = { + "optimize": True, + "view_as_float": False, + "nbits": 4, + # "quant_dtype": torch.uint8, + "bitpack": False, + "axis": 1, +} + + +def check(expected, actual, cfg_str, max_diff=1e-3): + passed = torch.allclose(expected, actual, atol=max_diff, rtol=max_diff) + max_err = (expected - actual).abs().max() + if not passed: + print(colored(f"{cfg_str}: Failed! Max error: {max_err}", "red", attrs=["bold"])) + else: + print(colored(f"{cfg_str}: Passed! Max error: {max_err}", "green", attrs=["bold"])) + +def test_mixed_mm(shape, group_size, axis, dtype, kernel_type, quant_dtype=torch.uint8): + # print(f"Test: {shape}, {group_size}, {axis}, {dtype}") + qcfg = { + **BASE_QUANT_CONFIG, + **dict(group_size=group_size, axis=axis), + } + M, N, K = shape + + x = torch.randn(M, K, dtype=dtype, device="cuda") + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") + + quant_config = BaseQuantizeConfig( + quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False + ) + quant_config.update({"weight_quant_params": qcfg}) + hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) + W_q, meta = hqq_linear.W_q, hqq_linear.meta + W_q = ( + W_q.reshape(meta["shape"]) + if quant_config["weight_quant_params"]["bitpack"] == False + else W_q + ) + scales, zeros = meta["scale"], meta["zero"] + + #Reference + hqq_out = hqq_linear.forward(x) + + ##Triton + W_q = W_q.to(dtype=quant_dtype) + packed_w = pack_2xint4(W_q.T) + scales = scales.reshape(N, -1) + zeros = zeros.reshape(N, -1) + tt_out = triton_mixed_mm( + x, packed_w, scales.T, zeros.T, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type + ) + + cfg_str = f"Test config {shape} {group_size} {dtype}" + # err = (hqq_out - tt_out).abs().max() + check(hqq_out, tt_out, cfg_str + " triton", max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3) + + if dtype == torch.bfloat16: + _ = quant_config["weight_quant_params"].pop("bitpack") + hqq_int4mm = HQQLinearTorchWeightOnlyInt4( + linear, quant_config, compute_dtype=dtype, del_orig=False + ) + hqq_int4_out = hqq_int4mm.forward(x) + err = (hqq_int4_out - hqq_out).abs().max() + check(hqq_out, hqq_int4_out, cfg_str + " torch_tinygemm", max_diff=1e-2) + + print() + + +for test in TEST_CONFIGS: + test_mixed_mm(*test) From 497d8db40ee67e3df988f3603898f57d33b280d2 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 02:37:06 +0000 Subject: [PATCH 02/18] add kernels --- torchao/prototype/hqq/README.md | 41 +++ torchao/prototype/hqq/__init__.py | 1 + torchao/prototype/hqq/hqq_tinygemm_linear.py | 256 +++++++++++++++ torchao/prototype/hqq/kernels.py | 313 +++++++++++++++++++ torchao/prototype/hqq/mixed_mm.py | 97 ++++++ 5 files changed, 708 insertions(+) create mode 100644 torchao/prototype/hqq/README.md create mode 100644 torchao/prototype/hqq/__init__.py create mode 100644 torchao/prototype/hqq/hqq_tinygemm_linear.py create mode 100644 torchao/prototype/hqq/kernels.py create mode 100644 torchao/prototype/hqq/mixed_mm.py diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md new file mode 100644 index 0000000000..d3b8608052 --- /dev/null +++ b/torchao/prototype/hqq/README.md @@ -0,0 +1,41 @@ +## Fused `int4 / fp16` Quant Matmul + +Fused gemm for asymmetric quantized weights. Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. + +The kernel packs `u8 / s8` weights and fuses dequantization with the matmul. + +- tested for `float16 / bfloat16` activations, scales, and zeros +- autotuned for both compute-bound and io-bound configs + +### Performance + +Initial benchmarking demonstrates promising results, scaling well across io-bound and compute-bound workloads: + +| | M | N | K | group_size | dtype | hqq_ref | triton | tinygemm | +| --- | ---- | ---- | ---- | ---------- | -------------- | ------- | ------ | -------- | +| 0 | 16 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2675 | 0.0633 | 0.0382 | +| 1 | 32 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2669 | 0.0704 | 0.0649 | +| 2 | 128 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2689 | 0.0960 | 0.2523 | +| 3 | 256 | 4096 | 4096 | 128 | torch.bfloat16 | 0.3268 | 0.1355 | 0.5192 | +| 4 | 512 | 4096 | 4096 | 128 | torch.bfloat16 | 0.3628 | 0.2369 | 1.0892 | +| 5 | 1024 | 4096 | 4096 | 128 | torch.bfloat16 | 0.5133 | 0.4753 | 2.2016 | + +- Times are in `ms`, see `benchmarks/benchmark_hqq.py`. +- `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul). +- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. + +GPU details: + +``` +_CudaDeviceProperties(name='NVIDIA RTX A6000', major=8, minor=6, total_memory=48676MB, multi_processor_count=84) +``` + +### NOTE + +> This implementation requires `triton >= 3.0.0`. + +- Running tests / benchmarks requires installation of `hqq`: + + ``` + pip install hqq + ``` diff --git a/torchao/prototype/hqq/__init__.py b/torchao/prototype/hqq/__init__.py new file mode 100644 index 0000000000..c97591c475 --- /dev/null +++ b/torchao/prototype/hqq/__init__.py @@ -0,0 +1 @@ +from .mixed_mm import triton_mixed_mm, pack_2xint4 \ No newline at end of file diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py new file mode 100644 index 0000000000..0c4ae45c61 --- /dev/null +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -0,0 +1,256 @@ + +#mobicham's tinygemm hqq eval script +import torch + +device = "cuda" + + +import torch, copy +from torch import nn, Tensor + +from hqq.core.quantize import * +from hqq.core.utils import * + +import torch.nn.functional as F + + +class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): + def __init__( + self, + linear_layer: nn.Module | None, + quant_config: dict, + del_orig: bool = True, + compute_dtype: torch.dtype = torch.bfloat16, + device: str = "cuda", + initialize: bool = True, + inner_k_tiles=8, + padding=True, + ): + super().__init__() + + self.ready = False + self.in_gpu = False + self.bias = None + self.device = device + self.compute_dtype = compute_dtype + self.quant_config = copy.deepcopy(quant_config) + self.del_orig = del_orig + + weight_quant_params = self.quant_config["weight_quant_params"] + self.groupsize = weight_quant_params["group_size"] + self.nbits = weight_quant_params["nbits"] + self.inner_k_tiles = inner_k_tiles + self.padding = padding + + assert self.nbits in [1, 2, 4], "Unsupported nbits" + assert self.groupsize in [None, 32, 64, 128, 256], "Unsupported groupsize" + assert self.inner_k_tiles in [2, 4, 8], "Unsupported tile" + + self.linear_layer = linear_layer + self.compute_dtype = compute_dtype + + if initialize: + self.initialize() + + ###################### Initializers ###################### + def initialize_with_hqq_quants(self, W_q, meta, bias=None): + self.padding = ( + False # Force padding off, a bit tricky to post-pad with grouping + ) + + self.set_shape(meta["shape"]) + self.process_hqq_quants(W_q, meta) + self.bias = bias + self.ready = True + self.in_gpu = True + torch.cuda.empty_cache() + + return self + + def initialize(self): + if self.linear_layer is not None: + W = self.linear_layer.weight.data + self.set_shape(W.shape) + + if self.in_features_diff > 0: + W = F.pad(W, pad=(0, self.in_features_diff), value=0) + + W_q, meta = self.quantize(W, **self.quant_config) + self.process_hqq_quants(W_q, meta) + del W_q, meta + + self.bias = ( + None + if (self.linear_layer.bias is None) + else self.linear_layer.bias.to( + dtype=self.compute_dtype, device=self.device + ) + ) + + if self.del_orig: + del self.linear_layer + + self.ready = True + self.in_gpu = True + torch.cuda.empty_cache() + + return self + + ###################### Quantize/packing ###################### + + def quantize( + self, + W: Tensor, + weight_quant_params: dict, + scale_quant_params=dict | None, + zero_quant_params=dict | None, + offload_meta=False, + ): + W_q, meta = Quantizer.quantize( + W, + **weight_quant_params, + device=self.device, + compute_dtype=self.compute_dtype, + bitpack=False, + ) + + # ToDO: meta quantization + + return W_q, meta + + # TODO: move these to utils + @torch.no_grad() + def reshape_meta_axis1(self, meta_tensor, new_group_size, shape): + meta_tensor = meta_tensor.repeat([1, shape[1]]).reshape(shape) + meta_tensor = torch.mean( + meta_tensor.reshape([-1, new_group_size]), axis=1, keepdim=True + ) + return meta_tensor + + def find_multiple(self, n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + def set_shape(self, shape): + self.shape = shape + self.in_features = shape[1] + self.out_features = shape[0] + + self.origin_in_features = self.in_features + if self.padding: + self.in_features = self.find_multiple(self.in_features, 1024) + + self.in_features_diff = self.in_features - self.origin_in_features + + @torch.no_grad() + def process_hqq_quants(self, W_q, meta): + scales = meta["scale"] + zeros = meta["zero"] + shape = meta["shape"] + + if meta["packing"] is not None: + W_q = Quantizer.unpack[meta["packing"]](W_q) + + if self.groupsize is None: + self.groupsize = 128 + W_q = W_q.reshape([-1, self.groupsize]) + scales = self.reshape_meta_axis1(scales, self.groupsize, shape) + zeros = self.reshape_meta_axis1(zeros, self.groupsize, shape) + + W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants( + W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits + ) + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + W_q_torch, self.inner_k_tiles + ) + self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch) + + del W_q_torch, scales_torch, zeros_torch + torch.cuda.empty_cache() + + @torch.no_grad() + def hqq_quants_to_torch_quants( + self, W_q: Tensor, scales: Tensor, zeros: Tensor, shape, nbits=4 + ): + W_q = W_q.to(dtype=self.compute_dtype, device=self.device) + scales = scales.to(dtype=self.compute_dtype, device=self.device) + zeros = zeros.to(dtype=self.compute_dtype, device=self.device) + + max_int = 2**nbits - 1 + min_int = 0 + dump = 2 ** (nbits - 1) + + # HQQ -> torch logic + new_zeros = (scales * dump) - zeros * scales + + min_val = new_zeros - scales * dump + + # group_quantize_tensor_from_qparams + W_r = (W_q - zeros) * scales + + W_q = ( + W_r.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape(shape) + .contiguous() + ) + + # group_dequantize_tensor_from_qparams + # W_r = W_q*scales + min_val + + scales = scales.contiguous().reshape(shape[0], -1) + new_zeros = new_zeros.contiguous().reshape(shape[0], -1) + + return W_q, scales, new_zeros + + def pack_scales_and_zeros(self, scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + ###################### Forward/matmul ###################### + + @torch.jit.ignore() + def matmul(self, x): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) + new_shape = origin_x_size[:-1] + (self.out_features,) + c = c.reshape(new_shape) + return c + + # TODO without matmul + def dequantize(self): + return self.matmul( + torch.eye(self.in_features, dtype=self.compute_dtype, device=self.device) + )[: self.origin_in_features].t() + + # TODO: backward + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.to(self.compute_dtype) + + if self.in_features_diff > 0: + x = F.pad(x, pad=(0, self.in_features_diff)) + + out = self.matmul(x) + + if self.bias is not None: + out += self.bias + return out diff --git a/torchao/prototype/hqq/kernels.py b/torchao/prototype/hqq/kernels.py new file mode 100644 index 0000000000..2065e4af47 --- /dev/null +++ b/torchao/prototype/hqq/kernels.py @@ -0,0 +1,313 @@ +from triton import Config +import triton.language as tl +import triton + +#TODO: add early config prune and estimate_matmul_time to reduce autotuning time +# from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ) + ) + return configs + + +def get_configs_compute_bound(): + configs = [ + # basic configs for compute-bound matmuls + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + # good for int8 + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + ] + return configs + + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +MIXED_MM_HEURISTICS = { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + "BLOCK_K": lambda args: min(args["BLOCK_K"], args["QGROUP_SIZE"]), + "SPLIT_K": lambda args: 1 + if args["IS_BFLOAT16"] + else args["SPLIT_K"], # atomic add not supported for bfloat16 +} + + + +@triton.jit +def _mixed_mm_kernel( + # Operands + A, + B, + scales_ptr, + zeros_ptr, + C, + # Matrix dims. + M, + N, + K, + # a, b, c, scales / zeros strides + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, + stride_scale_k, + stride_scale_n, + # Meta-params + IS_BFLOAT16: tl.constexpr, + QGROUP_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, # = 32, + BLOCK_N: tl.constexpr, # = 32, + BLOCK_K: tl.constexpr, # = 16, # + SPLIT_K: tl.constexpr, # = 1, + EVEN_K: tl.constexpr, # = True, + GROUP_M: tl.constexpr = 8, # 32, + # tl.dot options + acc_dtype: tl.constexpr = tl.float32, + input_precision: tl.constexpr = "ieee", + fp8_fast_accum: tl.constexpr = False, +): + """Mixed matmul kernel + + A has shape (M, K) and is float16, bfloat16, or float32 + + B is i4 / s4 and has shape (K // 2, N) and is packed as uint8 / int8. See `packed_2xint4` for details. + + Scales and zeros are of shape (NUM_GROUPS, N) and are same dtype as A, where NUM_GROUPS = (K // QGROUP_SIZE) + QGROUP_SIZE should be a multiple of BLOCK_K such that a vector of scales / zeros is loaded and broadcasted to block shape + per mainloop iteration. + + NOTE: Assumes that the quantization grouping was done along the K dimension originally (i.e., QGROUP_SIZE consecutive elements + of original weight matrix in the K dimension were grouped together when calculating min / max scaling factors). + """ + + # tl.static_assert(B.dtype.element_ty == tl.int8 or B.dtype.element_ty == tl.uint8) + tl.static_assert(QGROUP_SIZE % BLOCK_K == 0) + + # Threadblock swizzling + pid = tl.program_id(0) + pid_z = tl.program_id(1) + + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // group_size + + rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + + rak = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + + # BLOCK_K for b is effectively BLOCK_K // 2 + rbk = pid_z * BLOCK_K // 2 + tl.arange(0, BLOCK_K // 2) + + A = A + (ram[:, None] * stride_am + rak[None, :] * stride_ak) + B = B + (rbk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + scale_offset_n = pid_n * stride_scale_n * BLOCK_N + offsets_scale_n = scale_offset_n + tl.arange(0, BLOCK_N) * stride_scale_n + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + qb = tl.load(B) + else: + k_remaining_a = K - k * (BLOCK_K * SPLIT_K) + k_remaining_b = K - k * (BLOCK_K * SPLIT_K) // 2 # Note the division by 2 + + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rak[None, :] < k_remaining_a, other=_0) + qb = tl.load(B, mask=rbk[:, None] < k_remaining_b, other=_0) + + scale_offset_k = k * BLOCK_K * SPLIT_K * stride_scale_k // QGROUP_SIZE + scales = tl.load(scales_ptr + offsets_scale_n + scale_offset_k) + zeros = tl.load(zeros_ptr + offsets_scale_n + scale_offset_k) + + # Unpack qweights -- credit jlebar! + _4_i8 = tl.full((1,), 4, dtype=tl.int8) + qb_lo = (qb << _4_i8) >> _4_i8 + qb_hi = qb >> _4_i8 + + # Upcast to fp16 + # TODO add bfloat16 + if IS_BFLOAT16: + dq_b = ( + tl.join( + qb_lo.to(tl.float16).to(A.dtype.element_ty), + qb_hi.to(tl.float16).to(A.dtype.element_ty), + ) + .permute(0, 2, 1) + .reshape(BLOCK_K, BLOCK_N) + ) + else: + dq_b = ( + tl.join( + qb_lo.to(A.dtype.element_ty), + qb_hi.to(A.dtype.element_ty), + ) + .permute(0, 2, 1) + .reshape(BLOCK_K, BLOCK_N) + ) + + # Scale upcasted weights + # Note that we broadcast the scales --> the assumption is that all scales fall within a single QGROUP + # This condition is statically check (see assertions above) + dq_b = (dq_b - zeros[None, :]) * scales[None, :] + + if fp8_fast_accum: + acc = tl.dot( + a, dq_b, acc, out_dtype=acc_dtype, input_precision=input_precision + ) + else: + acc += tl.dot(a, dq_b, out_dtype=acc_dtype, input_precision=input_precision) + A += BLOCK_K * SPLIT_K * stride_ak + # Advance by half the block size, since each block is unpacked and upcasted into two fp16 values + B += BLOCK_K * SPLIT_K * stride_bk // 2 + + acc = acc.to(C.dtype.element_ty) + + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +_mixed_mm = triton.heuristics(MIXED_MM_HEURISTICS)(_mixed_mm_kernel) +mixed_mm_kernel_max_autotune = triton.autotune(configs=get_configs_compute_bound() + get_configs_io_bound(), key=["M", "N", "K"])(_mixed_mm) +mixed_mm_kernel_compute_bound = triton.autotune(configs=get_configs_compute_bound(), key=["M", "N", "K"])(_mixed_mm) diff --git a/torchao/prototype/hqq/mixed_mm.py b/torchao/prototype/hqq/mixed_mm.py new file mode 100644 index 0000000000..d56fa582b6 --- /dev/null +++ b/torchao/prototype/hqq/mixed_mm.py @@ -0,0 +1,97 @@ +import torch +from triton import cdiv +import triton.language as tl +from .kernels import mixed_mm_kernel_compute_bound, mixed_mm_kernel_max_autotune +#credit jlebar +def pack_2xint4(t): + """ + The packing format is such that consecutive rows are packed into a lower / upper bits + E.g., + Original, unpacked B (dtype i8): + [ + [0, 1, 2, 3] + [4, 5, 6, 7] + [8, 9, 10, 11] + [12, 13, 14, 15] + ] + Packed B: + [ + [0|4, 1|5, 2|6, 3|7] + [8|12, 9|13, 10|14, 11|15] + ] + (Note each entry in `Packed B` is shown lsb->msb) + """ + assert t.dtype == torch.int8 or t.dtype == torch.uint8 + t = t.reshape(t.shape[0] // 2, 2, t.shape[1]).permute(1, 0, 2) + return (t[0] & 0xF) | (t[1] << 4) + +def triton_mixed_mm( + a, + b, + scales, + zeros, + group_size, + acc_dtype=None, + input_precision="ieee", + fp8_fast_accum=False, + kernel_type="compute_bound", +): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0] * 2, "incompatible dimensions" + assert b.dtype == torch.int8 or b.dtype == torch.uint8, "b must be int8 or uint8" + assert scales.ndim == 2 + assert kernel_type in ["max_autotune", "compute_bound"] + + M, K = a.shape + _, N = b.shape + assert scales.shape[1] == N + assert scales.shape[0] == K // group_size + assert scales.dtype == a.dtype + assert scales.shape == zeros.shape + assert zeros.dtype == a.dtype + + # Assumes c is same type as a + c = torch.empty((M, N), device=device, dtype=a.dtype) + if acc_dtype is None: + acc_dtype = tl.float32 + + grid = lambda META: ( + cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]), + META["SPLIT_K"], + ) + + if kernel_type == "max_autotune": + kernel = mixed_mm_kernel_max_autotune + else: + kernel = mixed_mm_kernel_compute_bound + + kernel[grid]( + a, + b, + scales, + zeros, + c, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), + scales.stride(0), + scales.stride(1), + IS_BFLOAT16=a.dtype == torch.bfloat16, + QGROUP_SIZE=group_size, + acc_dtype=acc_dtype, + input_precision=input_precision, + fp8_fast_accum=fp8_fast_accum, + ) + return c From 8db9f51fb82d5cca9d9c0e80dc7a93503e61fa2b Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 02:48:41 +0000 Subject: [PATCH 03/18] update readme --- torchao/prototype/hqq/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index d3b8608052..26fe857086 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -6,6 +6,8 @@ The kernel packs `u8 / s8` weights and fuses dequantization with the matmul. - tested for `float16 / bfloat16` activations, scales, and zeros - autotuned for both compute-bound and io-bound configs +- assumes operand B of the `gemm` is is the quantized type. +- requires quantization along in-features, i.e., the `K` dimension, or `axis=1`, of `torch.linear.weight`. ### Performance From 2a18357813a661f9db05459e15af7a4ea0488997 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 02:52:16 +0000 Subject: [PATCH 04/18] more readme edits --- torchao/prototype/hqq/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 26fe857086..1fe133a5c4 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -7,7 +7,7 @@ The kernel packs `u8 / s8` weights and fuses dequantization with the matmul. - tested for `float16 / bfloat16` activations, scales, and zeros - autotuned for both compute-bound and io-bound configs - assumes operand B of the `gemm` is is the quantized type. -- requires quantization along in-features, i.e., the `K` dimension, or `axis=1`, of `torch.linear.weight`. +- requires quantization along `in-features`, i.e., the `K` dimension, or `axis=1`, of `torch.linear.weight`. ### Performance @@ -34,7 +34,7 @@ _CudaDeviceProperties(name='NVIDIA RTX A6000', major=8, minor=6, total_memory=48 ### NOTE -> This implementation requires `triton >= 3.0.0`. +This implementation requires **`triton >= 3.0.0`**. - Running tests / benchmarks requires installation of `hqq`: From f11a59fa678eb86daa5b119d7cadcdd4733b5971 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 02:58:41 +0000 Subject: [PATCH 05/18] edit readme --- torchao/prototype/hqq/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 1fe133a5c4..d45a3bd888 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -2,7 +2,7 @@ Fused gemm for asymmetric quantized weights. Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. -The kernel packs `u8 / s8` weights and fuses dequantization with the matmul. +The kernel packs `u4 / s4` weights and fuses dequantization with the matmul. - tested for `float16 / bfloat16` activations, scales, and zeros - autotuned for both compute-bound and io-bound configs From 2e768392b3b3e421adbd6ba1a9a0590ea10dbbb0 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 19:06:10 +0000 Subject: [PATCH 06/18] add transpose test --- test/hqq/test_triton_mm.py | 60 +++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 24 deletions(-) diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index a11a398896..dca55477ae 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -24,8 +24,10 @@ DTYPES = [torch.bfloat16, torch.float16] GROUP_SIZES = [64, 128] AXES = [1] #Only axis = 1 supported +TRANSPOSED = [False] TRITON_KERNEL_TYPE = ["compute_bound"] #["max_autotune", "compute_bound"] -TEST_CONFIGS = list(itertools.product(SHAPES, GROUP_SIZES, AXES, DTYPES, TRITON_KERNEL_TYPE)) + +TEST_CONFIGS = list(itertools.product(SHAPES, GROUP_SIZES, AXES, DTYPES, TRANSPOSED, TRITON_KERNEL_TYPE)) BASE_QUANT_CONFIG = { "optimize": True, @@ -45,15 +47,14 @@ def check(expected, actual, cfg_str, max_diff=1e-3): else: print(colored(f"{cfg_str}: Passed! Max error: {max_err}", "green", attrs=["bold"])) -def test_mixed_mm(shape, group_size, axis, dtype, kernel_type, quant_dtype=torch.uint8): +def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8): # print(f"Test: {shape}, {group_size}, {axis}, {dtype}") qcfg = { **BASE_QUANT_CONFIG, **dict(group_size=group_size, axis=axis), } M, N, K = shape - - x = torch.randn(M, K, dtype=dtype, device="cuda") + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") quant_config = BaseQuantizeConfig( @@ -62,37 +63,48 @@ def test_mixed_mm(shape, group_size, axis, dtype, kernel_type, quant_dtype=torch quant_config.update({"weight_quant_params": qcfg}) hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) W_q, meta = hqq_linear.W_q, hqq_linear.meta + W_q = W_q.to(dtype=quant_dtype) W_q = ( W_q.reshape(meta["shape"]) if quant_config["weight_quant_params"]["bitpack"] == False else W_q ) + W_dq = hqq_linear.dequantize() + scales, zeros = meta["scale"], meta["zero"] - - #Reference - hqq_out = hqq_linear.forward(x) - - ##Triton - W_q = W_q.to(dtype=quant_dtype) - packed_w = pack_2xint4(W_q.T) scales = scales.reshape(N, -1) zeros = zeros.reshape(N, -1) - tt_out = triton_mixed_mm( - x, packed_w, scales.T, zeros.T, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type - ) - cfg_str = f"Test config {shape} {group_size} {dtype}" - # err = (hqq_out - tt_out).abs().max() - check(hqq_out, tt_out, cfg_str + " triton", max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3) + + if transposed: + x = torch.randn(M, N, dtype=dtype, device="cuda") + hqq_out = x @ W_dq + + #Pack uint8 W_q, then run fused dequant matmul + packed_w = pack_2xint4(W_q) + tt_out = triton_mixed_mm( + x, packed_w, scales, zeros, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type + ) + else: + x = torch.randn(M, K, dtype=dtype, device="cuda") + hqq_out = x @ W_dq.T - if dtype == torch.bfloat16: - _ = quant_config["weight_quant_params"].pop("bitpack") - hqq_int4mm = HQQLinearTorchWeightOnlyInt4( - linear, quant_config, compute_dtype=dtype, del_orig=False + packed_w = pack_2xint4(W_q.T) + tt_out = triton_mixed_mm( + x, packed_w, scales.T, zeros.T, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type ) - hqq_int4_out = hqq_int4mm.forward(x) - err = (hqq_int4_out - hqq_out).abs().max() - check(hqq_out, hqq_int4_out, cfg_str + " torch_tinygemm", max_diff=1e-2) + + cfg_str = f"Test config {shape} {group_size} {dtype} {transposed} {kernel_type}" + check(hqq_out, tt_out, cfg_str + " triton", max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3) + + # if dtype == torch.bfloat16: + # _ = quant_config["weight_quant_params"].pop("bitpack") + # hqq_int4mm = HQQLinearTorchWeightOnlyInt4( + # linear, quant_config, compute_dtype=dtype, del_orig=False + # ) + # hqq_int4_out = hqq_int4mm.forward(x) + # err = (hqq_int4_out - hqq_out).abs().max() + # check(hqq_out, hqq_int4_out, cfg_str + " torch_tinygemm", max_diff=1e-2) print() From 19c43c2859e2201411b1bfa05b460b395995610d Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 21:17:24 +0000 Subject: [PATCH 07/18] transpose test pass --- test/hqq/test_triton_mm.py | 18 ++++++------ torchao/prototype/hqq/kernels.py | 47 +++++++++++++++++++++++++------ torchao/prototype/hqq/mixed_mm.py | 7 +++-- 3 files changed, 52 insertions(+), 20 deletions(-) diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index dca55477ae..de5286d43c 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -6,25 +6,18 @@ from hqq.core.quantize import HQQLinear, BaseQuantizeConfig from hqq.kernels.custom_quant.triton import triton_mixed_mm, pack_2xint4 from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4 -from torchao.prototype.hqq.hqq_tinygemm_linear import HQQLinearTorchWeightOnlyInt4 -#TODO: refactor to pytest - #Test configs SHAPES = [ - # [16, 128], [16, 128, 128], [16, 4096, 4096], - # [1024, 4096], - # [4096, 4096], - # [4096, 11008], ] DTYPES = [torch.bfloat16, torch.float16] GROUP_SIZES = [64, 128] AXES = [1] #Only axis = 1 supported -TRANSPOSED = [False] +TRANSPOSED = [True] TRITON_KERNEL_TYPE = ["compute_bound"] #["max_autotune", "compute_bound"] TEST_CONFIGS = list(itertools.product(SHAPES, GROUP_SIZES, AXES, DTYPES, TRANSPOSED, TRITON_KERNEL_TYPE)) @@ -83,7 +76,7 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant #Pack uint8 W_q, then run fused dequant matmul packed_w = pack_2xint4(W_q) tt_out = triton_mixed_mm( - x, packed_w, scales, zeros, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type + x, packed_w, scales, zeros, transposed=True, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type ) else: x = torch.randn(M, K, dtype=dtype, device="cuda") @@ -91,10 +84,15 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant packed_w = pack_2xint4(W_q.T) tt_out = triton_mixed_mm( - x, packed_w, scales.T, zeros.T, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type + x, packed_w, scales.T, zeros.T, transposed=False, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type ) cfg_str = f"Test config {shape} {group_size} {dtype} {transposed} {kernel_type}" + # print(cfg_str) + # print("packed_w", packed_w.shape) + # print("hqq_out", hqq_out.shape) + # print("tt_out", tt_out.shape) + check(hqq_out, tt_out, cfg_str + " triton", max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3) # if dtype == torch.bfloat16: diff --git a/torchao/prototype/hqq/kernels.py b/torchao/prototype/hqq/kernels.py index 2065e4af47..fe42e38240 100644 --- a/torchao/prototype/hqq/kernels.py +++ b/torchao/prototype/hqq/kernels.py @@ -148,7 +148,8 @@ def init_to_zero(name): MIXED_MM_HEURISTICS = { "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, - "BLOCK_K": lambda args: min(args["BLOCK_K"], args["QGROUP_SIZE"]), + "BLOCK_K": lambda args: min(args["BLOCK_K"], args["QGROUP_SIZE"]) if not args["TRANSPOSED"] else args["BLOCK_K"], + "BLOCK_N": lambda args: min(args["BLOCK_N"], args["QGROUP_SIZE"]) if args["TRANSPOSED"] else args["BLOCK_N"], "SPLIT_K": lambda args: 1 if args["IS_BFLOAT16"] else args["SPLIT_K"], # atomic add not supported for bfloat16 @@ -185,6 +186,7 @@ def _mixed_mm_kernel( BLOCK_K: tl.constexpr, # = 16, # SPLIT_K: tl.constexpr, # = 1, EVEN_K: tl.constexpr, # = True, + TRANSPOSED: tl.constexpr = False, GROUP_M: tl.constexpr = 8, # 32, # tl.dot options acc_dtype: tl.constexpr = tl.float32, @@ -206,8 +208,11 @@ def _mixed_mm_kernel( """ # tl.static_assert(B.dtype.element_ty == tl.int8 or B.dtype.element_ty == tl.uint8) - tl.static_assert(QGROUP_SIZE % BLOCK_K == 0) - + if not TRANSPOSED: + tl.static_assert(QGROUP_SIZE % BLOCK_K == 0) + else: + tl.static_assert(QGROUP_SIZE % BLOCK_N == 0) + # Threadblock swizzling pid = tl.program_id(0) pid_z = tl.program_id(1) @@ -234,9 +239,22 @@ def _mixed_mm_kernel( A = A + (ram[:, None] * stride_am + rak[None, :] * stride_ak) B = B + (rbk[:, None] * stride_bk + rbn[None, :] * stride_bn) - scale_offset_n = pid_n * stride_scale_n * BLOCK_N - offsets_scale_n = scale_offset_n + tl.arange(0, BLOCK_N) * stride_scale_n - + #In the forward pass, we have a K x N matrix + #In the transposed (backward) pass, we have an N x K matrix + #Grouping is along K, so in the forward pass, each block loads a row vector of BLK_K x BLK_N + #where grouping varies along N, hence the mainloop marches down the K dimension, where + #group idx is given by K // QGROUP_SIZE + # FOr the transposed case, we load a column vector of BLK_N x BLK_K + # we march down the N dimension during the mainloop + # Hence blocks now load K // QGROUP_SIZE (slow varying) + # while each block now loads differen groups on each main loop iteration + # scale offsets is thus a single idx along N and range along K + if not TRANSPOSED: + # scale_offset_n = pid_n * stride_scale_n * BLOCK_N + offsets_scale_n = pid_n * stride_scale_n * BLOCK_N + tl.arange(0, BLOCK_N) * stride_scale_n + else: + offsets_scale_n = pid_n * stride_scale_n * BLOCK_N // QGROUP_SIZE + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): if EVEN_K: @@ -250,7 +268,11 @@ def _mixed_mm_kernel( a = tl.load(A, mask=rak[None, :] < k_remaining_a, other=_0) qb = tl.load(B, mask=rbk[:, None] < k_remaining_b, other=_0) - scale_offset_k = k * BLOCK_K * SPLIT_K * stride_scale_k // QGROUP_SIZE + if not TRANSPOSED: + scale_offset_k = k * BLOCK_K * SPLIT_K * stride_scale_k // QGROUP_SIZE + else: + scale_offset_k = k * BLOCK_K * SPLIT_K * stride_scale_k + tl.arange(0, BLOCK_K) * stride_scale_k + scales = tl.load(scales_ptr + offsets_scale_n + scale_offset_k) zeros = tl.load(zeros_ptr + offsets_scale_n + scale_offset_k) @@ -283,7 +305,16 @@ def _mixed_mm_kernel( # Scale upcasted weights # Note that we broadcast the scales --> the assumption is that all scales fall within a single QGROUP # This condition is statically check (see assertions above) - dq_b = (dq_b - zeros[None, :]) * scales[None, :] + if not TRANSPOSED: + zeros = zeros[None, :] + scales = scales[None, :] + else: + zeros = zeros[:, None] + scales = scales[:, None] + + dq_b = (dq_b - zeros) * scales + + # dq_b = (dq_b - zeros[None, :]) * scales[None, :] if fp8_fast_accum: acc = tl.dot( diff --git a/torchao/prototype/hqq/mixed_mm.py b/torchao/prototype/hqq/mixed_mm.py index d56fa582b6..c249b7477d 100644 --- a/torchao/prototype/hqq/mixed_mm.py +++ b/torchao/prototype/hqq/mixed_mm.py @@ -31,6 +31,7 @@ def triton_mixed_mm( scales, zeros, group_size, + transposed=False, acc_dtype=None, input_precision="ieee", fp8_fast_accum=False, @@ -50,8 +51,9 @@ def triton_mixed_mm( M, K = a.shape _, N = b.shape - assert scales.shape[1] == N - assert scales.shape[0] == K // group_size + # N = b.shape[1] if not transposed else b.shape[0] + # assert scales.shape[1] == N if not transposed else scales.shape[0] == N + # assert scales.shape[0] == K // group_size if not transposed else scales.shape[1] == K // group_size assert scales.dtype == a.dtype assert scales.shape == zeros.shape assert zeros.dtype == a.dtype @@ -88,6 +90,7 @@ def triton_mixed_mm( c.stride(1), scales.stride(0), scales.stride(1), + TRANSPOSED=transposed, IS_BFLOAT16=a.dtype == torch.bfloat16, QGROUP_SIZE=group_size, acc_dtype=acc_dtype, From e0f378162f9c37baae1720548a7dc0648203cbd1 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 21:35:53 +0000 Subject: [PATCH 08/18] refactor test --- test/hqq/test_triton_mm.py | 49 +++++++++++++------------------------- 1 file changed, 17 insertions(+), 32 deletions(-) diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index de5286d43c..181d6c61da 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -1,7 +1,6 @@ import itertools - +import pytest import torch -from termcolor import colored from hqq.core.quantize import HQQLinear, BaseQuantizeConfig from hqq.kernels.custom_quant.triton import triton_mixed_mm, pack_2xint4 @@ -26,22 +25,29 @@ "optimize": True, "view_as_float": False, "nbits": 4, - # "quant_dtype": torch.uint8, "bitpack": False, "axis": 1, } -def check(expected, actual, cfg_str, max_diff=1e-3): +def check(expected, actual, msg="", max_diff=1e-3, verbose=False): passed = torch.allclose(expected, actual, atol=max_diff, rtol=max_diff) - max_err = (expected - actual).abs().max() - if not passed: - print(colored(f"{cfg_str}: Failed! Max error: {max_err}", "red", attrs=["bold"])) - else: - print(colored(f"{cfg_str}: Passed! Max error: {max_err}", "green", attrs=["bold"])) + if verbose: + max_err = (expected - actual).abs().max() + if not passed: + print(f"{msg}: Failed! Max error: {max_err}") + else: + print(f"{msg}: Passed! Max error: {max_err}") + + return passed +def _arg_to_id(arg): + if isinstance(arg, list): + return "x".join([str(x) for x in arg]) + return str(arg) + +@pytest.mark.parametrize("shape, group_size, axis, dtype, transposed, kernel_type", TEST_CONFIGS, ids=_arg_to_id) def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8): - # print(f"Test: {shape}, {group_size}, {axis}, {dtype}") qcfg = { **BASE_QUANT_CONFIG, **dict(group_size=group_size, axis=axis), @@ -67,7 +73,6 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant scales, zeros = meta["scale"], meta["zero"] scales = scales.reshape(N, -1) zeros = zeros.reshape(N, -1) - if transposed: x = torch.randn(M, N, dtype=dtype, device="cuda") @@ -87,25 +92,5 @@ def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant x, packed_w, scales.T, zeros.T, transposed=False, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type ) - cfg_str = f"Test config {shape} {group_size} {dtype} {transposed} {kernel_type}" - # print(cfg_str) - # print("packed_w", packed_w.shape) - # print("hqq_out", hqq_out.shape) - # print("tt_out", tt_out.shape) - - check(hqq_out, tt_out, cfg_str + " triton", max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3) - - # if dtype == torch.bfloat16: - # _ = quant_config["weight_quant_params"].pop("bitpack") - # hqq_int4mm = HQQLinearTorchWeightOnlyInt4( - # linear, quant_config, compute_dtype=dtype, del_orig=False - # ) - # hqq_int4_out = hqq_int4mm.forward(x) - # err = (hqq_int4_out - hqq_out).abs().max() - # check(hqq_out, hqq_int4_out, cfg_str + " torch_tinygemm", max_diff=1e-2) - - print() - + assert check(hqq_out, tt_out, max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3) -for test in TEST_CONFIGS: - test_mixed_mm(*test) From be718d68974c8d40efcebf7f5d5b52489251a47d Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 21:39:40 +0000 Subject: [PATCH 09/18] add checks for CI --- test/hqq/test_triton_mm.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index 181d6c61da..a99923624d 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -1,3 +1,11 @@ +# Skip entire test if triton is not available, otherwise CI failure +try: + import triton + if int(triton.__version__.split(".")[0]) < 3: + pytest.skip("triton >= 3.0.0 is required to run this test", allow_module_level=True) +except ImportError: + pytest.skip("triton is not installed", allow_module_level=True) + import itertools import pytest import torch From 793994bd3b59c0bb70adb0f6ffce8aa4b634a1ff Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 21:43:32 +0000 Subject: [PATCH 10/18] add more comments for transpose kernel --- torchao/prototype/hqq/kernels.py | 16 ++++++++++------ torchao/prototype/hqq/mixed_mm.py | 3 ++- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/torchao/prototype/hqq/kernels.py b/torchao/prototype/hqq/kernels.py index fe42e38240..2b0c1958af 100644 --- a/torchao/prototype/hqq/kernels.py +++ b/torchao/prototype/hqq/kernels.py @@ -240,15 +240,19 @@ def _mixed_mm_kernel( B = B + (rbk[:, None] * stride_bk + rbn[None, :] * stride_bn) #In the forward pass, we have a K x N matrix - #In the transposed (backward) pass, we have an N x K matrix + #In the transposed (backward) pass, we have an N x K matrix, where N and K refer to the how the weight was originally quantized + #note that N refers to offsets_scale_k and K refers to offsets_scale_n when it comes to the gemm indexing logic below + #Grouping is along K, so in the forward pass, each block loads a row vector of BLK_K x BLK_N #where grouping varies along N, hence the mainloop marches down the K dimension, where #group idx is given by K // QGROUP_SIZE - # FOr the transposed case, we load a column vector of BLK_N x BLK_K - # we march down the N dimension during the mainloop - # Hence blocks now load K // QGROUP_SIZE (slow varying) - # while each block now loads differen groups on each main loop iteration - # scale offsets is thus a single idx along N and range along K + + # For the transposed case, we load a column vector of BLK_N x BLK_K + # we march down the N dimension during the mainloop ("K" in gemm) + # Hence blocks now load K // QGROUP_SIZE along pid_n (slow varying) + # while each block now loads column vector of groups along "K" gemm dim on each main loop iteration + # scale offsets is thus a single idx along "N" and range along "K" for the transposed case + if not TRANSPOSED: # scale_offset_n = pid_n * stride_scale_n * BLOCK_N offsets_scale_n = pid_n * stride_scale_n * BLOCK_N + tl.arange(0, BLOCK_N) * stride_scale_n diff --git a/torchao/prototype/hqq/mixed_mm.py b/torchao/prototype/hqq/mixed_mm.py index c249b7477d..df6210b1c4 100644 --- a/torchao/prototype/hqq/mixed_mm.py +++ b/torchao/prototype/hqq/mixed_mm.py @@ -2,7 +2,8 @@ from triton import cdiv import triton.language as tl from .kernels import mixed_mm_kernel_compute_bound, mixed_mm_kernel_max_autotune -#credit jlebar + +#credit jlebar from triton slack discord discussion def pack_2xint4(t): """ The packing format is such that consecutive rows are packed into a lower / upper bits From 5c585b9c2a8038bcc4bb0ede89f32105adb41404 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 21:48:33 +0000 Subject: [PATCH 11/18] remove import in test --- test/hqq/test_triton_mm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index a99923624d..27fcf75772 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -11,7 +11,6 @@ import torch from hqq.core.quantize import HQQLinear, BaseQuantizeConfig -from hqq.kernels.custom_quant.triton import triton_mixed_mm, pack_2xint4 from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4 From c80b2398995666981e1893f3ef496a79f2fce98d Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 22:56:13 +0000 Subject: [PATCH 12/18] clean up benchmark --- benchmarks/benchmark_hqq.py | 39 ++++++++++++++++++++----------- test/hqq/test_triton_mm.py | 3 ++- torchao/prototype/hqq/kernels.py | 6 ++--- torchao/prototype/hqq/mixed_mm.py | 2 +- 4 files changed, 32 insertions(+), 18 deletions(-) diff --git a/benchmarks/benchmark_hqq.py b/benchmarks/benchmark_hqq.py index a51401a3ab..393481e95b 100644 --- a/benchmarks/benchmark_hqq.py +++ b/benchmarks/benchmark_hqq.py @@ -1,5 +1,14 @@ + +try: + import triton + import hqq + if int(triton.__version__.split(".")[0]) < 3: + raise "triton >= 3.0.0 is required to run this test" +except ImportError: + raise "triton and hqq required to run this benchmark" + import torch -from termcolor import colored +from io import StringIO import pandas as pd from hqq.core.quantize import HQQLinear, BaseQuantizeConfig @@ -85,12 +94,12 @@ def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8): ) int4_time = bench_hqq(x, hqq_int4mm) - print(colored(f"{shape=} {group_size=} {dtype=}:", attrs=["bold"])) + print(f"{shape=} {group_size=} {dtype=}:") print( - colored(f"Ref: {ref_time:.4f}", "blue"), - colored(f"Triton: {tt_time:.4f}", "green"), - colored(f"Torch int4mm: {int4_time:.4f}", "yellow") + f"Ref: {ref_time:.4f}", + f"Triton: {tt_time:.4f}", + f"Torch int4mm: {int4_time:.4f}" if dtype == torch.bfloat16 else "", ) @@ -110,7 +119,6 @@ def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8): DTYPES = [torch.bfloat16] # , torch.float16] GROUP_SIZES = [128] -print(torch.cuda.get_device_properties(0)) HEADERS = [ "M", @@ -123,12 +131,17 @@ def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8): "tinygemm", ] data = [] -for shape in SHAPES: - for group_size in GROUP_SIZES: - for dtype in DTYPES: - timings = run_benchmark(shape, group_size, dtype) - data.append((*shape, group_size, dtype, *timings)) +if __name__ == "__main__": + print(torch.cuda.get_device_properties(0)) + + for shape in SHAPES: + for group_size in GROUP_SIZES: + for dtype in DTYPES: + timings = run_benchmark(shape, group_size, dtype) + data.append((*shape, group_size, dtype, *timings)) -df = pd.DataFrame(data, columns=HEADERS) -df.to_csv("benchmark_triton.csv", index=False) + output = StringIO() + df = pd.DataFrame(data, columns=HEADERS) + df.to_csv(output, index=False) + print(output.getvalue()) \ No newline at end of file diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index 27fcf75772..7892b6ff17 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -1,10 +1,11 @@ # Skip entire test if triton is not available, otherwise CI failure try: import triton + import hqq if int(triton.__version__.split(".")[0]) < 3: pytest.skip("triton >= 3.0.0 is required to run this test", allow_module_level=True) except ImportError: - pytest.skip("triton is not installed", allow_module_level=True) + pytest.skip("triton and hqq required to run this test", allow_module_level=True) import itertools import pytest diff --git a/torchao/prototype/hqq/kernels.py b/torchao/prototype/hqq/kernels.py index 2b0c1958af..077fc94108 100644 --- a/torchao/prototype/hqq/kernels.py +++ b/torchao/prototype/hqq/kernels.py @@ -280,14 +280,14 @@ def _mixed_mm_kernel( scales = tl.load(scales_ptr + offsets_scale_n + scale_offset_k) zeros = tl.load(zeros_ptr + offsets_scale_n + scale_offset_k) - # Unpack qweights -- credit jlebar! + # Unpack qweights -- h/t jlebar! _4_i8 = tl.full((1,), 4, dtype=tl.int8) qb_lo = (qb << _4_i8) >> _4_i8 qb_hi = qb >> _4_i8 # Upcast to fp16 - # TODO add bfloat16 - if IS_BFLOAT16: + # TODO: better bfloat16 conversion? compilation error if direct conversion from int8 to bfloat16 + if IS_BFLOAT16: dq_b = ( tl.join( qb_lo.to(tl.float16).to(A.dtype.element_ty), diff --git a/torchao/prototype/hqq/mixed_mm.py b/torchao/prototype/hqq/mixed_mm.py index df6210b1c4..099698ffd2 100644 --- a/torchao/prototype/hqq/mixed_mm.py +++ b/torchao/prototype/hqq/mixed_mm.py @@ -3,7 +3,7 @@ import triton.language as tl from .kernels import mixed_mm_kernel_compute_bound, mixed_mm_kernel_max_autotune -#credit jlebar from triton slack discord discussion +#h/t jlebar for the bit packing / unpacking logic (source: Triton Slack thread) def pack_2xint4(t): """ The packing format is such that consecutive rows are packed into a lower / upper bits From 48a153fff8b1e0d197b5f310f3c1be97c35b1fdf Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 23:17:13 +0000 Subject: [PATCH 13/18] fix test import order --- test/hqq/test_triton_mm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py index 7892b6ff17..23f6c60f70 100644 --- a/test/hqq/test_triton_mm.py +++ b/test/hqq/test_triton_mm.py @@ -1,4 +1,5 @@ # Skip entire test if triton is not available, otherwise CI failure +import pytest try: import triton import hqq @@ -8,7 +9,6 @@ pytest.skip("triton and hqq required to run this test", allow_module_level=True) import itertools -import pytest import torch from hqq.core.quantize import HQQLinear, BaseQuantizeConfig From b3a9ab8697e22fcacc8bc173335866216ab3732b Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 22 Apr 2024 23:55:03 +0000 Subject: [PATCH 14/18] minor README edits --- torchao/prototype/hqq/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index d45a3bd888..6e558c2a78 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -2,10 +2,11 @@ Fused gemm for asymmetric quantized weights. Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. -The kernel packs `u4 / s4` weights and fuses dequantization with the matmul. +The kernel takes packed `u4 / s4` weights and fuses dequantization with matmul. +- bitpacking is simple row interleave, no need for extensive preprocessing (e.g., `tinygemm` or `fastertransformer`) - tested for `float16 / bfloat16` activations, scales, and zeros -- autotuned for both compute-bound and io-bound configs +- autotuned for both compute-bound and memory-bound configs - assumes operand B of the `gemm` is is the quantized type. - requires quantization along `in-features`, i.e., the `K` dimension, or `axis=1`, of `torch.linear.weight`. From bbe9083d9f87bffdc62941a43932265854f487f8 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Tue, 23 Apr 2024 00:01:55 +0000 Subject: [PATCH 15/18] additional readme edits --- torchao/prototype/hqq/README.md | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 6e558c2a78..dbc9cf8f93 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -1,18 +1,24 @@ ## Fused `int4 / fp16` Quant Matmul -Fused gemm for asymmetric quantized weights. Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. +Fused kernel that combines asymmetric dequantization and gemm: -The kernel takes packed `u4 / s4` weights and fuses dequantization with matmul. +- Dequantization: upcasts `u4 / s4` weights to `float16 / bfloat16`, followed by groupwise scaling and shifting by scales / zeropoints +- GEMM: standard matmul on dequantized weights and activations. -- bitpacking is simple row interleave, no need for extensive preprocessing (e.g., `tinygemm` or `fastertransformer`) -- tested for `float16 / bfloat16` activations, scales, and zeros -- autotuned for both compute-bound and memory-bound configs -- assumes operand B of the `gemm` is is the quantized type. -- requires quantization along `in-features`, i.e., the `K` dimension, or `axis=1`, of `torch.linear.weight`. +Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. + +### Implementation Details + +- Bitpacking is simple row interleave, no need for extensive preprocessing (e.g., `tinygemm` or `fastertransformer`) +- Tested for `float16 / bfloat16` activations, scales, and zeros +- Autotuned for both compute-bound and memory-bound configs +- Assumes operand B of the `gemm` is is the quantized type. +- Requires quantization along `in-features`, i.e., the `K` dimension, or `axis=1`, of `torch.linear.weight`. +- Implementation handles both transposed and non-tranposed quantized weights, useful for forward / backward training passes. ### Performance -Initial benchmarking demonstrates promising results, scaling well across io-bound and compute-bound workloads: +Initial benchmarking demonstrates promising results, scaling well across memory-bound and compute-bound workloads: | | M | N | K | group_size | dtype | hqq_ref | triton | tinygemm | | --- | ---- | ---- | ---- | ---------- | -------------- | ------- | ------ | -------- | From 38dbc3e48af1a7baaca8d4eca57696905b208a26 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Wed, 24 Apr 2024 19:25:16 +0000 Subject: [PATCH 16/18] update readme --- torchao/prototype/hqq/README.md | 8 +++++--- torchao/prototype/hqq/mixed_mm.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index dbc9cf8f93..709f97ea22 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -1,9 +1,11 @@ ## Fused `int4 / fp16` Quant Matmul -Fused kernel that combines asymmetric dequantization and gemm: +Fused kernel that combines asymmetric dequantization and gemm. Useful primarily for compute-bound (M > 16) scenarios and not for memory-bound / inference scenarios. + +The kernel fuses two ops: - Dequantization: upcasts `u4 / s4` weights to `float16 / bfloat16`, followed by groupwise scaling and shifting by scales / zeropoints -- GEMM: standard matmul on dequantized weights and activations. +- GEMM: matmul on dequantized weights and activations. Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. @@ -18,7 +20,7 @@ Tested and benchmarked for `HQQ` but could theoretically be used for any asymmet ### Performance -Initial benchmarking demonstrates promising results, scaling well across memory-bound and compute-bound workloads: +Initial benchmarking (on `A6000`) demonstrates promising results, scaling well for compute-bound workloads: | | M | N | K | group_size | dtype | hqq_ref | triton | tinygemm | | --- | ---- | ---- | ---- | ---------- | -------------- | ------- | ------ | -------- | diff --git a/torchao/prototype/hqq/mixed_mm.py b/torchao/prototype/hqq/mixed_mm.py index 099698ffd2..e3ccaeb46e 100644 --- a/torchao/prototype/hqq/mixed_mm.py +++ b/torchao/prototype/hqq/mixed_mm.py @@ -4,6 +4,7 @@ from .kernels import mixed_mm_kernel_compute_bound, mixed_mm_kernel_max_autotune #h/t jlebar for the bit packing / unpacking logic (source: Triton Slack thread) +#https://gist.github.com/jlebar/3435b2c00deea53258887ce37231e5e2 def pack_2xint4(t): """ The packing format is such that consecutive rows are packed into a lower / upper bits From d89fa74b46a158aa272730f89b29cc46721d9857 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 25 Apr 2024 02:00:10 +0000 Subject: [PATCH 17/18] update readme --- torchao/prototype/hqq/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 709f97ea22..71a2b8f3eb 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -9,6 +9,9 @@ The kernel fuses two ops: Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. +> **NOTE**: Benchmark below is only indicative of performance on consumer-grade `Ampere` GPUs (`A6000` specifically). When tested on `H100`, the performance is on par / marginally worse than native / compiled `torch`. +> The intended use is thus for fine-tuning / training models on non-datacenter GPUs (`80 <= compute capability < 90`). + ### Implementation Details - Bitpacking is simple row interleave, no need for extensive preprocessing (e.g., `tinygemm` or `fastertransformer`) From cd68d38c9347d62277edeaaae40e755cf87c0890 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 25 Apr 2024 02:04:38 +0000 Subject: [PATCH 18/18] add note about cudamode --- torchao/prototype/hqq/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 71a2b8f3eb..22c40fd246 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -10,7 +10,7 @@ The kernel fuses two ops: Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. > **NOTE**: Benchmark below is only indicative of performance on consumer-grade `Ampere` GPUs (`A6000` specifically). When tested on `H100`, the performance is on par / marginally worse than native / compiled `torch`. -> The intended use is thus for fine-tuning / training models on non-datacenter GPUs (`80 <= compute capability < 90`). +> The intended use is thus for fine-tuning / training models on non-datacenter GPUs (`80 <= compute capability < 90`). If interested in optimizing the kernel for other architectures, please drop a note in the CUDA-MODE Discord channel. ### Implementation Details