diff --git a/.github/workflows/regression_test.yml b/.github/workflows/regression_test.yml index d9649b7f7..0488e6d92 100644 --- a/.github/workflows/regression_test.yml +++ b/.github/workflows/regression_test.yml @@ -70,6 +70,7 @@ jobs: torch-spec: 'torch==2.5.1 --index-url https://download.pytorch.org/whl/cu121' gpu-arch-type: "cuda" gpu-arch-version: "12.1" + - name: CPU 2.3 runs-on: linux.4xlarge torch-spec: 'torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu' diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index e049500e3..9e9144c60 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -8,7 +8,7 @@ run_tests, ) -from torchao.dtypes import SemiSparseLayout +from torchao.dtypes import Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, int4_weight_only, @@ -17,12 +17,12 @@ int8_weight_only, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) -def get_quantization_functions(do_sparse: bool, do_int4: bool): +def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"): base_functions = [ int8_weight_only(), int8_dynamic_activation_int4_weight(), @@ -30,7 +30,12 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool): int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC), ] if do_int4: - base_functions.append(int4_weight_only(group_size=32)) + if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: + base_functions.append( + int4_weight_only(group_size=32, layout=Int4CPULayout()) + ) + else: + base_functions.append(int4_weight_only(group_size=32)) if do_sparse: base_functions.append( @@ -152,30 +157,28 @@ class TestAffineQuantizedBasic(TestCase): COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) COMMON_DTYPES = [torch.bfloat16] - @common_utils.parametrize("apply_quant", get_quantization_functions(False, True)) @common_utils.parametrize("device", COMMON_DEVICES) @common_utils.parametrize("dtype", COMMON_DTYPES) - def test_flatten_unflatten(self, apply_quant, device, dtype): - if device == "cpu": - self.skipTest(f"Temporarily skipping for {device}") - - linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) - ql = apply_quant(linear) - lp_tensor = ql.weight - tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() - tensor_data_dict = { - name: getattr(lp_tensor, name) for name in tensor_data_name_dict - } - outer_size = lp_tensor.size() - outer_stride = lp_tensor.stride() - reconstructed = type(lp_tensor).__tensor_unflatten__( - tensor_data_dict, tensor_attributes, outer_size, outer_stride - ) - example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) - ref = ql(*example_inputs) - ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) - reconstruct_res = ql(*example_inputs) - self.assertEqual(reconstruct_res, ref) + def test_flatten_unflatten(self, device, dtype): + apply_quant_list = get_quantization_functions(False, True, device) + for apply_quant in apply_quant_list: + linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) + ql = apply_quant(linear) + lp_tensor = ql.weight + tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() + tensor_data_dict = { + name: getattr(lp_tensor, name) for name in tensor_data_name_dict + } + outer_size = lp_tensor.size() + outer_stride = lp_tensor.stride() + reconstructed = type(lp_tensor).__tensor_unflatten__( + tensor_data_dict, tensor_attributes, outer_size, outer_stride + ) + example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) + ref = ql(*example_inputs) + ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) + reconstruct_res = ql(*example_inputs) + self.assertEqual(reconstruct_res, ref) common_utils.instantiate_parametrized_tests(TestAffineQuantized) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 663db20b7..df20c5f03 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -19,7 +19,7 @@ from torchao.quantization.dynamic_quant import ( DynamicallyPerAxisQuantizedLinear, ) -from torchao.dtypes import TensorCoreTiledLayout +from torchao.dtypes import TensorCoreTiledLayout, Int4CPULayout from torchao.quantization.quant_api import ( int4_weight_only, int8_weight_only, @@ -93,6 +93,7 @@ is_fbcode, benchmark_model ) +from torchao.dtypes.utils import is_device logger = logging.getLogger("INFO") @@ -133,7 +134,10 @@ def _int8da_int8w_api(mod): change_linear_weights_to_int8_dqtensors(mod) def _int4wo_api(mod): - if TORCH_VERSION_AT_LEAST_2_4: + if is_device(next(mod.parameters()).device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + quantize_(mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False) + unwrap_tensor_subclass(mod) + elif TORCH_VERSION_AT_LEAST_2_4: quantize_(mod, int4_weight_only(), set_inductor_config=False) if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(mod) @@ -935,10 +939,16 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): self.skipTest(f"Temporarily skipping for {device}") if dtype != torch.bfloat16: self.skipTest(f"Fails for {dtype}") + layout_list = [] + if device == 'cpu' and TORCH_VERSION_AT_LEAST_2_6: + layout_list.append(Int4CPULayout()) + else: + for inner_k_tiles in [4, 2]: + layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)) for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])): for groupsize in [64, 32]: - for inner_k_tiles in [4, 2]: - kwargs = {"groupsize": groupsize, "layout": TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)} + for layout in layout_list: + kwargs = {"groupsize": groupsize, "layout": layout} def api(mod): kwargs_copy = kwargs.copy() diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 4e0663eb8..78556772d 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -33,6 +33,7 @@ TORCH_VERSION_AT_LEAST_2_6, is_fbcode, ) +from torchao.dtypes.utils import is_device _SEED = 1234 torch.manual_seed(_SEED) @@ -102,7 +103,8 @@ def _groupwise_affine_quantize_tensor_from_qparams( .reshape_as(w) ) if TORCH_VERSION_AT_LEAST_2_5: - w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) + if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) return w_int4x8 @@ -524,8 +526,10 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): groupsize = 128 if TORCH_VERSION_AT_LEAST_2_5: - input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize) + input_tmp = input + if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_tmp, scales, zeros, n_bit, groupsize) else: w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d1fbacdcb..00305db34 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -16,6 +16,7 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( BlockSparseLayout, + Int4CPULayout, MarlinQQQLayout, MarlinSparseLayout, SemiSparseLayout, @@ -48,4 +49,5 @@ "UintxLayout", "MarlinQQQTensor", "MarlinQQQLayout", + "Int4CPULayout", ] diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index a6059f93a..8fba2bb67 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -11,6 +11,7 @@ SemiSparseLayout, ) from .tensor_core_tiled_layout import ( + Int4CPULayout, TensorCoreTiledLayout, ) from .uintx_layout import ( @@ -23,5 +24,6 @@ "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", + "Int4CPULayout", "MarlinQQQLayout", ] diff --git a/torchao/dtypes/uintx/tensor_core_tiled_layout.py b/torchao/dtypes/uintx/tensor_core_tiled_layout.py index ced3fc892..df79b653e 100644 --- a/torchao/dtypes/uintx/tensor_core_tiled_layout.py +++ b/torchao/dtypes/uintx/tensor_core_tiled_layout.py @@ -13,7 +13,12 @@ ) from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device from torchao.quantization.quant_primitives import ZeroPointDomain, _get_reduction_params -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, fill_defaults, find_multiple +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_6, + fill_defaults, + find_multiple, +) aten = torch.ops.aten @@ -71,9 +76,14 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm( - act_mat.contiguous(), packed_weight, groupsize, scale_and_zero - ) + if is_device(input_tensor.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + y = torch.ops.aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) + else: + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] @@ -383,3 +393,251 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def get_layout(self) -> Layout: return self._layout + + +@dataclass(frozen=True) +class Int4CPULayout(Layout): + """Only for PyTorch version at least 2.6""" + + pass + + +@register_layout(Int4CPULayout) +class Int4CPUAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for int4 CPU layout for affine quantized tensor, this is for int4 only, + used by tinygemm kernels `_weight_int4pack_mm_for_cpu` + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 2] (uint8 dtype) + (unpacked Tensor shape is n * k) + Note: we also pack scale and zero point together here for tinygemm kernel + Note: technically Int4 CPU layout should be the layout for the underlying packed weight + (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used + in plain layout, we just created a layout for AQT right now, this could be improved if we split out + int4 aqt into a separate tensor subclass + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout + scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale_and_zero: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scale_and_zero = scale_and_zero + self.transposed = False + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale_and_zero = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale_and_zero"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scale_and_zero, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, Int4CPULayout) + + if TORCH_VERSION_AT_LEAST_2_6: + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + int_data, + 1, # TODO:remove + ) + elif TORCH_VERSION_AT_LEAST_2_5: + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + assert ( + int_data.dtype == torch.uint8 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) + else: + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) + + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + return cls(packed_weight, scale_and_zero, False, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + if not is_device(torch.device(self.device).type, device): + raise ValueError( + f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" + ) + return self.__class__( + self.packed_weight.to(device), + self.scale_and_zero.to(device), + self.transposed, + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scale_and_zero), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = Int4CPUAQTTensorImpl( + args[0].packed_weight, + args[0].scale_and_zero, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + int_data, scale, zero_point = self.get_plain() + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + elif dim == 1: + int_data, scale, zero_point = self.get_plain() + assert step == 1, "Only step == 1 is supported in slicing right now" + data_len = int_data.shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return sliced + else: + raise NotImplementedError( + f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + quantize_affine, + ) + from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + + cur_shape = self.shape + assert len(cur_shape) == 2 + original_shape = (cur_shape[0], cur_shape[1] * 2) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[-2]) + block_size = (1, groupsize) + device = self.device + original_dtype = torch.bfloat16 + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + zero_point_domain = ZeroPointDomain.FLOAT + assert len(block_size) == 2 and block_size[0] == 1 + dequantized = torch.ops.aten._weight_int4pack_mm_for_cpu( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) + dequantized = dequantized.t().contiguous() + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + return int_data, scale, zero + + def get_layout(self) -> Layout: + return self._layout diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 8bf1d3426..1bdbcd96e 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -83,7 +83,7 @@ Initial benchmarking (on `A6000`) demonstrates promising results, scaling well f - Times are in `ms`, see `benchmarks/benchmark_hqq.py`. - `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul). -- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. +- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm` or `torch.ops.aten._weight_int4pack_mm_for_cpu`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. GPU details: diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index 8abdad039..743c6128a 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -12,7 +12,8 @@ from hqq.core.utils import * import torch.nn.functional as F -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 +from torchao.dtypes.utils import is_device class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -162,9 +163,14 @@ def process_hqq_quants(self, W_q, meta): W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants( W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits ) - self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - W_q_torch, self.inner_k_tiles - ) + if is_device(W_q.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + W_q_torch, self.inner_k_tiles + ) + else: + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + W_q_torch, self.inner_k_tiles + ) self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch) del W_q_torch, scales_torch, zeros_torch @@ -200,7 +206,8 @@ def hqq_quants_to_torch_quants( .contiguous() ) if TORCH_VERSION_AT_LEAST_2_5: - W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) + if not is_device(W_q.device.type, "cpu"): + W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams # W_r = W_q*scales + min_val @@ -232,9 +239,14 @@ def pack_scales_and_zeros(self, scales, zeros): def matmul(self, x): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm( - x, self.weight_int4pack, self.groupsize, self.scales_and_zeros - ) + if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + c = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) + else: + c = torch.ops.aten._weight_int4pack_mm( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) new_shape = origin_x_size[:-1] + (self.out_features,) c = c.reshape(new_shape) return c diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index dc68f59ce..c169271e8 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -17,8 +17,10 @@ import torch.nn.functional as F from torch.utils._pytree import tree_flatten, tree_unflatten +from torchao.dtypes.utils import is_device from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, + TORCH_VERSION_AT_LEAST_2_6, find_multiple, ) @@ -537,12 +539,20 @@ def linear_forward_int4( ): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - c = torch.ops.aten._weight_int4pack_mm( - x.to(precision), - weight_int4pack, - groupsize, - scales_and_zeros.to(scales_precision), - ).to(dtype=x.dtype) + if is_device(x.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + c = torch.ops.aten._weight_int4pack_mm_for_cpu( + x.to(precision), + weight_int4pack, + groupsize, + scales_and_zeros.to(scales_precision), + ).to(dtype=x.dtype) + else: + c = torch.ops.aten._weight_int4pack_mm( + x.to(precision), + weight_int4pack, + groupsize, + scales_and_zeros.to(scales_precision), + ).to(dtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c @@ -591,19 +601,32 @@ def __init__( assert ( in_features % (inner_k_tiles * 16) == 0 ), "require in_features % (innerKTiles * 16) == 0" - self.register_buffer( - "weight", - torch.zeros( - ( - out_features // 8, - in_features // (inner_k_tiles * 16), - 32, - inner_k_tiles // 2, + if is_device(device.type, "cpu"): + self.register_buffer( + "weight", + torch.zeros( + ( + out_features, + in_features // 2, + ), + dtype=torch.uint8, + device=device, ), - dtype=torch.int32, - device=device, - ), - ) + ) + else: + self.register_buffer( + "weight", + torch.zeros( + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), + dtype=torch.int32, + device=device, + ), + ) self.dtype = dtype self.register_buffer( "scales_and_zeros", @@ -760,9 +783,19 @@ def _create_quantized_state_dict( self.precision, # dtype for scales_and_zeros ) # TODO: just get the device from mod.weight.device? - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - w_int4x8.to(self.device), self.inner_k_tiles - ) + if ( + is_device(w_int4x8.device.type, "cpu") + and TORCH_VERSION_AT_LEAST_2_6 + ): + weight_int4pack = ( + torch.ops.aten._convert_weight_to_int4pack_for_cpu( + w_int4x8.to(self.device), self.inner_k_tiles + ) + ) + else: + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + w_int4x8.to(self.device), self.inner_k_tiles + ) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( self.device @@ -846,9 +879,14 @@ def make_names_and_values_dict_func(q, qparams): # how much we need to pad the weight delta_k = int((new_k - k) / 2) q = q.to(self.device) - final_q = torch.ops.aten._convert_weight_to_int4pack( - F.pad(q, pad=(0, delta_k)), inner_k_tiles - ) + if is_device(self.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + final_q = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) + else: + final_q = torch.ops.aten._convert_weight_to_int4pack( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) scales = qparams[0].to(torch.bfloat16).to(self.device) zeros = qparams[1].to(torch.bfloat16).to(self.device) scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros) diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index cbe629640..d5f2dca5b 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -9,6 +9,7 @@ import torch import torch.nn.functional as F +from torchao.dtypes.utils import is_device from torchao.quantization.GPTQ import ( Int8DynActInt4WeightLinear, WeightOnlyInt4Linear, @@ -23,6 +24,7 @@ ) from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 from .api import FakeQuantizeConfig from .fake_quantizer import FakeQuantizer @@ -363,6 +365,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): inner_k_tiles=inner_k_tiles, precision=child.weight.dtype, scales_precision=config.scale_precision, + device=next(child.parameters()).device, ) setattr(module, name, quantized_linear) @@ -373,10 +376,19 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module): n_bit, config.group_size, ) - q_weight = torch.ops.aten._convert_weight_to_int4pack( - q_weight.to(child.weight.device), - child.inner_k_tiles, - ) + if ( + is_device(q_weight.device.type, "cpu") + and TORCH_VERSION_AT_LEAST_2_6 + ): + q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) + else: + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), + child.inner_k_tiles, + ) quantized_linear.weight = q_weight quantized_linear.scales_and_zeros = scales_and_zeros else: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c730ec904..ddeb4ef2f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -630,7 +630,8 @@ def int4_weight_only( "tensor_core_tiled" layout for speedup with tinygemm kernel Note: - This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`), the main difference + This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm` + and `torch.ops.aten._weight_int4pack_mm_for_cpu`), the main difference of quantization algorithm compared to the more traditional type of integer quantization is the following: 1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`) 2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 036109bc8..9715d99e0 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -8,6 +8,7 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.utils import is_device from torchao.quantization.utils import ( dequantize_per_channel, dynamically_quantize_per_channel, @@ -15,7 +16,7 @@ quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, ) -from torchao.utils import find_multiple +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, find_multiple __all__ = [ "Int8DynamicallyQuantizedLinearWeight", @@ -458,12 +459,20 @@ def _quantized_op(act_mat, w_qtensor, bias): act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1])) # matmul - y = aten._weight_int4pack_mm( - act_mat.contiguous(), - w_qtensor.int_data, - w_qtensor.groupsize, - w_qtensor.scales_and_zeros, - ) + if is_device(act_mat.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + y = aten._weight_int4pack_mm_for_cpu( + act_mat.contiguous(), + w_qtensor.int_data, + w_qtensor.groupsize, + w_qtensor.scales_and_zeros, + ) + else: + y = aten._weight_int4pack_mm( + act_mat.contiguous(), + w_qtensor.int_data, + w_qtensor.groupsize, + w_qtensor.scales_and_zeros, + ) # remove out_feature padding orig_out_features = ( @@ -609,5 +618,10 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): input_int4x8, scales_and_zeros = groupwise_affine_quantize_tensor( input_float, 4, groupsize, dtype=input_float.dtype ) - int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) + if is_device(input_float.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6: + int_data = aten._convert_weight_to_int4pack_for_cpu( + input_int4x8, inner_k_tiles + ) + else: + int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 9083dd762..e1cf98b54 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -9,6 +9,7 @@ import torch from torch.utils._python_dispatch import TorchDispatchMode +from torchao.dtypes.utils import is_device from torchao.kernel import ( int_scaled_matmul, ) @@ -19,7 +20,7 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 __all__ = [ "compute_error", @@ -402,13 +403,8 @@ def groupwise_affine_quantize_tensor_from_qparams( zero_point_domain=ZeroPointDomain.FLOAT, ) if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: - int_data_device_type = int_data.device.type - # Move to cpu, until issue with MPS memory management of temporary tensors is resolved - if int_data_device_type == "mps": - int_data = int_data.cpu() - int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - if int_data_device_type == "mps": - int_data = int_data.to(device="mps") + if not (is_device(int_data.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) return int_data @@ -422,8 +418,10 @@ def groupwise_affine_dequantize_tensor_from_qparams( assert groupsize > 1 assert w_int4x8.dim() == 2 # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path - if TORCH_VERSION_AT_LEAST_2_5 and ( - w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1 + if ( + TORCH_VERSION_AT_LEAST_2_5 + and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) + and not (is_device(w_int4x8.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6) ): data = w_int4x8.to(torch.int32) high_bits = data >> 4