From dc8a0a1988f724f292cca0742afb2bc831b32cc7 Mon Sep 17 00:00:00 2001 From: youn17 Date: Fri, 24 Oct 2025 21:29:05 +0900 Subject: [PATCH 01/19] introduce new int8 quantization API --- docs/source/quantization_overview.rst | 4 +- .../workflows/int8/test_int8_tensor.py | 266 ++++++++++++++++ torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 98 ++++-- .../common/quantize_tensor_kwargs.py | 8 + .../quantize_/workflows/__init__.py | 6 + .../quantize_/workflows/int8/int8_tensor.py | 300 ++++++++++++++++++ 7 files changed, 651 insertions(+), 33 deletions(-) create mode 100644 test/quantization/quantize_/workflows/int8/test_int8_tensor.py create mode 100644 torchao/quantization/quantize_/workflows/int8/int8_tensor.py diff --git a/docs/source/quantization_overview.rst b/docs/source/quantization_overview.rst index f5c82bfe5f..df0a924b11 100644 --- a/docs/source/quantization_overview.rst +++ b/docs/source/quantization_overview.rst @@ -5,7 +5,7 @@ First we want to lay out the torchao stack:: Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc. --------------------------------------------------------------------------------------------- - Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor + Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Int8Tensor, Float8Tensor --------------------------------------------------------------------------------------------- Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize --------------------------------------------------------------------------------------------- @@ -88,6 +88,8 @@ So in general we structure Tensor subclasses by dervied dtpype and packing forma - scaled int4 - preshuffled (special format to optimize for loading) - float8 act + int4 weight dynamic quantization and int4 weight only quantization + * - Int8Tensor + - plain .. note:: We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options. diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py new file mode 100644 index 0000000000..0fc959beb7 --- /dev/null +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -0,0 +1,266 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from torch.testing._internal import common_utils + +from torchao.quantization import ( + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + PerRow, + PerTensor, + quantize_, +) +from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine +from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( + Int8Tensor, + QuantizeTensorToInt8Kwargs, +) +from torchao.quantization.utils import compute_error +from torchao.testing.utils import TorchAOIntegrationTestCase + + +# TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged +class ToyTwoLinearModel(torch.nn.Module): + def __init__( + self, + input_dim, + hidden_dim, + output_dim, + has_bias=False, + dtype=None, + device=None, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.linear1 = torch.nn.Linear( + input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device + ) + self.linear2 = torch.nn.Linear( + hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@common_utils.instantiate_parametrized_tests +class TestInt8Tensor(TorchAOIntegrationTestCase): + def setUp(self): + super().setUp() + + self.test_shape = (4, 3) + self.dtype = torch.bfloat16 + self.batch_size = 32 + self.int8_min = -128 + self.int8_max = 127 + + torch.manual_seed(42) + self.weight_fp = torch.randn(*self.test_shape, dtype=self.dtype) + self.input_fp = torch.randn(*self.test_shape, dtype=self.dtype) + self.bias = torch.randn(self.test_shape[0], dtype=self.dtype) + self.block_size = list(self.test_shape) + + def test_creation_and_attributes(self): + """Test tensor creation, dtypes, and ranges""" + tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size) + + self.assertEqual(tensor.shape, self.test_shape) + self.assertEqual(tensor.qdata.dtype, torch.int8) + self.assertTrue( + torch.all(tensor.qdata >= self.int8_min) + and torch.all(tensor.qdata <= self.int8_max) + ) + + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ], + ) + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + def test_int8_linear_quantization_accuracy( + self, + dtype: torch.dtype, + sizes: tuple, + config, + ): + """Test quantization preserves reasonable accuracy""" + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + + # Create a linear layer + m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda") + m_q = copy.deepcopy(m) + + # Quantize + quantize_(m_q, config) + + output_original = m(input_tensor) + output_quantized = m_q(input_tensor) + + error = compute_error(output_original, output_quantized) + assert error > 20, ( + f"Quantization quality is too low, SQNR: {error}dB (expected > {20}dB)" + ) + + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_quantization_shapes(self, dtype): + """Test static and dynamic quantization output shapes""" + K, N = 128, 64 + weight = torch.randn(N, K, dtype=dtype, device="cuda") + input_tensor = torch.randn(self.batch_size, K, dtype=dtype, device="cuda") + + # Dynamic quantization (runtime scale computation) + dynamic_tensor = Int8Tensor.from_hp(weight, block_size=[N, K]) + + # Static quantization (pre-computed scale) + act_scale, _ = choose_qparams_affine( + input=input_tensor, + mapping_type=MappingType.SYMMETRIC, + block_size=(input_tensor.shape[0], K), + target_dtype=torch.int8, + quant_min=self.int8_min, + quant_max=self.int8_max, + scale_dtype=dtype, + zero_point_dtype=torch.int8, + ) + + # Static quantization (with pre-computed scale) + static_tensor = Int8Tensor.from_hp( + weight, + block_size=[N, K], + act_quant_kwargs=QuantizeTensorToInt8Kwargs( + block_size=[input_tensor.shape[0], K], + static_scale=act_scale, + ), + ) + + dynamic_output = torch.nn.functional.linear(input_tensor, dynamic_tensor) + static_output = torch.nn.functional.linear(input_tensor, static_tensor) + + expected_shape = (self.batch_size, N) + self.assertEqual(dynamic_output.shape, expected_shape) + self.assertEqual(static_output.shape, expected_shape) + self.assertEqual(dynamic_output.dtype, dtype) + self.assertEqual(static_output.dtype, dtype) + + @unittest.skip("granularity parameter not supported in current API") + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + def test_slice_preserves_aliasing(self, granularity): + slice_size = 512 + tensor_size = 1024 + + config = Int8DynamicActivationInt8WeightConfig( + granularity=granularity, version=2 + ) + l = torch.nn.Linear(tensor_size, tensor_size).to("cuda").to(torch.bfloat16) + l.weight = torch.nn.Parameter( + torch.zeros(tensor_size, tensor_size, dtype=torch.bfloat16, device="cuda") + ) + quantize_(l, config) + param = l.weight + param_data = param.data + param_data = param_data.narrow(0, 0, slice_size) + # Making sure the aliasing is preserved in sliced quantized Tensor + assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() + assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + @common_utils.parametrize("device", ["cpu", "cuda"]) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_slice(self, config, device, dtype): + """Test tensor slicing""" + tensor_size = 256 + slice_sizes = (64, 128) + + dummy = torch.nn.Linear( + tensor_size, tensor_size, bias=False, dtype=dtype, device=device + ) + quantize_(dummy, config) + + weight1 = dummy.weight.clone().narrow(0, 0, slice_sizes[0]) + weight2 = dummy.weight.clone().narrow(1, 0, slice_sizes[1]) + + self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0])) + self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1])) + + # Int8DynamicActivationInt8WeightConfig uses per-row (PerRow) + # Int8WeightOnlyConfig uses per-tensor (PerTensor) + if isinstance(config, Int8DynamicActivationInt8WeightConfig): + # PerRow: dim 0 slicing affects scale, dim 1 doesn't + self.assertEqual( + weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0]) + ) + self.assertEqual(weight2.scale, dummy.weight.scale) + else: + # PerTensor: scale unchanged by slicing + self.assertEqual(weight1.scale, dummy.weight.scale) + self.assertEqual(weight2.scale, dummy.weight.scale) + with self.assertRaises(NotImplementedError): + _ = dummy.weight[::2] + + def test_index_select(self): + """test that `x_0 = x[0]` works when `x` is a 2D `Int8Tensor`.""" + N, K = 256, 512 + x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) + x_int8 = Int8Tensor.from_hp(x, block_size=[N, K]) + x_int8_0 = x_int8[0] + torch.testing.assert_close( + x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0 + ) + + def test_invalid_input_handling(self): + """Test input validation with specific error types""" + invalid_tensor = torch.randn(5) + incompatible_block_size = [1] + + with self.assertRaises( + ValueError, msg="Should reject incompatible tensor dimensions" + ): + Int8Tensor.from_hp(invalid_tensor, incompatible_block_size) + + with self.assertRaises( + ValueError, msg="Should reject mismatched block size dimensions" + ): + Int8Tensor.from_hp(self.weight_fp, [1]) + + def test_dequantization_accuracy(self): + """Test dequantization accuracy separately""" + test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16) + tensor = Int8Tensor.from_hp(test_data, [1, 2]) + + dequantized = tensor.dequantize() + self.assertEqual(dequantized.shape, test_data.shape) + self.assertLess( + torch.abs(dequantized - test_data).max().item(), + 0.1, + msg=f"Dequantization error exceeds tolerance of {0.1}", + ) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 577ac40721..77de6732f7 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -100,6 +100,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, + Int8Tensor, IntxOpaqueTensor, IntxUnpackedToInt8Tensor, ) @@ -173,6 +174,7 @@ "IntxOpaqueTensor", "IntxUnpackedToInt8Tensor", "Int4TilePackedTo4dTensor", + "Int8Tensor", "Float8Tensor", "Int4OpaqueTensor", "Float8OpaqueTensor", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 1e176f9e9b..f8939b19fa 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -85,6 +85,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, + Int8Tensor, IntxChooseQParamsAlgorithm, IntxOpaqueTensor, IntxPackingFormat, @@ -1346,6 +1347,7 @@ class Int8WeightOnlyConfig(AOBaseConfig): group_size: Optional[int] = None set_inductor_config: bool = True + version: int = 2 def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") @@ -1356,22 +1358,30 @@ def __post_init__(self): def _int8_weight_only_quantize_tensor(weight, config): - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - group_size = config.group_size - if group_size is None: - group_size = weight.shape[-1] - block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) - new_weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - ) + if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Int8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" + ) + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + group_size = config.group_size + if group_size is None: + group_size = weight.shape[-1] + block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + ) + else: + assert config.version == 2, f"Unexpected version: {config.version}" + block_size = [weight.shape[0], weight.shape[1]] + new_weight = Int8Tensor.from_hp(weight, block_size=block_size) return new_weight @@ -1517,12 +1527,15 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): in original precision during decode operations. set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values for better performance with this quantization scheme. + version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Int8Tensor """ layout: Optional[Layout] = PlainLayout() act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC weight_only_decode: bool = False + granularity: Optional[Union[PerRow, PerTensor]] = PerRow() set_inductor_config: bool = True + version: int = 2 def __post_init__(self): torch._C._log_api_usage_once( @@ -1554,9 +1567,6 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): mapping_type = MappingType.SYMMETRIC weight_zero_point_domain = ZeroPointDomain.NONE - def get_weight_block_size(x): - return tuple([1 for _ in range(x.dim() - 1)] + [x.shape[-1]]) - target_dtype = torch.int8 eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 @@ -1570,19 +1580,43 @@ def get_weight_block_size(x): else: input_quant_func = _int8_asymm_per_token_quant - block_size = get_weight_block_size(weight) - new_weight = to_affine_quantized_intx( - weight, - mapping_type, - block_size, - target_dtype, - eps=eps, - zero_point_dtype=zero_point_dtype, - _layout=layout, - zero_point_domain=weight_zero_point_domain, - ) - new_weight = to_linear_activation_quantized(new_weight, input_quant_func) - return new_weight + if isinstance(config.granularity, PerTensor): + # Tensor granularity + block_size = weight.shape + else: + # Per row granularity + block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]) + + if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" + ) + quantized_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + _layout=layout, + zero_point_domain=weight_zero_point_domain, + ) + quantized_weight = to_linear_activation_quantized( + quantized_weight, input_quant_func + ) + else: + from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( + QuantizeTensorToInt8Kwargs, + ) + + assert config.version == 2, f"Unexpected version: {config.version}" + quantized_weight = Int8Tensor.from_hp( + weight, + block_size, + act_quant_kwargs=QuantizeTensorToInt8Kwargs(block_size=block_size), + ) + + return quantized_weight @register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py index 0adc8c786d..44dd09ff62 100644 --- a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor( """ from torchao.quantization.quantize_.workflows import ( Float8Tensor, + Int8Tensor, QuantizeTensorToFloat8Kwargs, + QuantizeTensorToInt8Kwargs, ) if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): @@ -52,5 +54,11 @@ def _choose_quant_func_and_quantize_tensor( quant_kwargs.hp_value_ub, quant_kwargs.kernel_preference, ) + elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): + return Int8Tensor.from_hp( + tensor, + quant_kwargs.block_size, + act_quant_kwargs=quant_kwargs, + ) raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index e379327689..e1553adeb6 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -24,6 +24,10 @@ Int4Tensor, ) from .int4.int4_tile_packed_to_4d_tensor import Int4TilePackedTo4dTensor +from .int8.int8_tensor import ( + Int8Tensor, + QuantizeTensorToInt8Kwargs, +) from .intx.intx_choose_qparams_algorithm import IntxChooseQParamsAlgorithm from .intx.intx_opaque_tensor import ( IntxOpaqueTensor, @@ -42,6 +46,8 @@ "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", "Float8OpaqueTensor", + "Int8Tensor", + "QuantizeTensorToInt8Kwargs", "Float8Tensor", "Float8PackingFormat", "QuantizeTensorToFloat8Kwargs", diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py new file mode 100644 index 0000000000..eee493b3e3 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -0,0 +1,300 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.float8.inference import _slice_scale_for_dimension +from torchao.quantization.quant_primitives import ( + MappingType, + _maybe_expand_scale_to_tensor_shape, + choose_qparams_affine, + quantize_affine, +) +from torchao.quantization.quantize_.common import ( + QuantizeTensorKwargs, + _choose_quant_func_and_quantize_tensor, +) +from torchao.utils import TorchAOBaseTensor, fill_defaults + +__all__ = ["Int8Tensor", "QuantizeTensorToInt8Kwargs"] + +aten = torch.ops.aten + + +@dataclass +class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): + """Tensor kwargs for creating int8 tensor (either activation or weight) + + Args: + block_size (list[int]): block size for quantization granularity + static_scale (Optional[torch.Tensor]): pre-computed scale for static quantization + """ + + block_size: list[int] + static_scale: Optional[torch.Tensor] = None + + +class Int8Tensor(TorchAOBaseTensor): + """ + int8 quantized tensor with plain layout + + Tensor Attributes: + qdata: (N, K) int8 quantized weight data + scale: scale factors for dequantization + + Non-Tensor Attributes: + block_size: block size for quantization granularity + act_quant_kwargs: flags for static/dynamic activation quantization + """ + + tensor_data_names = ["qdata", "scale"] + tensor_attribute_names = ["block_size"] + optional_tensor_attribute_names = [ + "act_quant_kwargs", + "dtype", + ] + + def __new__( + cls: type, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], + act_quant_kwargs=None, + dtype=None, + ): + kwargs = { + "device": qdata.device, + "dtype": dtype or scale.dtype, + "requires_grad": False, + } + return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, **kwargs) + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: list[int], + act_quant_kwargs=None, + dtype=None, + ): + super().__init__() + self.qdata = qdata + self.scale = scale + self.block_size = block_size + self.act_quant_kwargs = act_quant_kwargs + + def __repr__(self): + return ( + f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, " + f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})" + ) + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: list[int], + act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + ): + if w.dim() != 2 or len(block_size) != 2: + raise ValueError("Expected 2D tensor and block_size length 2") + + if act_quant_kwargs is not None and act_quant_kwargs.static_scale is not None: + # INT8 × INT8 (static) + scale = act_quant_kwargs.static_scale + zero_point = torch.zeros_like(scale, dtype=torch.int8) + else: + # INT8 × INT8 (dynamic): compute scale at runtime + scale, zero_point = choose_qparams_affine( + input=w, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=w.dtype, + zero_point_dtype=torch.int8, + ) + + int_data = quantize_affine( + w, + block_size=block_size, + scale=scale, + zero_point=zero_point, + output_dtype=torch.int8, + ) + + if tuple(block_size) == w.shape: + # per-tensor + pass + elif len(scale.shape) == 1: + # per-row, 1D -> 2D + scale = scale.unsqueeze(-1) + + return cls( + int_data, + scale, + block_size, + act_quant_kwargs=act_quant_kwargs, + dtype=w.dtype, + ) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Dequantize int8 tensor to floating point""" + + if output_dtype is None: + output_dtype = self.dtype + + qdata_fp = self.qdata.to(output_dtype) + # Reshape scale to broadcast if granularity is block-wise + scale_expanded = _maybe_expand_scale_to_tensor_shape( + self.scale, self.qdata.shape + ) + return qdata_fp * scale_expanded.to(output_dtype) + + +implements = Int8Tensor.implements +implements_torch_function = Int8Tensor.implements_torch_function + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + """quantization: dynamic, static, weight-only int8 quantization""" + activation_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + + assert isinstance(weight_tensor, Int8Tensor), ( + f"Expected weight to be Int8Tensor, got {type(weight_tensor)}" + ) + + if weight_tensor.act_quant_kwargs is not None: + if not isinstance(activation_tensor, Int8Tensor): + # Activation quantization + activation_tensor = _choose_quant_func_and_quantize_tensor( + activation_tensor, weight_tensor.act_quant_kwargs + ) + + x_vals = activation_tensor.qdata + x_scales = activation_tensor.scale + w_vals_t = weight_tensor.qdata.contiguous().t() + w_scales = weight_tensor.scale + + tmp_shape = (-1, x_vals.shape[-1]) + tmp = x_vals.view(tmp_shape) + + # Cast fp16 scale to float + intermediate_dtype = ( + torch.float if x_scales.dtype == torch.half else x_scales.dtype + ) + # Note: CUDA doesn't support int32/int64 matmul, so we convert to float + # Error message is NotImplementedError: "addmm_cuda" not implemented for 'Int' + # This may introduce minor numerical differences compared to int arithmetic + y_dot = torch.mm(tmp.to(intermediate_dtype), w_vals_t.to(intermediate_dtype)) + + # Apply activation scale + is_per_tensor_act = x_scales.numel() == 1 + if is_per_tensor_act: + y_dot.mul_(x_scales.to(intermediate_dtype)) + else: + # For block-wise activation scale, reshape to match y_dot + x_scales_reshaped = x_scales.view(y_dot.shape[0], -1) + y_dot.mul_(x_scales_reshaped.to(intermediate_dtype)) + + # Apply weight scale + is_per_tensor_weight = w_scales.numel() == 1 + if is_per_tensor_weight: + result = y_dot.mul_(w_scales.to(intermediate_dtype)) + else: + # Per-row weight scale - transpose and broadcast + w_scales_broadcast = w_scales.t().expand_as(y_dot) + result = y_dot.mul_(w_scales_broadcast.to(intermediate_dtype)) + + # Reshape back to original shape + result = result.view(*x_vals.shape[:-1], result.shape[-1]) + result = result.to(activation_tensor.dtype) + else: + # FP × INT8 (weight-only) + result = func( + activation_tensor, weight_tensor.dequantize(activation_tensor.dtype), None + ) + + return result + bias if bias is not None else result + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + """Slice operation for Int8Tensor""" + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + + if step != 1: + raise NotImplementedError("Slicing with step > 1 is not supported") + + if end >= self.shape[dim]: + end = self.shape[dim] + + sliced_qdata = aten.slice.Tensor(self.qdata, dim, start, end, step) + + if self.scale.numel() == 1: + # Per-tensor quantization - scale doesn't change + sliced_scale = self.scale + elif dim < self.scale.ndim and self.scale.shape[dim] > 1: + # Block-wise quantization - need to slice the scale appropriately + sliced_scale = aten.slice.Tensor(self.scale, dim, start, end, step) + else: + # Block-wise quantization - need to slice the scale appropriately + sliced_scale = _slice_scale_for_dimension( + self.scale, self.qdata.shape, dim, start, end, step + ) + + block_size = list(self.block_size) + for i in range(len(block_size)): + block_size[i] = min(block_size[i], sliced_qdata.shape[i]) + + return return_and_correct_aliasing( + func, + args, + kwargs, + Int8Tensor( + sliced_qdata, + sliced_scale, + block_size, + self.act_quant_kwargs, + dtype=self.dtype, + ), + ) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + self, dim, index = args + assert dim == 0, f"Only dim=0 supported, got {dim}" + + selected_scale = self.scale if self.scale.ndim == 0 else self.scale[index] + + return return_and_correct_aliasing( + func, + args, + kwargs, + Int8Tensor( + self.qdata[index], + selected_scale, + self.block_size, + self.act_quant_kwargs, + self.dtype, + ), + ) + + +Int8Tensor.__module__ = "torchao.quantization" +torch.serialization.add_safe_globals([Int8Tensor, QuantizeTensorToInt8Kwargs]) From 5d0431884c7680656d1fed788e90b30135c54a60 Mon Sep 17 00:00:00 2001 From: younn17 Date: Sun, 26 Oct 2025 14:59:38 +0900 Subject: [PATCH 02/19] refactor ops and update test cases --- .../workflows/int8/test_int8_tensor.py | 95 +++++++------------ .../quantize_/workflows/int8/int8_tensor.py | 94 ++++++++---------- 2 files changed, 70 insertions(+), 119 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 0fc959beb7..6e225ee108 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -13,14 +13,10 @@ from torchao.quantization import ( Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, - PerRow, - PerTensor, quantize_, ) -from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( Int8Tensor, - QuantizeTensorToInt8Kwargs, ) from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase @@ -62,8 +58,6 @@ def setUp(self): self.test_shape = (4, 3) self.dtype = torch.bfloat16 self.batch_size = 32 - self.int8_min = -128 - self.int8_max = 127 torch.manual_seed(42) self.weight_fp = torch.randn(*self.test_shape, dtype=self.dtype) @@ -78,8 +72,7 @@ def test_creation_and_attributes(self): self.assertEqual(tensor.shape, self.test_shape) self.assertEqual(tensor.qdata.dtype, torch.int8) self.assertTrue( - torch.all(tensor.qdata >= self.int8_min) - and torch.all(tensor.qdata <= self.int8_max) + torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) @@ -122,66 +115,42 @@ def test_int8_linear_quantization_accuracy( ) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) - def test_quantization_shapes(self, dtype): - """Test static and dynamic quantization output shapes""" - K, N = 128, 64 - weight = torch.randn(N, K, dtype=dtype, device="cuda") - input_tensor = torch.randn(self.batch_size, K, dtype=dtype, device="cuda") - - # Dynamic quantization (runtime scale computation) - dynamic_tensor = Int8Tensor.from_hp(weight, block_size=[N, K]) - - # Static quantization (pre-computed scale) - act_scale, _ = choose_qparams_affine( - input=input_tensor, - mapping_type=MappingType.SYMMETRIC, - block_size=(input_tensor.shape[0], K), - target_dtype=torch.int8, - quant_min=self.int8_min, - quant_max=self.int8_max, - scale_dtype=dtype, - zero_point_dtype=torch.int8, - ) + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + def test_per_row_scale_shape(self, dtype, config): + """Test per-row quantization maintains 1D scale""" + N, K = 64, 128 + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") + quantize_(linear, config) - # Static quantization (with pre-computed scale) - static_tensor = Int8Tensor.from_hp( - weight, - block_size=[N, K], - act_quant_kwargs=QuantizeTensorToInt8Kwargs( - block_size=[input_tensor.shape[0], K], - static_scale=act_scale, - ), - ) + # Dynamic: per-row (1D scale [N]), Weight-only: per-tensor (scalar) + if isinstance(config, Int8DynamicActivationInt8WeightConfig): + self.assertEqual(linear.weight.scale.shape, (N,)) + self.assertEqual(linear.weight.scale.ndim, 1) + else: + self.assertEqual(linear.weight.scale.numel(), 1) - dynamic_output = torch.nn.functional.linear(input_tensor, dynamic_tensor) - static_output = torch.nn.functional.linear(input_tensor, static_tensor) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + @common_utils.parametrize("has_bias", [True, False]) + def test_weight_only_linear_with_bias(self, dtype, has_bias): + """Test weight-only quantization with and without bias""" + K, N = 128, 64 + linear = torch.nn.Linear(K, N, bias=has_bias, dtype=dtype, device="cuda") + input_tensor = torch.randn(self.batch_size, K, dtype=dtype, device="cuda") - expected_shape = (self.batch_size, N) - self.assertEqual(dynamic_output.shape, expected_shape) - self.assertEqual(static_output.shape, expected_shape) - self.assertEqual(dynamic_output.dtype, dtype) - self.assertEqual(static_output.dtype, dtype) + output_fp = linear(input_tensor) - @unittest.skip("granularity parameter not supported in current API") - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) - def test_slice_preserves_aliasing(self, granularity): - slice_size = 512 - tensor_size = 1024 + quantize_(linear, Int8WeightOnlyConfig(version=2)) + output_q = linear(input_tensor) - config = Int8DynamicActivationInt8WeightConfig( - granularity=granularity, version=2 - ) - l = torch.nn.Linear(tensor_size, tensor_size).to("cuda").to(torch.bfloat16) - l.weight = torch.nn.Parameter( - torch.zeros(tensor_size, tensor_size, dtype=torch.bfloat16, device="cuda") - ) - quantize_(l, config) - param = l.weight - param_data = param.data - param_data = param_data.narrow(0, 0, slice_size) - # Making sure the aliasing is preserved in sliced quantized Tensor - assert param.data.qdata.data_ptr() == param_data.qdata.data_ptr() - assert param.data.scale.data_ptr() == param_data.scale.data_ptr() + self.assertEqual(output_q.shape, output_fp.shape) + error = compute_error(output_fp, output_q) + self.assertGreater(error, 20) @common_utils.parametrize( "config", diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index eee493b3e3..ded7020263 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -11,6 +11,7 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.float8.inference import _slice_scale_for_dimension +from torchao.kernel import int_scaled_matmul from torchao.quantization.quant_primitives import ( MappingType, _maybe_expand_scale_to_tensor_shape, @@ -34,11 +35,10 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): Args: block_size (list[int]): block size for quantization granularity - static_scale (Optional[torch.Tensor]): pre-computed scale for static quantization + # TODO: Static quantization support using `static_scale`, `static_zero_point` """ block_size: list[int] - static_scale: Optional[torch.Tensor] = None class Int8Tensor(TorchAOBaseTensor): @@ -51,7 +51,7 @@ class Int8Tensor(TorchAOBaseTensor): Non-Tensor Attributes: block_size: block size for quantization granularity - act_quant_kwargs: flags for static/dynamic activation quantization + act_quant_kwargs: flags for dynamic activation quantization """ tensor_data_names = ["qdata", "scale"] @@ -106,22 +106,16 @@ def from_hp( if w.dim() != 2 or len(block_size) != 2: raise ValueError("Expected 2D tensor and block_size length 2") - if act_quant_kwargs is not None and act_quant_kwargs.static_scale is not None: - # INT8 × INT8 (static) - scale = act_quant_kwargs.static_scale - zero_point = torch.zeros_like(scale, dtype=torch.int8) - else: - # INT8 × INT8 (dynamic): compute scale at runtime - scale, zero_point = choose_qparams_affine( - input=w, - mapping_type=MappingType.SYMMETRIC, - block_size=block_size, - target_dtype=torch.int8, - quant_min=-128, - quant_max=127, - scale_dtype=w.dtype, - zero_point_dtype=torch.int8, - ) + scale, zero_point = choose_qparams_affine( + input=w, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=w.dtype, + zero_point_dtype=torch.int8, + ) int_data = quantize_affine( w, @@ -131,12 +125,9 @@ def from_hp( output_dtype=torch.int8, ) - if tuple(block_size) == w.shape: - # per-tensor + if tuple(block_size) != w.shape and len(scale.shape) == 1: + # per-row pass - elif len(scale.shape) == 1: - # per-row, 1D -> 2D - scale = scale.unsqueeze(-1) return cls( int_data, @@ -167,7 +158,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor @implements(aten.linear.default) @implements_torch_function(torch.nn.functional.linear) def _(func, types, args, kwargs): - """quantization: dynamic, static, weight-only int8 quantization""" + """quantization: dynamic, weight-only int8 quantization""" activation_tensor, weight_tensor, bias = ( args[0], args[1], @@ -190,44 +181,29 @@ def _(func, types, args, kwargs): w_vals_t = weight_tensor.qdata.contiguous().t() w_scales = weight_tensor.scale - tmp_shape = (-1, x_vals.shape[-1]) - tmp = x_vals.view(tmp_shape) - - # Cast fp16 scale to float + tmp = x_vals.reshape(-1, x_vals.shape[-1]) intermediate_dtype = ( torch.float if x_scales.dtype == torch.half else x_scales.dtype ) - # Note: CUDA doesn't support int32/int64 matmul, so we convert to float - # Error message is NotImplementedError: "addmm_cuda" not implemented for 'Int' - # This may introduce minor numerical differences compared to int arithmetic - y_dot = torch.mm(tmp.to(intermediate_dtype), w_vals_t.to(intermediate_dtype)) - - # Apply activation scale - is_per_tensor_act = x_scales.numel() == 1 - if is_per_tensor_act: - y_dot.mul_(x_scales.to(intermediate_dtype)) - else: - # For block-wise activation scale, reshape to match y_dot - x_scales_reshaped = x_scales.view(y_dot.shape[0], -1) - y_dot.mul_(x_scales_reshaped.to(intermediate_dtype)) - - # Apply weight scale - is_per_tensor_weight = w_scales.numel() == 1 - if is_per_tensor_weight: - result = y_dot.mul_(w_scales.to(intermediate_dtype)) - else: - # Per-row weight scale - transpose and broadcast - w_scales_broadcast = w_scales.t().expand_as(y_dot) - result = y_dot.mul_(w_scales_broadcast.to(intermediate_dtype)) - # Reshape back to original shape - result = result.view(*x_vals.shape[:-1], result.shape[-1]) + y_dot_scaled = int_scaled_matmul( + tmp, w_vals_t, x_scales.reshape(-1, 1).to(intermediate_dtype) + ) + y_dot_scaled = y_dot_scaled.to(x_scales.dtype) + + result = (y_dot_scaled * w_scales).reshape( + *x_vals.shape[:-1], y_dot_scaled.shape[-1] + ) result = result.to(activation_tensor.dtype) else: # FP × INT8 (weight-only) - result = func( - activation_tensor, weight_tensor.dequantize(activation_tensor.dtype), None + w_vals_int8_t = weight_tensor.qdata.t() + m = torch.mm( + activation_tensor.reshape(-1, activation_tensor.shape[-1]), + w_vals_int8_t.to(activation_tensor.dtype), ) + result = m * weight_tensor.scale.to(m.dtype) + result = result.reshape(*activation_tensor.shape[:-1], result.shape[-1]) return result + bias if bias is not None else result @@ -248,11 +224,17 @@ def _(func, types, args, kwargs): if self.scale.numel() == 1: # Per-tensor quantization - scale doesn't change sliced_scale = self.scale + elif self.scale.ndim == 1: + # Per-row: 1D scale - only slice if dim=0 + if dim == 0: + sliced_scale = aten.slice.Tensor(self.scale, 0, start, end, step) + else: + sliced_scale = self.scale elif dim < self.scale.ndim and self.scale.shape[dim] > 1: # Block-wise quantization - need to slice the scale appropriately sliced_scale = aten.slice.Tensor(self.scale, dim, start, end, step) else: - # Block-wise quantization - need to slice the scale appropriately + # Block-wise quantization with different dimensions sliced_scale = _slice_scale_for_dimension( self.scale, self.qdata.shape, dim, start, end, step ) From 8c7afcba19f521cb54f7cc43c35390dda2f6c35b Mon Sep 17 00:00:00 2001 From: younn17 Date: Sun, 26 Oct 2025 16:22:20 +0900 Subject: [PATCH 03/19] update granularity slicing support --- torchao/float8/inference.py | 25 ++++++++++++----- .../quantize_/workflows/int8/int8_tensor.py | 27 +++---------------- 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 212df9c5db..42edfa6afa 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -142,7 +142,18 @@ def _slice_scale_for_dimension( """ aten = torch.ops.aten - # Unsupported case for now, this would be 1 scale per data element + # Per-tensor quantization (scalar scale) + if scale.numel() == 1: + return scale + + # Per-row quantization (1D scale) + if scale.ndim == 1: + if dim == 0: + return aten.slice.Tensor(scale, 0, start, end, step) + else: + return scale + + # Block-wise quantization (2D scale) if scale.shape == data_shape: return aten.slice.Tensor(scale, dim, start, end, step) @@ -160,6 +171,12 @@ def _slice_scale_for_dimension( # Slice away as normal return aten.slice.Tensor(scale, dim, start, end, step) else: + # Error on Step > 1 + if step > 1: + raise NotImplementedError( + "Slicing with step > 1 is not implemented for scale tensors." + ) + # There is blocking in this dimension # Calculate which scale elements correspond to the sliced data scale_start = start // block_size_for_dim if start is not None else None @@ -169,12 +186,6 @@ def _slice_scale_for_dimension( else None ) - # Error on Step > 1 - if step > 1: - raise NotImplementedError( - "Slicing with step > 1 is not implemented for scale tensors." - ) - return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index ded7020263..622d69ae09 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -125,10 +125,6 @@ def from_hp( output_dtype=torch.int8, ) - if tuple(block_size) != w.shape and len(scale.shape) == 1: - # per-row - pass - return cls( int_data, scale, @@ -158,7 +154,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor @implements(aten.linear.default) @implements_torch_function(torch.nn.functional.linear) def _(func, types, args, kwargs): - """quantization: dynamic, weight-only int8 quantization""" + """INT8 quantization: dynamic activation or weight-only""" activation_tensor, weight_tensor, bias = ( args[0], args[1], @@ -220,24 +216,9 @@ def _(func, types, args, kwargs): end = self.shape[dim] sliced_qdata = aten.slice.Tensor(self.qdata, dim, start, end, step) - - if self.scale.numel() == 1: - # Per-tensor quantization - scale doesn't change - sliced_scale = self.scale - elif self.scale.ndim == 1: - # Per-row: 1D scale - only slice if dim=0 - if dim == 0: - sliced_scale = aten.slice.Tensor(self.scale, 0, start, end, step) - else: - sliced_scale = self.scale - elif dim < self.scale.ndim and self.scale.shape[dim] > 1: - # Block-wise quantization - need to slice the scale appropriately - sliced_scale = aten.slice.Tensor(self.scale, dim, start, end, step) - else: - # Block-wise quantization with different dimensions - sliced_scale = _slice_scale_for_dimension( - self.scale, self.qdata.shape, dim, start, end, step - ) + sliced_scale = _slice_scale_for_dimension( + self.scale, self.qdata.shape, dim, start, end, step + ) block_size = list(self.block_size) for i in range(len(block_size)): From fcad08247cf4ba0afa530c66946c69732d65f21c Mon Sep 17 00:00:00 2001 From: youn17 Date: Fri, 31 Oct 2025 18:29:49 +0900 Subject: [PATCH 04/19] add 3D support to api, build linear variants test --- .../workflows/int8/test_int8_tensor.py | 69 ++++++++++++------ torchao/quantization/quant_api.py | 4 +- .../quantize_/workflows/int8/int8_tensor.py | 73 ++++++++++++++----- 3 files changed, 102 insertions(+), 44 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 6e225ee108..fb091a0e22 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -75,13 +75,48 @@ def test_creation_and_attributes(self): torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) ) - @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) + @common_utils.parametrize("compile", [True, False]) + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) @common_utils.parametrize( "sizes", [ - ((128,), 256, 128), + ((128,), 256, 128), # 2D + ((32, 128), 64, 256), # 3D ], ) + def test_int8_linear_variants( + self, + dtype: torch.dtype, + config, + compile: bool, + sizes: tuple, + ): + """Test linear operation supports including shape and compile""" + M, N, K = sizes + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") + model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval() + model_q = copy.deepcopy(model) + + quantize_(model_q, config) + + if compile: + model_q = torch.compile(model_q, fullgraph=True) + + output_fp = model(input_tensor) + output_quantized = model_q(input_tensor) + + assert compute_error(output_fp, output_quantized) > 20, ( + f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" + ) + + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) @common_utils.parametrize( "config", [ @@ -89,6 +124,13 @@ def test_creation_and_attributes(self): Int8WeightOnlyConfig(version=2), ], ) + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), # 2D + ((32, 128), 64, 256), # 3D + ], + ) def test_int8_linear_quantization_accuracy( self, dtype: torch.dtype, @@ -100,16 +142,16 @@ def test_int8_linear_quantization_accuracy( input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") # Create a linear layer - m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda") + m = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval() m_q = copy.deepcopy(m) # Quantize quantize_(m_q, config) - output_original = m(input_tensor) + output_fp = m(input_tensor) output_quantized = m_q(input_tensor) - error = compute_error(output_original, output_quantized) + error = compute_error(output_fp, output_quantized) assert error > 20, ( f"Quantization quality is too low, SQNR: {error}dB (expected > {20}dB)" ) @@ -135,23 +177,6 @@ def test_per_row_scale_shape(self, dtype, config): else: self.assertEqual(linear.weight.scale.numel(), 1) - @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) - @common_utils.parametrize("has_bias", [True, False]) - def test_weight_only_linear_with_bias(self, dtype, has_bias): - """Test weight-only quantization with and without bias""" - K, N = 128, 64 - linear = torch.nn.Linear(K, N, bias=has_bias, dtype=dtype, device="cuda") - input_tensor = torch.randn(self.batch_size, K, dtype=dtype, device="cuda") - - output_fp = linear(input_tensor) - - quantize_(linear, Int8WeightOnlyConfig(version=2)) - output_q = linear(input_tensor) - - self.assertEqual(output_q.shape, output_fp.shape) - error = compute_error(output_fp, output_q) - self.assertGreater(error, 20) - @common_utils.parametrize( "config", [ diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f8939b19fa..55dca97206 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1347,7 +1347,7 @@ class Int8WeightOnlyConfig(AOBaseConfig): group_size: Optional[int] = None set_inductor_config: bool = True - version: int = 2 + version: int = 1 def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") @@ -1535,7 +1535,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): weight_only_decode: bool = False granularity: Optional[Union[PerRow, PerTensor]] = PerRow() set_inductor_config: bool = True - version: int = 2 + version: int = 1 def __post_init__(self): torch._C._log_api_usage_once( diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 622d69ae09..a6e747e7e1 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -10,7 +10,10 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.float8.inference import _slice_scale_for_dimension +from torchao.float8.inference import ( + _slice_scale_for_dimension, + preprocess_scale, +) from torchao.kernel import int_scaled_matmul from torchao.quantization.quant_primitives import ( MappingType, @@ -46,7 +49,7 @@ class Int8Tensor(TorchAOBaseTensor): int8 quantized tensor with plain layout Tensor Attributes: - qdata: (N, K) int8 quantized weight data + qdata: (N, K) or (B, N, K) int8 quantized weight data (2D or 3D) scale: scale factors for dequantization Non-Tensor Attributes: @@ -103,8 +106,8 @@ def from_hp( block_size: list[int], act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, ): - if w.dim() != 2 or len(block_size) != 2: - raise ValueError("Expected 2D tensor and block_size length 2") + if w.dim() not in [2, 3] or len(block_size) != w.dim(): + raise ValueError("Expected 2D or 3D tensor with same block_size length") scale, zero_point = choose_qparams_affine( input=w, @@ -165,41 +168,66 @@ def _(func, types, args, kwargs): f"Expected weight to be Int8Tensor, got {type(weight_tensor)}" ) + # Store original shape for reshaping result + original_weight_shape = weight_tensor.qdata.shape + + # Reshape 3D weights to 2D: (B, N, K) -> (B*N, K) + if weight_tensor.qdata.dim() == 3: + w_q_2d = weight_tensor.qdata.reshape(-1, original_weight_shape[-1]) + w_scale_2d = ( + weight_tensor.scale.reshape(-1) + if weight_tensor.scale.numel() > 1 + else weight_tensor.scale + ) + else: + w_q_2d = weight_tensor.qdata + w_scale_2d = weight_tensor.scale + if weight_tensor.act_quant_kwargs is not None: if not isinstance(activation_tensor, Int8Tensor): - # Activation quantization + # Dynamic activation quantization + act_kwargs = weight_tensor.act_quant_kwargs + input_ndim = activation_tensor.ndim + + # Ensure block_size matches input tensor dimensions + if len(act_kwargs.block_size) != input_ndim: + if input_ndim == 3 and len(act_kwargs.block_size) == 2: + block_size_updated = [1] + list(act_kwargs.block_size) + else: + block_size_updated = list(act_kwargs.block_size)[-input_ndim:] + act_kwargs = QuantizeTensorToInt8Kwargs(block_size=block_size_updated) + activation_tensor = _choose_quant_func_and_quantize_tensor( - activation_tensor, weight_tensor.act_quant_kwargs + activation_tensor, act_kwargs ) - x_vals = activation_tensor.qdata - x_scales = activation_tensor.scale - w_vals_t = weight_tensor.qdata.contiguous().t() - w_scales = weight_tensor.scale - - tmp = x_vals.reshape(-1, x_vals.shape[-1]) + x_vals = activation_tensor.qdata.reshape(-1, activation_tensor.qdata.shape[-1]) + x_scales = preprocess_scale(activation_tensor.scale, x_vals.shape) + w_vals_t = w_q_2d.contiguous().t() intermediate_dtype = ( torch.float if x_scales.dtype == torch.half else x_scales.dtype ) y_dot_scaled = int_scaled_matmul( - tmp, w_vals_t, x_scales.reshape(-1, 1).to(intermediate_dtype) + x_vals, w_vals_t, x_scales.to(intermediate_dtype) ) - y_dot_scaled = y_dot_scaled.to(x_scales.dtype) + y_dot_scaled = y_dot_scaled.to(activation_tensor.scale.dtype) - result = (y_dot_scaled * w_scales).reshape( - *x_vals.shape[:-1], y_dot_scaled.shape[-1] + result = (y_dot_scaled * w_scale_2d).reshape( + *activation_tensor.shape[:-1], *original_weight_shape[:-1] ) result = result.to(activation_tensor.dtype) else: # FP × INT8 (weight-only) - w_vals_int8_t = weight_tensor.qdata.t() + w_vals_int8_t = w_q_2d.t() m = torch.mm( activation_tensor.reshape(-1, activation_tensor.shape[-1]), w_vals_int8_t.to(activation_tensor.dtype), ) - result = m * weight_tensor.scale.to(m.dtype) - result = result.reshape(*activation_tensor.shape[:-1], result.shape[-1]) + result = m * w_scale_2d.to(m.dtype) + result = result.reshape( + *activation_tensor.shape[:-1], *original_weight_shape[:-1] + ) return result + bias if bias is not None else result @@ -212,6 +240,11 @@ def _(func, types, args, kwargs): if step != 1: raise NotImplementedError("Slicing with step > 1 is not supported") + assert dim in [0, 1, 2], f"Only dim=0,1,2 are supported, got: dim={dim}" + assert self.qdata.ndim in [2, 3], ( + f"Expected qdata to have dim=2,3 got: dim={self.qdata.ndim}" + ) + if end >= self.shape[dim]: end = self.shape[dim] @@ -252,7 +285,7 @@ def _(func, types, args, kwargs): Int8Tensor( self.qdata[index], selected_scale, - self.block_size, + self.block_size[1:], self.act_quant_kwargs, self.dtype, ), From 5d8543b5a6e469adc0fd3ceaf1b737f9e2ec28eb Mon Sep 17 00:00:00 2001 From: youn17 Date: Fri, 31 Oct 2025 18:54:35 +0900 Subject: [PATCH 05/19] add kernel detection test case --- .../workflows/int8/test_int8_tensor.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index fb091a0e22..83ca569af4 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -8,6 +8,7 @@ import unittest import torch +from torch._inductor.utils import run_and_get_code from torch.testing._internal import common_utils from torchao.quantization import ( @@ -255,6 +256,43 @@ def test_dequantization_accuracy(self): msg=f"Dequantization error exceeds tolerance of {0.1}", ) + def test_available_gpu_kernels(self): + """Check which GPU kernels are available""" + M, K, N = 128, 256, 512 + m = torch.nn.Sequential( + torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) + ) + config = Int8DynamicActivationInt8WeightConfig(version=2) + quantize_(m, config) + m = torch.compile(m) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + try: + out, code = run_and_get_code(m, x) + kernels_found = {} + + # Check for Triton kernels + if "torch.ops.triton" in code[0]: + kernels_found["triton"] = True + print("Triton kernels are available for int8 quantization") + else: + kernels_found["triton"] = False + print("Triton kernels are NOT available for int8 quantization") + + # Check for FBGEMM kernels + if "torch.ops.fbgemm" in code[0]: + kernels_found["fbgemm"] = True + print("FBGEMM kernels are available for int8 quantization") + else: + kernels_found["fbgemm"] = False + print("FBGEMM kernels are NOT available for int8 quantization") + + # Just log what we found, don't fail the test + print(f"Available kernels for int8 quantization: {kernels_found}") + + except Exception as e: + print(f"Could not check available kernels: {e}") + if __name__ == "__main__": common_utils.run_tests() From 40a99f4f7e060b22c4a27b708b738e243e463524 Mon Sep 17 00:00:00 2001 From: youn17 Date: Fri, 31 Oct 2025 19:05:19 +0900 Subject: [PATCH 06/19] refactor kernel test --- .../workflows/int8/test_int8_tensor.py | 43 ++++++++----------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 83ca569af4..1ea3b22948 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -267,31 +267,24 @@ def test_available_gpu_kernels(self): m = torch.compile(m) x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - try: - out, code = run_and_get_code(m, x) - kernels_found = {} - - # Check for Triton kernels - if "torch.ops.triton" in code[0]: - kernels_found["triton"] = True - print("Triton kernels are available for int8 quantization") - else: - kernels_found["triton"] = False - print("Triton kernels are NOT available for int8 quantization") - - # Check for FBGEMM kernels - if "torch.ops.fbgemm" in code[0]: - kernels_found["fbgemm"] = True - print("FBGEMM kernels are available for int8 quantization") - else: - kernels_found["fbgemm"] = False - print("FBGEMM kernels are NOT available for int8 quantization") - - # Just log what we found, don't fail the test - print(f"Available kernels for int8 quantization: {kernels_found}") - - except Exception as e: - print(f"Could not check available kernels: {e}") + out, code = run_and_get_code(m, x) + kernels = {} + + # Check for Triton kernels + if "torch.ops.triton" in code[0]: + kernels["triton"] = True + print("Triton kernels are available for int8 quantization") + else: + kernels["triton"] = False + print("Triton kernels are NOT available for int8 quantization") + + # Check for FBGEMM kernels + if "torch.ops.fbgemm" in code[0]: + kernels["fbgemm"] = True + print("FBGEMM kernels are available for int8 quantization") + else: + kernels["fbgemm"] = False + print("FBGEMM kernels are NOT available for int8 quantization") if __name__ == "__main__": From 051030a1f06add67f1f2f62eecb3813819d2b872 Mon Sep 17 00:00:00 2001 From: younn17 Date: Tue, 4 Nov 2025 22:47:31 +0900 Subject: [PATCH 07/19] update linear variant, kernel detection test - Configs are updated to global variants --- .../workflows/int8/test_int8_tensor.py | 137 +++++++----------- .../quantize_/workflows/int8/int8_tensor.py | 18 ++- 2 files changed, 65 insertions(+), 90 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 1ea3b22948..a7411cc967 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -16,9 +16,6 @@ Int8WeightOnlyConfig, quantize_, ) -from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( - Int8Tensor, -) from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase @@ -56,7 +53,7 @@ class TestInt8Tensor(TorchAOIntegrationTestCase): def setUp(self): super().setUp() - self.test_shape = (4, 3) + self.test_shape = (32, 20) self.dtype = torch.bfloat16 self.batch_size = 32 @@ -66,9 +63,26 @@ def setUp(self): self.bias = torch.randn(self.test_shape[0], dtype=self.dtype) self.block_size = list(self.test_shape) - def test_creation_and_attributes(self): + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + def test_creation_and_attributes(self, config): """Test tensor creation, dtypes, and ranges""" - tensor = Int8Tensor.from_hp(self.weight_fp, self.block_size) + linear = torch.nn.Linear( + self.test_shape[1], + self.test_shape[0], + bias=False, + dtype=self.dtype, + device="cuda", + ) + linear.weight.data = self.weight_fp.cuda() + quantize_(linear, config) + + tensor = linear.weight self.assertEqual(tensor.shape, self.test_shape) self.assertEqual(tensor.qdata.dtype, torch.int8) @@ -117,46 +131,6 @@ def test_int8_linear_variants( f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" ) - @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig(version=2), - Int8WeightOnlyConfig(version=2), - ], - ) - @common_utils.parametrize( - "sizes", - [ - ((128,), 256, 128), # 2D - ((32, 128), 64, 256), # 3D - ], - ) - def test_int8_linear_quantization_accuracy( - self, - dtype: torch.dtype, - sizes: tuple, - config, - ): - """Test quantization preserves reasonable accuracy""" - M, N, K = sizes - input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") - - # Create a linear layer - m = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval() - m_q = copy.deepcopy(m) - - # Quantize - quantize_(m_q, config) - - output_fp = m(input_tensor) - output_quantized = m_q(input_tensor) - - error = compute_error(output_fp, output_quantized) - assert error > 20, ( - f"Quantization quality is too low, SQNR: {error}dB (expected > {20}dB)" - ) - @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) @common_utils.parametrize( "config", @@ -218,36 +192,42 @@ def test_slice(self, config, device, dtype): with self.assertRaises(NotImplementedError): _ = dummy.weight[::2] - def test_index_select(self): - """test that `x_0 = x[0]` works when `x` is a 2D `Int8Tensor`.""" + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + def test_index_select(self, config): + """test that `x_0 = x[0]` works when `x` is a 2D quantized tensor.""" N, K = 256, 512 x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) - x_int8 = Int8Tensor.from_hp(x, block_size=[N, K]) + linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda") + linear.weight.data = x + quantize_(linear, config) + + x_int8 = linear.weight x_int8_0 = x_int8[0] torch.testing.assert_close( x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0 ) - def test_invalid_input_handling(self): - """Test input validation with specific error types""" - invalid_tensor = torch.randn(5) - incompatible_block_size = [1] - - with self.assertRaises( - ValueError, msg="Should reject incompatible tensor dimensions" - ): - Int8Tensor.from_hp(invalid_tensor, incompatible_block_size) - - with self.assertRaises( - ValueError, msg="Should reject mismatched block size dimensions" - ): - Int8Tensor.from_hp(self.weight_fp, [1]) - - def test_dequantization_accuracy(self): + @common_utils.parametrize( + "config", + [ + Int8DynamicActivationInt8WeightConfig(version=2), + Int8WeightOnlyConfig(version=2), + ], + ) + def test_dequantization_accuracy(self, config): """Test dequantization accuracy separately""" - test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16) - tensor = Int8Tensor.from_hp(test_data, [1, 2]) + test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16, device="cuda") + linear = torch.nn.Linear(2, 1, bias=False, dtype=torch.bfloat16, device="cuda") + linear.weight.data = test_data + quantize_(linear, config) + tensor = linear.weight dequantized = tensor.dequantize() self.assertEqual(dequantized.shape, test_data.shape) self.assertLess( @@ -268,23 +248,14 @@ def test_available_gpu_kernels(self): x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) out, code = run_and_get_code(m, x) - kernels = {} - - # Check for Triton kernels - if "torch.ops.triton" in code[0]: - kernels["triton"] = True - print("Triton kernels are available for int8 quantization") - else: - kernels["triton"] = False - print("Triton kernels are NOT available for int8 quantization") + has_triton = "triton" in code[0].lower() # Trition + has_fbgemm = "fbgemm" in code[0].lower() # FB-GEMM + has_int_mm = "_int_mm" in code[0] # Int8 MatMul - # Check for FBGEMM kernels - if "torch.ops.fbgemm" in code[0]: - kernels["fbgemm"] = True - print("FBGEMM kernels are available for int8 quantization") - else: - kernels["fbgemm"] = False - print("FBGEMM kernels are NOT available for int8 quantization") + self.assertTrue( + has_triton or has_fbgemm or has_int_mm, + f"No int8 quantization kernels found. has_triton={has_triton}, has_fbgemm={has_fbgemm}, has_int_mm={has_int_mm}", + ) if __name__ == "__main__": diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index a6e747e7e1..11633ebc24 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -143,10 +143,11 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor output_dtype = self.dtype qdata_fp = self.qdata.to(output_dtype) - # Reshape scale to broadcast if granularity is block-wise - scale_expanded = _maybe_expand_scale_to_tensor_shape( - self.scale, self.qdata.shape - ) + scale = self.scale + while scale.ndim < qdata_fp.ndim: + scale = scale.unsqueeze(-1) + + scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, qdata_fp.shape) return qdata_fp * scale_expanded.to(output_dtype) @@ -276,16 +277,19 @@ def _(func, types, args, kwargs): self, dim, index = args assert dim == 0, f"Only dim=0 supported, got {dim}" - selected_scale = self.scale if self.scale.ndim == 0 else self.scale[index] + selected_qdata = self.qdata[index] + selected_scale = _slice_scale_for_dimension( + self.scale, self.qdata.shape, dim, index, index + 1, step=1 + ).squeeze(0) return return_and_correct_aliasing( func, args, kwargs, Int8Tensor( - self.qdata[index], + selected_qdata, selected_scale, - self.block_size[1:], + [selected_qdata.shape[-1]], self.act_quant_kwargs, self.dtype, ), From f74bb9a811bcf318f1e15020fcc61995ce66b579 Mon Sep 17 00:00:00 2001 From: youn17 Date: Tue, 11 Nov 2025 16:24:11 +0900 Subject: [PATCH 08/19] update default granularity, kernel test --- .../workflows/int8/test_int8_tensor.py | 47 ++++++------------- torchao/quantization/quant_api.py | 5 +- 2 files changed, 18 insertions(+), 34 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index a7411cc967..d687107b8c 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -9,6 +9,7 @@ import torch from torch._inductor.utils import run_and_get_code +from torch.testing import FileCheck from torch.testing._internal import common_utils from torchao.quantization import ( @@ -145,12 +146,8 @@ def test_per_row_scale_shape(self, dtype, config): linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") quantize_(linear, config) - # Dynamic: per-row (1D scale [N]), Weight-only: per-tensor (scalar) - if isinstance(config, Int8DynamicActivationInt8WeightConfig): - self.assertEqual(linear.weight.scale.shape, (N,)) - self.assertEqual(linear.weight.scale.ndim, 1) - else: - self.assertEqual(linear.weight.scale.numel(), 1) + self.assertEqual(linear.weight.scale.shape, (N,)) + self.assertEqual(linear.weight.scale.ndim, 1) @common_utils.parametrize( "config", @@ -162,7 +159,7 @@ def test_per_row_scale_shape(self, dtype, config): @common_utils.parametrize("device", ["cpu", "cuda"]) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) def test_slice(self, config, device, dtype): - """Test tensor slicing""" + """Test tensor slicing with per-row quantization""" tensor_size = 256 slice_sizes = (64, 128) @@ -176,19 +173,8 @@ def test_slice(self, config, device, dtype): self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0])) self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1])) - - # Int8DynamicActivationInt8WeightConfig uses per-row (PerRow) - # Int8WeightOnlyConfig uses per-tensor (PerTensor) - if isinstance(config, Int8DynamicActivationInt8WeightConfig): - # PerRow: dim 0 slicing affects scale, dim 1 doesn't - self.assertEqual( - weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0]) - ) - self.assertEqual(weight2.scale, dummy.weight.scale) - else: - # PerTensor: scale unchanged by slicing - self.assertEqual(weight1.scale, dummy.weight.scale) - self.assertEqual(weight2.scale, dummy.weight.scale) + self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])) + self.assertEqual(weight2.scale, dummy.weight.scale) with self.assertRaises(NotImplementedError): _ = dummy.weight[::2] @@ -230,13 +216,15 @@ def test_dequantization_accuracy(self, config): tensor = linear.weight dequantized = tensor.dequantize() self.assertEqual(dequantized.shape, test_data.shape) - self.assertLess( - torch.abs(dequantized - test_data).max().item(), - 0.1, - msg=f"Dequantization error exceeds tolerance of {0.1}", + assert compute_error(dequantized, test_data) > 20, ( + f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, test_data)}" ) - def test_available_gpu_kernels(self): + @common_utils.parametrize( + "kernel", + ["triton_per_fused", "extern_kernels._int_mm", "triton_poi_fused"], + ) + def test_available_gpu_kernels(self, kernel): """Check which GPU kernels are available""" M, K, N = 128, 256, 512 m = torch.nn.Sequential( @@ -248,14 +236,7 @@ def test_available_gpu_kernels(self): x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) out, code = run_and_get_code(m, x) - has_triton = "triton" in code[0].lower() # Trition - has_fbgemm = "fbgemm" in code[0].lower() # FB-GEMM - has_int_mm = "_int_mm" in code[0] # Int8 MatMul - - self.assertTrue( - has_triton or has_fbgemm or has_int_mm, - f"No int8 quantization kernels found. has_triton={has_triton}, has_fbgemm={has_fbgemm}, has_int_mm={has_int_mm}", - ) + FileCheck().check(kernel).run(code[0]) if __name__ == "__main__": diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 55dca97206..c9146e24cb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1380,7 +1380,10 @@ def _int8_weight_only_quantize_tensor(weight, config): ) else: assert config.version == 2, f"Unexpected version: {config.version}" - block_size = [weight.shape[0], weight.shape[1]] + group_size = config.group_size + if group_size is None: + group_size = weight.shape[-1] + block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) new_weight = Int8Tensor.from_hp(weight, block_size=block_size) return new_weight From 55c9cb75d8d624907442a3cf8738fc09346effa3 Mon Sep 17 00:00:00 2001 From: youn17 Date: Tue, 11 Nov 2025 19:01:41 +0900 Subject: [PATCH 09/19] fix quantization ops --- .../quantize_/workflows/int8/int8_tensor.py | 67 ++++++++++--------- 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 11633ebc24..947b9c872e 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -12,7 +12,6 @@ from torchao.float8.inference import ( _slice_scale_for_dimension, - preprocess_scale, ) from torchao.kernel import int_scaled_matmul from torchao.quantization.quant_primitives import ( @@ -169,21 +168,6 @@ def _(func, types, args, kwargs): f"Expected weight to be Int8Tensor, got {type(weight_tensor)}" ) - # Store original shape for reshaping result - original_weight_shape = weight_tensor.qdata.shape - - # Reshape 3D weights to 2D: (B, N, K) -> (B*N, K) - if weight_tensor.qdata.dim() == 3: - w_q_2d = weight_tensor.qdata.reshape(-1, original_weight_shape[-1]) - w_scale_2d = ( - weight_tensor.scale.reshape(-1) - if weight_tensor.scale.numel() > 1 - else weight_tensor.scale - ) - else: - w_q_2d = weight_tensor.qdata - w_scale_2d = weight_tensor.scale - if weight_tensor.act_quant_kwargs is not None: if not isinstance(activation_tensor, Int8Tensor): # Dynamic activation quantization @@ -202,35 +186,52 @@ def _(func, types, args, kwargs): activation_tensor, act_kwargs ) - x_vals = activation_tensor.qdata.reshape(-1, activation_tensor.qdata.shape[-1]) - x_scales = preprocess_scale(activation_tensor.scale, x_vals.shape) - w_vals_t = w_q_2d.contiguous().t() + # 1. do the matrix form of dot(X_i, W_j) + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a FP16 scale is greater than the maximum + # value of a FP16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + x_vals_int8 = activation_tensor.qdata + x_scales = activation_tensor.scale + w_vals_int8_t = weight_tensor.qdata.contiguous().t() + w_scales = weight_tensor.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + x_scales_dtype = x_scales.dtype + # Cast FP16 scale to float to avoid overflow in int_scaled_matmul intermediate_dtype = ( - torch.float if x_scales.dtype == torch.half else x_scales.dtype + torch.float if x_scales_dtype == torch.half else x_scales_dtype ) - y_dot_scaled = int_scaled_matmul( - x_vals, w_vals_t, x_scales.to(intermediate_dtype) + tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) ) - y_dot_scaled = y_dot_scaled.to(activation_tensor.scale.dtype) + y_dot_scaled = y_dot_scaled.to(x_scales_dtype) - result = (y_dot_scaled * w_scale_2d).reshape( - *activation_tensor.shape[:-1], *original_weight_shape[:-1] + y = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] ) - result = result.to(activation_tensor.dtype) + + # can downcast only at the very end + output_dtype = activation_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y else: # FP × INT8 (weight-only) - w_vals_int8_t = w_q_2d.t() + w_vals_int8_t = weight_tensor.qdata.t() m = torch.mm( activation_tensor.reshape(-1, activation_tensor.shape[-1]), w_vals_int8_t.to(activation_tensor.dtype), ) - result = m * w_scale_2d.to(m.dtype) - result = result.reshape( - *activation_tensor.shape[:-1], *original_weight_shape[:-1] - ) - - return result + bias if bias is not None else result + y = m * weight_tensor.scale.to(m.dtype) + y = y.reshape(*activation_tensor.shape[:-1], weight_tensor.qdata.shape[0]) + if bias is not None: + y += bias + return y @implements(aten.slice.Tensor) From 2a483fda4cfc85e341f3b86991882effcb0bbfe5 Mon Sep 17 00:00:00 2001 From: youn17 Date: Thu, 13 Nov 2025 01:55:11 +0900 Subject: [PATCH 10/19] merge test cases with cleanup --- .../workflows/int8/test_int8_tensor.py | 35 ++++--------------- 1 file changed, 7 insertions(+), 28 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index d687107b8c..2c24f0085e 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -59,10 +59,6 @@ def setUp(self): self.batch_size = 32 torch.manual_seed(42) - self.weight_fp = torch.randn(*self.test_shape, dtype=self.dtype) - self.input_fp = torch.randn(*self.test_shape, dtype=self.dtype) - self.bias = torch.randn(self.test_shape[0], dtype=self.dtype) - self.block_size = list(self.test_shape) @common_utils.parametrize( "config", @@ -80,16 +76,13 @@ def test_creation_and_attributes(self, config): dtype=self.dtype, device="cuda", ) - linear.weight.data = self.weight_fp.cuda() quantize_(linear, config) - tensor = linear.weight + w = linear.weight - self.assertEqual(tensor.shape, self.test_shape) - self.assertEqual(tensor.qdata.dtype, torch.int8) - self.assertTrue( - torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127) - ) + self.assertEqual(w.shape, self.test_shape) + self.assertEqual(w.qdata.dtype, torch.int8) + self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127)) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("compile", [True, False]) @@ -122,6 +115,9 @@ def test_int8_linear_variants( quantize_(model_q, config) + self.assertEqual(model_q.linear2.weight.scale.shape, (K,)) + self.assertEqual(model_q.linear2.weight.scale.ndim, 1) + if compile: model_q = torch.compile(model_q, fullgraph=True) @@ -132,23 +128,6 @@ def test_int8_linear_variants( f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" ) - @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig(version=2), - Int8WeightOnlyConfig(version=2), - ], - ) - def test_per_row_scale_shape(self, dtype, config): - """Test per-row quantization maintains 1D scale""" - N, K = 64, 128 - linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") - quantize_(linear, config) - - self.assertEqual(linear.weight.scale.shape, (N,)) - self.assertEqual(linear.weight.scale.ndim, 1) - @common_utils.parametrize( "config", [ From 496191a9bb6c43c48a572e571b17691823f09f2f Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 16 Nov 2025 22:26:55 +0900 Subject: [PATCH 11/19] update `block_size` args to `granularity` --- .../workflows/int8/test_int8_tensor.py | 31 ++++++++++++----- torchao/float8/inference.py | 25 ++++---------- torchao/quantization/quant_api.py | 34 ++++++++++--------- .../common/quantize_tensor_kwargs.py | 2 +- .../quantize_/workflows/int8/int8_tensor.py | 22 +++++++----- 5 files changed, 62 insertions(+), 52 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 2c24f0085e..b492b189ff 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -17,6 +17,7 @@ Int8WeightOnlyConfig, quantize_, ) +from torchao.quantization.granularity import PerRow, PerTensor from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase @@ -160,24 +161,35 @@ def test_slice(self, config, device, dtype): @common_utils.parametrize( "config", [ - Int8DynamicActivationInt8WeightConfig(version=2), - Int8WeightOnlyConfig(version=2), + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, ], ) - def test_index_select(self, config): + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + def test_index_select(self, config, granularity): """test that `x_0 = x[0]` works when `x` is a 2D quantized tensor.""" N, K = 256, 512 x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda") linear.weight.data = x + + config = config(version=2, granularity=granularity) quantize_(linear, config) x_int8 = linear.weight x_int8_0 = x_int8[0] + + # Test dequantization consistency torch.testing.assert_close( x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0 ) + # Test block_size granularity + if isinstance(granularity, PerRow): + self.assertEqual(x_int8.block_size, [1, K]) + elif isinstance(granularity, PerTensor): + self.assertEqual(x_int8.block_size, [N, K]) + @common_utils.parametrize( "config", [ @@ -187,16 +199,17 @@ def test_index_select(self, config): ) def test_dequantization_accuracy(self, config): """Test dequantization accuracy separately""" - test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16, device="cuda") - linear = torch.nn.Linear(2, 1, bias=False, dtype=torch.bfloat16, device="cuda") - linear.weight.data = test_data + linear = torch.nn.Linear( + 256, 512, bias=False, dtype=torch.bfloat16, device="cuda" + ) + weight_fp = copy.deepcopy(linear.weight) quantize_(linear, config) tensor = linear.weight dequantized = tensor.dequantize() - self.assertEqual(dequantized.shape, test_data.shape) - assert compute_error(dequantized, test_data) > 20, ( - f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, test_data)}" + self.assertEqual(dequantized.shape, weight_fp.shape) + assert compute_error(dequantized, weight_fp) > 20, ( + f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, weight_fp)}" ) @common_utils.parametrize( diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 42edfa6afa..212df9c5db 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -142,18 +142,7 @@ def _slice_scale_for_dimension( """ aten = torch.ops.aten - # Per-tensor quantization (scalar scale) - if scale.numel() == 1: - return scale - - # Per-row quantization (1D scale) - if scale.ndim == 1: - if dim == 0: - return aten.slice.Tensor(scale, 0, start, end, step) - else: - return scale - - # Block-wise quantization (2D scale) + # Unsupported case for now, this would be 1 scale per data element if scale.shape == data_shape: return aten.slice.Tensor(scale, dim, start, end, step) @@ -171,12 +160,6 @@ def _slice_scale_for_dimension( # Slice away as normal return aten.slice.Tensor(scale, dim, start, end, step) else: - # Error on Step > 1 - if step > 1: - raise NotImplementedError( - "Slicing with step > 1 is not implemented for scale tensors." - ) - # There is blocking in this dimension # Calculate which scale elements correspond to the sliced data scale_start = start // block_size_for_dim if start is not None else None @@ -186,6 +169,12 @@ def _slice_scale_for_dimension( else None ) + # Error on Step > 1 + if step > 1: + raise NotImplementedError( + "Slicing with step > 1 is not implemented for scale tensors." + ) + return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c9146e24cb..8e5dc20b7f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1339,13 +1339,16 @@ class Int8WeightOnlyConfig(AOBaseConfig): Configuration for applying int8 weight-only symmetric per-channel quantization to linear layers. Args: - group_size: Optional[int] = None - Controls the granularity of quantization. If None, applies per-channel quantization. - Otherwise, applies per-group quantization with the specified group size. + group_size (version 1) - Controls the granularity of quantization. + If None, applies per-channel quantization. Otherwise, applies per-group quantization with the specified group size. + granularity (version 2) - Quantization granularity. + PerRow() for per-channel quantization, PerTensor() for per-tensor quantization. set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values for better performance with this quantization scheme. """ group_size: Optional[int] = None + granularity: Optional[Union[PerRow, PerTensor]] = PerRow() set_inductor_config: bool = True version: int = 1 @@ -1380,11 +1383,7 @@ def _int8_weight_only_quantize_tensor(weight, config): ) else: assert config.version == 2, f"Unexpected version: {config.version}" - group_size = config.group_size - if group_size is None: - group_size = weight.shape[-1] - block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size]) - new_weight = Int8Tensor.from_hp(weight, block_size=block_size) + new_weight = Int8Tensor.from_hp(weight, granularity=config.granularity) return new_weight @@ -1583,17 +1582,17 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): else: input_quant_func = _int8_asymm_per_token_quant - if isinstance(config.granularity, PerTensor): - # Tensor granularity - block_size = weight.shape - else: - # Per row granularity - block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]) - if config.version == 1: warnings.warn( "Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" ) + if isinstance(config.granularity, PerTensor): + block_size = weight.shape + else: + block_size = tuple( + [1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]] + ) + quantized_weight = to_affine_quantized_intx( weight, mapping_type, @@ -1613,10 +1612,13 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): ) assert config.version == 2, f"Unexpected version: {config.version}" + # Compute block_size from granularity for activation quantization kwargs + block_size = get_block_size(weight.shape, config.granularity) + quantized_weight = Int8Tensor.from_hp( weight, - block_size, - act_quant_kwargs=QuantizeTensorToInt8Kwargs(block_size=block_size), + granularity=config.granularity, + act_quant_kwargs=QuantizeTensorToInt8Kwargs(block_size=list(block_size)), ) return quantized_weight diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py index 44dd09ff62..15540e34c8 100644 --- a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -57,7 +57,7 @@ def _choose_quant_func_and_quantize_tensor( elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): return Int8Tensor.from_hp( tensor, - quant_kwargs.block_size, + granularity=quant_kwargs.granularity, act_quant_kwargs=quant_kwargs, ) diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 947b9c872e..b9b4477842 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -14,6 +14,7 @@ _slice_scale_for_dimension, ) from torchao.kernel import int_scaled_matmul +from torchao.quantization.granularity import PerRow from torchao.quantization.quant_primitives import ( MappingType, _maybe_expand_scale_to_tensor_shape, @@ -24,6 +25,7 @@ QuantizeTensorKwargs, _choose_quant_func_and_quantize_tensor, ) +from torchao.quantization.utils import get_block_size from torchao.utils import TorchAOBaseTensor, fill_defaults __all__ = ["Int8Tensor", "QuantizeTensorToInt8Kwargs"] @@ -37,10 +39,12 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): Args: block_size (list[int]): block size for quantization granularity + granularity: the granularity for the Tensor, currently either PerRow() or PerTensor() # TODO: Static quantization support using `static_scale`, `static_zero_point` """ block_size: list[int] + granularity = PerRow() class Int8Tensor(TorchAOBaseTensor): @@ -101,26 +105,28 @@ def __repr__(self): @classmethod def from_hp( cls, - w: torch.Tensor, - block_size: list[int], + w_hp: torch.Tensor, + granularity=PerRow(), act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, ): - if w.dim() not in [2, 3] or len(block_size) != w.dim(): + block_size = list(get_block_size(w_hp.shape, granularity)) + + if w_hp.dim() not in [2, 3] or len(block_size) != w_hp.dim(): raise ValueError("Expected 2D or 3D tensor with same block_size length") scale, zero_point = choose_qparams_affine( - input=w, + input=w_hp, mapping_type=MappingType.SYMMETRIC, block_size=block_size, target_dtype=torch.int8, quant_min=-128, quant_max=127, - scale_dtype=w.dtype, + scale_dtype=w_hp.dtype, zero_point_dtype=torch.int8, ) int_data = quantize_affine( - w, + w_hp, block_size=block_size, scale=scale, zero_point=zero_point, @@ -132,7 +138,7 @@ def from_hp( scale, block_size, act_quant_kwargs=act_quant_kwargs, - dtype=w.dtype, + dtype=w_hp.dtype, ) def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: @@ -290,7 +296,7 @@ def _(func, types, args, kwargs): Int8Tensor( selected_qdata, selected_scale, - [selected_qdata.shape[-1]], + self.block_size[1:], self.act_quant_kwargs, self.dtype, ), From 83634f0a6237a06c16a9a0b738612cef4f244e8e Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 16 Nov 2025 23:09:29 +0900 Subject: [PATCH 12/19] update expected kernel test --- .../workflows/int8/test_int8_tensor.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index b492b189ff..07dcc225fd 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -212,23 +212,27 @@ def test_dequantization_accuracy(self, config): f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, weight_fp)}" ) - @common_utils.parametrize( - "kernel", - ["triton_per_fused", "extern_kernels._int_mm", "triton_poi_fused"], - ) - def test_available_gpu_kernels(self, kernel): - """Check which GPU kernels are available""" + def test_available_gpu_kernels(self): + """Check which GPU kernels are used""" + torch.compiler.reset() + M, K, N = 128, 256, 512 m = torch.nn.Sequential( torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) ) + config = Int8DynamicActivationInt8WeightConfig(version=2) quantize_(m, config) + m = torch.compile(m) x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) out, code = run_and_get_code(m, x) - FileCheck().check(kernel).run(code[0]) + + # Check expected kernels are present + FileCheck().check_count("triton_per_fused", 1).check_count( + "extern_kernels._int_mm", 1 + ).check_count("triton_poi_fused", 1).run(code[0]) if __name__ == "__main__": From 9bc1ad56dbf9e8d81408e6e52b4d0af64cfc8f4d Mon Sep 17 00:00:00 2001 From: youn17 Date: Thu, 20 Nov 2025 19:19:55 +0900 Subject: [PATCH 13/19] use Granularity for slicing logic instead of block_size --- .../workflows/int8/test_int8_tensor.py | 10 ++-- torchao/quantization/quant_api.py | 4 +- .../quantize_/workflows/int8/int8_tensor.py | 54 ++++++++----------- 3 files changed, 30 insertions(+), 38 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 07dcc225fd..c2f099fcde 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -18,7 +18,7 @@ quantize_, ) from torchao.quantization.granularity import PerRow, PerTensor -from torchao.quantization.utils import compute_error +from torchao.quantization.utils import compute_error, get_block_size from torchao.testing.utils import TorchAOIntegrationTestCase @@ -186,9 +186,13 @@ def test_index_select(self, config, granularity): # Test block_size granularity if isinstance(granularity, PerRow): - self.assertEqual(x_int8.block_size, [1, K]) + self.assertEqual( + list(get_block_size(x_int8.shape, x_int8.granularity)), [1, K] + ) elif isinstance(granularity, PerTensor): - self.assertEqual(x_int8.block_size, [N, K]) + self.assertEqual( + list(get_block_size(x_int8.shape, x_int8.granularity)), [N, K] + ) @common_utils.parametrize( "config", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 8e5dc20b7f..0dc65d561f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1612,13 +1612,11 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): ) assert config.version == 2, f"Unexpected version: {config.version}" - # Compute block_size from granularity for activation quantization kwargs - block_size = get_block_size(weight.shape, config.granularity) quantized_weight = Int8Tensor.from_hp( weight, granularity=config.granularity, - act_quant_kwargs=QuantizeTensorToInt8Kwargs(block_size=list(block_size)), + act_quant_kwargs=QuantizeTensorToInt8Kwargs(granularity=config.granularity), ) return quantized_weight diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index b9b4477842..51739f9229 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -17,8 +17,8 @@ from torchao.quantization.granularity import PerRow from torchao.quantization.quant_primitives import ( MappingType, - _maybe_expand_scale_to_tensor_shape, choose_qparams_affine, + dequantize_affine, quantize_affine, ) from torchao.quantization.quantize_.common import ( @@ -38,13 +38,11 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): """Tensor kwargs for creating int8 tensor (either activation or weight) Args: - block_size (list[int]): block size for quantization granularity granularity: the granularity for the Tensor, currently either PerRow() or PerTensor() # TODO: Static quantization support using `static_scale`, `static_zero_point` """ - block_size: list[int] - granularity = PerRow() + granularity: object = PerRow() class Int8Tensor(TorchAOBaseTensor): @@ -56,12 +54,12 @@ class Int8Tensor(TorchAOBaseTensor): scale: scale factors for dequantization Non-Tensor Attributes: - block_size: block size for quantization granularity + granularity: the granularity for quantization (e.g., PerRow(), PerTensor()) act_quant_kwargs: flags for dynamic activation quantization """ tensor_data_names = ["qdata", "scale"] - tensor_attribute_names = ["block_size"] + tensor_attribute_names = ["granularity"] optional_tensor_attribute_names = [ "act_quant_kwargs", "dtype", @@ -86,20 +84,20 @@ def __init__( self, qdata: torch.Tensor, scale: torch.Tensor, - block_size: list[int], + granularity, act_quant_kwargs=None, dtype=None, ): super().__init__() self.qdata = qdata self.scale = scale - self.block_size = block_size + self.granularity = granularity self.act_quant_kwargs = act_quant_kwargs def __repr__(self): return ( f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, " - f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})" + f"{self.granularity=}, {self.shape=}, {self.device=}, {self.dtype=})" ) @classmethod @@ -109,7 +107,7 @@ def from_hp( granularity=PerRow(), act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, ): - block_size = list(get_block_size(w_hp.shape, granularity)) + block_size = get_block_size(w_hp.shape, granularity) if w_hp.dim() not in [2, 3] or len(block_size) != w_hp.dim(): raise ValueError("Expected 2D or 3D tensor with same block_size length") @@ -136,7 +134,7 @@ def from_hp( return cls( int_data, scale, - block_size, + granularity, act_quant_kwargs=act_quant_kwargs, dtype=w_hp.dtype, ) @@ -147,13 +145,18 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor if output_dtype is None: output_dtype = self.dtype - qdata_fp = self.qdata.to(output_dtype) - scale = self.scale - while scale.ndim < qdata_fp.ndim: - scale = scale.unsqueeze(-1) + block_size = get_block_size(self.qdata.shape, self.granularity) - scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, qdata_fp.shape) - return qdata_fp * scale_expanded.to(output_dtype) + return dequantize_affine( + input=self.qdata, + block_size=block_size, + scale=self.scale, + zero_point=None, + input_dtype=torch.int8, + quant_min=-128, + quant_max=127, + output_dtype=output_dtype, + ) implements = Int8Tensor.implements @@ -178,15 +181,6 @@ def _(func, types, args, kwargs): if not isinstance(activation_tensor, Int8Tensor): # Dynamic activation quantization act_kwargs = weight_tensor.act_quant_kwargs - input_ndim = activation_tensor.ndim - - # Ensure block_size matches input tensor dimensions - if len(act_kwargs.block_size) != input_ndim: - if input_ndim == 3 and len(act_kwargs.block_size) == 2: - block_size_updated = [1] + list(act_kwargs.block_size) - else: - block_size_updated = list(act_kwargs.block_size)[-input_ndim:] - act_kwargs = QuantizeTensorToInt8Kwargs(block_size=block_size_updated) activation_tensor = _choose_quant_func_and_quantize_tensor( activation_tensor, act_kwargs @@ -261,10 +255,6 @@ def _(func, types, args, kwargs): self.scale, self.qdata.shape, dim, start, end, step ) - block_size = list(self.block_size) - for i in range(len(block_size)): - block_size[i] = min(block_size[i], sliced_qdata.shape[i]) - return return_and_correct_aliasing( func, args, @@ -272,7 +262,7 @@ def _(func, types, args, kwargs): Int8Tensor( sliced_qdata, sliced_scale, - block_size, + self.granularity, self.act_quant_kwargs, dtype=self.dtype, ), @@ -296,7 +286,7 @@ def _(func, types, args, kwargs): Int8Tensor( selected_qdata, selected_scale, - self.block_size[1:], + self.granularity, self.act_quant_kwargs, self.dtype, ), From 56faf2fb8dd637bd36b158b4650ea0960b08aa56 Mon Sep 17 00:00:00 2001 From: youn17 Date: Thu, 20 Nov 2025 19:20:19 +0900 Subject: [PATCH 14/19] update tensor slicing for per-tensor/row/block scales --- torchao/float8/inference.py | 61 +++++++++++++++++++++---------------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index 212df9c5db..25e0dfd7aa 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -139,43 +139,50 @@ def _slice_scale_for_dimension( Slice the scale tensor appropriately based on the data tensor slicing. This function calculates how the scale should be sliced when the data tensor is sliced along a given dimension, taking into account the block structure. - """ - aten = torch.ops.aten - # Unsupported case for now, this would be 1 scale per data element - if scale.shape == data_shape: - return aten.slice.Tensor(scale, dim, start, end, step) + Example: + If data_shape is [256, 128] and scale shape is [1] (indicating per-tensor scaling), + slicing along any dimension should return the same scale tensor. - # Reconstruct block sizes based on data shape and scale shape - block_sizes = tuple(data_shape[i] // scale.shape[i] for i in range(len(data_shape))) + If data_shape is [256, 128] and scale shape is [256] (indicating per-row scaling), + and we slice data along dim=0 from 64 to 192, the corresponding scale + """ + aten = torch.ops.aten - if dim >= len(block_sizes): - # Slicing beyond the dimensions we care about + # Case 1: Per-tensor quantization (scalar scale) + if scale.numel() <= 1: return scale + # Case 2: Per-row quantization (1D scale) + # Scale is per-element along this dimension + if scale.ndim == 1: + if dim == 0: + return aten.slice.Tensor(scale, 0, start, end, step) + else: + return scale + + # Case 3: Per-block quantization (2D scale) + block_sizes = tuple( + data_shape[i] // scale.shape[i] for i in range(len(scale.shape)) + ) + block_size_for_dim = block_sizes[dim] - if block_size_for_dim == 1: - # Scale is per-element along this dimension - # Slice away as normal - return aten.slice.Tensor(scale, dim, start, end, step) - else: - # There is blocking in this dimension - # Calculate which scale elements correspond to the sliced data - scale_start = start // block_size_for_dim if start is not None else None - scale_end = ( - (end + block_size_for_dim - 1) // block_size_for_dim - if end is not None - else None + if step > 1: + raise NotImplementedError( + "Slicing with step > 1 is not implemented for scale tensors." ) - # Error on Step > 1 - if step > 1: - raise NotImplementedError( - "Slicing with step > 1 is not implemented for scale tensors." - ) + # There is blocking in this dimension + # Calculate which scale elements correspond to the sliced data + scale_start = start // block_size_for_dim if start is not None else None + scale_end = ( + (end + block_size_for_dim - 1) // block_size_for_dim + if end is not None + else None + ) - return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) + return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) def _is_rowwise_scaled(x: torch.Tensor) -> bool: From 3fe2282ddb846bda45940b3d033e3762c6199eac Mon Sep 17 00:00:00 2001 From: youn17 Date: Thu, 20 Nov 2025 19:37:53 +0900 Subject: [PATCH 15/19] fix __new__/__init__ signatures and formatting --- .../quantize_/workflows/int8/int8_tensor.py | 94 ++++++++++--------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 51739f9229..dceb0964aa 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -10,11 +10,9 @@ import torch from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.float8.inference import ( - _slice_scale_for_dimension, -) +from torchao.float8.inference import _slice_scale_for_dimension from torchao.kernel import int_scaled_matmul -from torchao.quantization.granularity import PerRow +from torchao.quantization.granularity import Granularity, PerRow from torchao.quantization.quant_primitives import ( MappingType, choose_qparams_affine, @@ -39,10 +37,9 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): Args: granularity: the granularity for the Tensor, currently either PerRow() or PerTensor() - # TODO: Static quantization support using `static_scale`, `static_zero_point` """ - granularity: object = PerRow() + granularity: Granularity = PerRow() class Int8Tensor(TorchAOBaseTensor): @@ -52,26 +49,25 @@ class Int8Tensor(TorchAOBaseTensor): Tensor Attributes: qdata: (N, K) or (B, N, K) int8 quantized weight data (2D or 3D) scale: scale factors for dequantization + # TODO: Static quantization support using `static_scale` Non-Tensor Attributes: granularity: the granularity for quantization (e.g., PerRow(), PerTensor()) act_quant_kwargs: flags for dynamic activation quantization """ + # TODO: Static quantization support using `static_scale` tensor_data_names = ["qdata", "scale"] tensor_attribute_names = ["granularity"] - optional_tensor_attribute_names = [ - "act_quant_kwargs", - "dtype", - ] + optional_tensor_attribute_names = ["act_quant_kwargs", "dtype"] def __new__( cls: type, qdata: torch.Tensor, scale: torch.Tensor, - block_size: list[int], - act_quant_kwargs=None, - dtype=None, + granularity: Granularity, + act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + dtype: Optional[torch.dtype] = None, ): kwargs = { "device": qdata.device, @@ -84,9 +80,9 @@ def __init__( self, qdata: torch.Tensor, scale: torch.Tensor, - granularity, - act_quant_kwargs=None, - dtype=None, + granularity: Granularity, + act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + dtype: Optional[torch.dtype] = None, ): super().__init__() self.qdata = qdata @@ -96,21 +92,31 @@ def __init__( def __repr__(self): return ( - f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, " - f"{self.granularity=}, {self.shape=}, {self.device=}, {self.dtype=})" + f"{self.__class__.__name__}(" + f"act_quant_kwargs={self.act_quant_kwargs}, " + f"qdata={self.qdata}, " + f"scale={self.scale}, " + f"granularity={self.granularity}, " + f"shape={self.shape}, " + f"device={self.device}, " + f"dtype={self.dtype})" ) @classmethod def from_hp( cls, w_hp: torch.Tensor, - granularity=PerRow(), + granularity: Granularity = PerRow(), act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, ): + """Create Int8Tensor from high-precision tensor""" block_size = get_block_size(w_hp.shape, granularity) if w_hp.dim() not in [2, 3] or len(block_size) != w_hp.dim(): - raise ValueError("Expected 2D or 3D tensor with same block_size length") + raise ValueError( + f"Expected 2D or 3D tensor with matching block_size dimensions, " + f"got tensor dim={w_hp.dim()}, block_size length={len(block_size)}" + ) scale, zero_point = choose_qparams_affine( input=w_hp, @@ -141,7 +147,6 @@ def from_hp( def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize int8 tensor to floating point""" - if output_dtype is None: output_dtype = self.dtype @@ -173,17 +178,16 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) - assert isinstance(weight_tensor, Int8Tensor), ( - f"Expected weight to be Int8Tensor, got {type(weight_tensor)}" - ) + if not isinstance(weight_tensor, Int8Tensor): + raise TypeError(f"Expected weight to be Int8Tensor, got {type(weight_tensor)}") + + output_dtype = activation_tensor.dtype if weight_tensor.act_quant_kwargs is not None: + # Dynamic activation quantization path if not isinstance(activation_tensor, Int8Tensor): - # Dynamic activation quantization - act_kwargs = weight_tensor.act_quant_kwargs - activation_tensor = _choose_quant_func_and_quantize_tensor( - activation_tensor, act_kwargs + activation_tensor, weight_tensor.act_quant_kwargs ) # 1. do the matrix form of dot(X_i, W_j) @@ -199,6 +203,7 @@ def _(func, types, args, kwargs): x_scales = activation_tensor.scale w_vals_int8_t = weight_tensor.qdata.contiguous().t() w_scales = weight_tensor.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) x_scales_dtype = x_scales.dtype # Cast FP16 scale to float to avoid overflow in int_scaled_matmul @@ -214,12 +219,6 @@ def _(func, types, args, kwargs): *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] ) - # can downcast only at the very end - output_dtype = activation_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y else: # FP × INT8 (weight-only) w_vals_int8_t = weight_tensor.qdata.t() @@ -229,9 +228,11 @@ def _(func, types, args, kwargs): ) y = m * weight_tensor.scale.to(m.dtype) y = y.reshape(*activation_tensor.shape[:-1], weight_tensor.qdata.shape[0]) - if bias is not None: - y += bias - return y + + if bias is not None: + y += bias + + return y.to(output_dtype) @implements(aten.slice.Tensor) @@ -240,14 +241,17 @@ def _(func, types, args, kwargs): self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if step != 1: - raise NotImplementedError("Slicing with step > 1 is not supported") + raise NotImplementedError( + f"Slicing with step != 1 is not supported, got step={step}" + ) - assert dim in [0, 1, 2], f"Only dim=0,1,2 are supported, got: dim={dim}" - assert self.qdata.ndim in [2, 3], ( - f"Expected qdata to have dim=2,3 got: dim={self.qdata.ndim}" - ) + if dim not in [0, 1, 2]: + raise ValueError(f"Only dim in [0, 1, 2] supported, got dim={dim}") + + if self.qdata.ndim not in [2, 3]: + raise ValueError(f"Expected qdata to be 2D or 3D, got {self.qdata.ndim}D") - if end >= self.shape[dim]: + if end is None or end > self.shape[dim]: end = self.shape[dim] sliced_qdata = aten.slice.Tensor(self.qdata, dim, start, end, step) @@ -271,8 +275,10 @@ def _(func, types, args, kwargs): @implements(aten.select.int) def _(func, types, args, kwargs): + """Select operation for Int8Tensor""" self, dim, index = args - assert dim == 0, f"Only dim=0 supported, got {dim}" + if dim != 0: + raise NotImplementedError(f"Only dim=0 supported, got dim={dim}") selected_qdata = self.qdata[index] selected_scale = _slice_scale_for_dimension( From 3ad4898f61076098008e9f25e73d0e22c0b2a2bd Mon Sep 17 00:00:00 2001 From: youn17 Date: Sat, 22 Nov 2025 20:53:34 +0900 Subject: [PATCH 16/19] reland toy linear model --- .../workflows/int8/test_int8_tensor.py | 28 +------------------ 1 file changed, 1 insertion(+), 27 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index c2f099fcde..70cfd49253 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -19,36 +19,10 @@ ) from torchao.quantization.granularity import PerRow, PerTensor from torchao.quantization.utils import compute_error, get_block_size +from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.testing.utils import TorchAOIntegrationTestCase -# TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged -class ToyTwoLinearModel(torch.nn.Module): - def __init__( - self, - input_dim, - hidden_dim, - output_dim, - has_bias=False, - dtype=None, - device=None, - ): - super().__init__() - self.dtype = dtype - self.device = device - self.linear1 = torch.nn.Linear( - input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device - ) - self.linear2 = torch.nn.Linear( - hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device - ) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.instantiate_parametrized_tests class TestInt8Tensor(TorchAOIntegrationTestCase): From 1b817ec9be342d036a34048d4cbe0a7e496e415c Mon Sep 17 00:00:00 2001 From: youn17 Date: Sat, 22 Nov 2025 21:05:17 +0900 Subject: [PATCH 17/19] update casting logic --- .../quantization/quantize_/workflows/int8/int8_tensor.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index dceb0964aa..0c1ec93af0 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -205,16 +205,13 @@ def _(func, types, args, kwargs): w_scales = weight_tensor.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - x_scales_dtype = x_scales.dtype # Cast FP16 scale to float to avoid overflow in int_scaled_matmul intermediate_dtype = ( - torch.float if x_scales_dtype == torch.half else x_scales_dtype + torch.float if x_scales.dtype == torch.half else x_scales.dtype ) y_dot_scaled = int_scaled_matmul( tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) - ) - y_dot_scaled = y_dot_scaled.to(x_scales_dtype) - + ).to(output_dtype) y = (y_dot_scaled * w_scales).reshape( *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] ) @@ -224,7 +221,7 @@ def _(func, types, args, kwargs): w_vals_int8_t = weight_tensor.qdata.t() m = torch.mm( activation_tensor.reshape(-1, activation_tensor.shape[-1]), - w_vals_int8_t.to(activation_tensor.dtype), + w_vals_int8_t.to(output_dtype), ) y = m * weight_tensor.scale.to(m.dtype) y = y.reshape(*activation_tensor.shape[:-1], weight_tensor.qdata.shape[0]) From 3af0a3cbb646c675f0ec7f7024875c3661e1f25e Mon Sep 17 00:00:00 2001 From: youn17 Date: Sat, 22 Nov 2025 22:50:58 +0900 Subject: [PATCH 18/19] add block_size attribute, separate version 1 from 2 --- torchao/quantization/quant_api.py | 20 +++++++++---------- .../quantize_/workflows/int8/int8_tensor.py | 15 ++++++++++---- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 0dc65d561f..f1c1f59023 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1348,7 +1348,7 @@ class Int8WeightOnlyConfig(AOBaseConfig): """ group_size: Optional[int] = None - granularity: Optional[Union[PerRow, PerTensor]] = PerRow() + granularity: Optional[Granularity] = PerRow() set_inductor_config: bool = True version: int = 1 @@ -1573,15 +1573,6 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 - if weight_only_decode: - input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode - else: - # input settings - if act_mapping_type == MappingType.SYMMETRIC: - input_quant_func = _int8_symm_per_token_reduced_range_quant - else: - input_quant_func = _int8_asymm_per_token_quant - if config.version == 1: warnings.warn( "Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" @@ -1593,6 +1584,15 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): [1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]] ) + if weight_only_decode: + input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode + else: + # input settings + if act_mapping_type == MappingType.SYMMETRIC: + input_quant_func = _int8_symm_per_token_reduced_range_quant + else: + input_quant_func = _int8_asymm_per_token_quant + quantized_weight = to_affine_quantized_intx( weight, mapping_type, diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 0c1ec93af0..60fa66e957 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -59,13 +59,14 @@ class Int8Tensor(TorchAOBaseTensor): # TODO: Static quantization support using `static_scale` tensor_data_names = ["qdata", "scale"] tensor_attribute_names = ["granularity"] - optional_tensor_attribute_names = ["act_quant_kwargs", "dtype"] + optional_tensor_attribute_names = ["act_quant_kwargs", "block_size", "dtype"] def __new__( cls: type, qdata: torch.Tensor, scale: torch.Tensor, granularity: Granularity, + block_size: Optional[torch.Size] = None, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, dtype: Optional[torch.dtype] = None, ): @@ -81,6 +82,7 @@ def __init__( qdata: torch.Tensor, scale: torch.Tensor, granularity: Granularity, + block_size: Optional[torch.Size] = None, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, dtype: Optional[torch.dtype] = None, ): @@ -88,6 +90,7 @@ def __init__( self.qdata = qdata self.scale = scale self.granularity = granularity + self.block_size = block_size or get_block_size(qdata.shape, granularity) self.act_quant_kwargs = act_quant_kwargs def __repr__(self): @@ -97,6 +100,7 @@ def __repr__(self): f"qdata={self.qdata}, " f"scale={self.scale}, " f"granularity={self.granularity}, " + f"block_size={self.block_size}, " f"shape={self.shape}, " f"device={self.device}, " f"dtype={self.dtype})" @@ -141,6 +145,7 @@ def from_hp( int_data, scale, granularity, + block_size=block_size, act_quant_kwargs=act_quant_kwargs, dtype=w_hp.dtype, ) @@ -264,7 +269,8 @@ def _(func, types, args, kwargs): sliced_qdata, sliced_scale, self.granularity, - self.act_quant_kwargs, + block_size=get_block_size(sliced_qdata.shape, self.granularity), + act_quant_kwargs=self.act_quant_kwargs, dtype=self.dtype, ), ) @@ -290,8 +296,9 @@ def _(func, types, args, kwargs): selected_qdata, selected_scale, self.granularity, - self.act_quant_kwargs, - self.dtype, + block_size=get_block_size(selected_qdata.shape, self.granularity), + act_quant_kwargs=self.act_quant_kwargs, + dtype=self.dtype, ), ) From c0f090f23d5093d70eea5bc22ee8664b3555ddd8 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sun, 23 Nov 2025 04:31:31 +0900 Subject: [PATCH 19/19] fix activation kwargs --- .../quantize_/common/quantize_tensor_kwargs.py | 8 -------- .../quantize_/workflows/int8/int8_tensor.py | 14 +++++--------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py index 15540e34c8..0adc8c786d 100644 --- a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -39,9 +39,7 @@ def _choose_quant_func_and_quantize_tensor( """ from torchao.quantization.quantize_.workflows import ( Float8Tensor, - Int8Tensor, QuantizeTensorToFloat8Kwargs, - QuantizeTensorToInt8Kwargs, ) if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): @@ -54,11 +52,5 @@ def _choose_quant_func_and_quantize_tensor( quant_kwargs.hp_value_ub, quant_kwargs.kernel_preference, ) - elif isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): - return Int8Tensor.from_hp( - tensor, - granularity=quant_kwargs.granularity, - act_quant_kwargs=quant_kwargs, - ) raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 60fa66e957..4d902b9ae9 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -19,10 +19,7 @@ dequantize_affine, quantize_affine, ) -from torchao.quantization.quantize_.common import ( - QuantizeTensorKwargs, - _choose_quant_func_and_quantize_tensor, -) +from torchao.quantization.quantize_.common import QuantizeTensorKwargs from torchao.quantization.utils import get_block_size from torchao.utils import TorchAOBaseTensor, fill_defaults @@ -65,7 +62,7 @@ def __new__( cls: type, qdata: torch.Tensor, scale: torch.Tensor, - granularity: Granularity, + granularity: Optional[Granularity] = None, block_size: Optional[torch.Size] = None, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, dtype: Optional[torch.dtype] = None, @@ -189,11 +186,10 @@ def _(func, types, args, kwargs): output_dtype = activation_tensor.dtype if weight_tensor.act_quant_kwargs is not None: + activation_tensor = Int8Tensor.from_hp( + activation_tensor, weight_tensor.act_quant_kwargs.granularity + ) # Dynamic activation quantization path - if not isinstance(activation_tensor, Int8Tensor): - activation_tensor = _choose_quant_func_and_quantize_tensor( - activation_tensor, weight_tensor.act_quant_kwargs - ) # 1. do the matrix form of dot(X_i, W_j) #