Skip to content

Commit c53dad0

Browse files
committed
link new API with old API using version 2
1 parent 9a51cae commit c53dad0

File tree

3 files changed

+152
-31
lines changed

3 files changed

+152
-31
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,19 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import copy
78
import unittest
9+
from contextlib import nullcontext
10+
from typing import Tuple
811

912
import torch
10-
from torch.testing._internal.common_utils import run_tests
13+
from torch.testing._internal import common_utils
1114

15+
from torchao.quantization import (
16+
Int8DynamicActivationInt8WeightConfig,
17+
Int8WeightOnlyConfig,
18+
quantize_,
19+
)
1220
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
1321
Int8Tensor,
1422
QuantizeTensorToInt8Kwargs,
@@ -17,7 +25,46 @@
1725
from torchao.testing.utils import TorchAOIntegrationTestCase
1826

1927

28+
# TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged
29+
class ToyTwoLinearModel(torch.nn.Module):
30+
def __init__(
31+
self,
32+
input_dim,
33+
hidden_dim,
34+
output_dim,
35+
has_bias=False,
36+
dtype=None,
37+
device=None,
38+
):
39+
super().__init__()
40+
self.dtype = dtype
41+
self.device = device
42+
self.linear1 = torch.nn.Linear(
43+
input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device
44+
)
45+
self.linear2 = torch.nn.Linear(
46+
hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device
47+
)
48+
49+
# Note: tinygemm kernel only uses bfloat16 inputs
50+
def example_inputs(self, batch_size=1):
51+
return (
52+
torch.randn(
53+
batch_size,
54+
self.linear1.in_features,
55+
dtype=self.dtype,
56+
device=self.device,
57+
),
58+
)
59+
60+
def forward(self, x):
61+
x = self.linear1(x)
62+
x = self.linear2(x)
63+
return x
64+
65+
2066
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
67+
@common_utils.instantiate_parametrized_tests
2168
class TestInt8Tensor(TorchAOIntegrationTestCase):
2269
def setUp(self):
2370
super().setUp()
@@ -37,6 +84,56 @@ def test_creation_and_attributes(self):
3784
torch.all(tensor.qdata >= -128) and torch.all(tensor.qdata <= 127)
3885
)
3986

87+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
88+
@common_utils.parametrize("compile", [False, True])
89+
@common_utils.parametrize(
90+
"sizes",
91+
[
92+
((128,), 256, 128),
93+
((32, 128), 64, 256),
94+
],
95+
)
96+
@common_utils.parametrize(
97+
"config",
98+
[
99+
Int8DynamicActivationInt8WeightConfig(version=2),
100+
Int8WeightOnlyConfig(version=2),
101+
],
102+
)
103+
def test_int8_linear_variants(
104+
self,
105+
dtype: torch.dtype,
106+
compile: bool,
107+
sizes: Tuple,
108+
config,
109+
):
110+
error_message = None
111+
112+
error_context = (
113+
self.assertRaisesRegex(AssertionError, error_message)
114+
if error_message
115+
else nullcontext()
116+
)
117+
118+
with error_context:
119+
M, N, K = sizes
120+
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
121+
122+
# Create a linear layer
123+
m = ToyTwoLinearModel(K, N, K).eval().to(dtype).to("cuda")
124+
m_q = copy.deepcopy(m)
125+
126+
# Quantize
127+
quantize_(m_q, config)
128+
129+
output_original = m(input_tensor)
130+
output_quantized = m_q(input_tensor)
131+
132+
error = compute_error(output_original, output_quantized)
133+
assert compute_error(output_original, output_quantized) > 20, (
134+
f"Quantization error is too high got a SQNR of {error}"
135+
)
136+
40137
def test_linear_operations(self):
41138
"""Test fp+int8 and int8+int8 linear ops with quantization error check"""
42139
weight_q8 = Int8Tensor.from_hp(self.weight_fp, self.block_size)
@@ -85,4 +182,4 @@ def test_error_handling_and_dequant(self):
85182

86183

87184
if __name__ == "__main__":
88-
run_tests()
185+
common_utils.run_tests()

torchao/quantization/quant_api.py

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
Int4PreshuffledTensor,
7979
Int4Tensor,
8080
Int4TilePackedTo4dTensor,
81+
Int8Tensor,
8182
IntxOpaqueTensor,
8283
IntxPackingFormat,
8384
IntxUnpackedToInt8Tensor,
@@ -1352,10 +1353,12 @@ class Int8WeightOnlyConfig(AOBaseConfig):
13521353
Otherwise, applies per-group quantization with the specified group size.
13531354
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values
13541355
for better performance with this quantization scheme.
1356+
version: int = 2 - Version of the config to use. Version 1 uses AffineQuantization for quantization,
13551357
"""
13561358

13571359
group_size: Optional[int] = None
13581360
set_inductor_config: bool = True
1361+
version: int = 1
13591362

13601363
def __post_init__(self):
13611364
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
@@ -1366,22 +1369,30 @@ def __post_init__(self):
13661369

13671370

13681371
def _int8_weight_only_quantize_tensor(weight, config):
1369-
mapping_type = MappingType.SYMMETRIC
1370-
target_dtype = torch.int8
1371-
eps = torch.finfo(torch.float32).eps
1372-
zero_point_dtype = torch.int64
1373-
group_size = config.group_size
1374-
if group_size is None:
1375-
group_size = weight.shape[-1]
1376-
block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size])
1377-
new_weight = to_affine_quantized_intx(
1378-
weight,
1379-
mapping_type,
1380-
block_size,
1381-
target_dtype,
1382-
eps=eps,
1383-
zero_point_dtype=zero_point_dtype,
1384-
)
1372+
if config.version == 1:
1373+
warnings.warn(
1374+
"Config Deprecation: version 1 of Int8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details"
1375+
)
1376+
mapping_type = MappingType.SYMMETRIC
1377+
target_dtype = torch.int8
1378+
eps = torch.finfo(torch.float32).eps
1379+
zero_point_dtype = torch.int64
1380+
group_size = config.group_size
1381+
if group_size is None:
1382+
group_size = weight.shape[-1]
1383+
block_size = tuple([1 for x in range(weight.dim() - 1)] + [group_size])
1384+
new_weight = to_affine_quantized_intx(
1385+
weight,
1386+
mapping_type,
1387+
block_size,
1388+
target_dtype,
1389+
eps=eps,
1390+
zero_point_dtype=zero_point_dtype,
1391+
)
1392+
else:
1393+
assert config.version == 2, f"Unexpected version: {config.version}"
1394+
block_size = [weight.shape[0], weight.shape[1]]
1395+
new_weight = Int8Tensor.from_hp(weight, block_size=block_size)
13851396
return new_weight
13861397

13871398

@@ -1509,12 +1520,14 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
15091520
in original precision during decode operations.
15101521
set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values
15111522
for better performance with this quantization scheme.
1523+
version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Int8Tensor
15121524
"""
15131525

15141526
layout: Optional[Layout] = PlainLayout()
15151527
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
15161528
weight_only_decode: bool = False
15171529
set_inductor_config: bool = True
1530+
version: int = 1
15181531

15191532
def __post_init__(self):
15201533
torch._C._log_api_usage_once(
@@ -1562,19 +1575,28 @@ def get_weight_block_size(x):
15621575
else:
15631576
input_quant_func = _int8_asymm_per_token_quant
15641577

1565-
block_size = get_weight_block_size(weight)
1566-
new_weight = to_affine_quantized_intx(
1567-
weight,
1568-
mapping_type,
1569-
block_size,
1570-
target_dtype,
1571-
eps=eps,
1572-
zero_point_dtype=zero_point_dtype,
1573-
_layout=layout,
1574-
zero_point_domain=weight_zero_point_domain,
1575-
)
1576-
new_weight = to_linear_activation_quantized(new_weight, input_quant_func)
1577-
return new_weight
1578+
if config.version == 1:
1579+
block_size = get_weight_block_size(weight)
1580+
quantized_weight = to_affine_quantized_intx(
1581+
weight,
1582+
mapping_type,
1583+
block_size,
1584+
target_dtype,
1585+
eps=eps,
1586+
zero_point_dtype=zero_point_dtype,
1587+
_layout=layout,
1588+
zero_point_domain=weight_zero_point_domain,
1589+
)
1590+
quantized_weight = to_linear_activation_quantized(
1591+
quantized_weight, input_quant_func
1592+
)
1593+
else:
1594+
quantized_weight = Int8Tensor.from_hp(
1595+
weight,
1596+
block_size=get_weight_block_size(weight),
1597+
)
1598+
1599+
return quantized_weight
15781600

15791601

15801602
@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig)

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
2929
"""
3030

3131
block_size: Optional[list[int]] = None
32+
kernel_preference: Optional[str] = None
3233

3334

3435
# TODO: Implement block-wise quantization using block_size
@@ -102,6 +103,7 @@ def from_hp(
102103
w: torch.Tensor,
103104
block_size: list[int],
104105
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
106+
kernel_preference: Optional[str] = None,
105107
):
106108
if w.dim() != 2 or len(block_size) != 2:
107109
raise ValueError("Expected 2D tensor and block_size length 2")

0 commit comments

Comments
 (0)