Skip to content

Commit 062f3cc

Browse files
committed
update int8
1 parent 49a7a89 commit 062f3cc

File tree

3 files changed

+63
-46
lines changed

3 files changed

+63
-46
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,16 @@ def test_int8_linear_variants(
111111
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
112112

113113
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
114-
def test_static_quantization(self, dtype):
115-
"""Test static quantization with pre-computed scale"""
114+
def test_static_dynamic_quantization(self, dtype):
115+
"""Test static and dynamic quantization"""
116116
K, N = 128, 64
117117
weight = torch.randn(N, K, dtype=dtype, device="cuda")
118118
input_tensor = torch.randn(32, K, dtype=dtype, device="cuda")
119119

120+
# Dynamic quantization (runtime scale computation)
121+
dynamic_tensor = Int8Tensor.from_hp(weight, block_size=[N, K])
122+
123+
# Static quantization (pre-computed scale)
120124
act_scale, _ = choose_qparams_affine(
121125
input=input_tensor,
122126
mapping_type=MappingType.SYMMETRIC,
@@ -128,8 +132,8 @@ def test_static_quantization(self, dtype):
128132
zero_point_dtype=torch.int8,
129133
)
130134

131-
# Create weight with static quantization
132-
weight_int8 = Int8Tensor.from_hp(
135+
# Static quantization (with pre-computed scale)
136+
static_tensor = Int8Tensor.from_hp(
133137
weight,
134138
block_size=[N, K],
135139
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
@@ -138,9 +142,13 @@ def test_static_quantization(self, dtype):
138142
),
139143
)
140144

141-
output = torch.nn.functional.linear(input_tensor, weight_int8)
142-
self.assertEqual(output.shape, (32, N))
143-
self.assertEqual(output.dtype, dtype)
145+
dynamic_output = torch.nn.functional.linear(input_tensor, dynamic_tensor)
146+
static_output = torch.nn.functional.linear(input_tensor, static_tensor)
147+
148+
self.assertEqual(dynamic_output.shape, (32, N))
149+
self.assertEqual(static_output.shape, (32, N))
150+
self.assertEqual(dynamic_output.dtype, dtype)
151+
self.assertEqual(static_output.dtype, dtype)
144152

145153
@unittest.skip("granularity parameter not supported in current API")
146154
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@@ -190,6 +198,8 @@ def test_slice(self, config, device, dtype):
190198
# PerTensor: scale unchanged by slicing
191199
self.assertEqual(weight1.scale, dummy.weight.scale)
192200
self.assertEqual(weight2.scale, dummy.weight.scale)
201+
with self.assertRaises(NotImplementedError):
202+
_ = dummy.weight[::2]
193203

194204
def test_index_select(self):
195205
"""test that `x_0 = x[0]` works when `x` is a 2D `Int8Tensor`."""
@@ -212,7 +222,7 @@ def test_error_handling_and_dequant(self):
212222
test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16)
213223
tensor = Int8Tensor.from_hp(test_data, [1, 2])
214224

215-
dequantized = torch.ops.aten.dequantize.self(tensor)
225+
dequantized = tensor.dequantize()
216226
self.assertEqual(dequantized.shape, test_data.shape)
217227
self.assertLess(torch.abs(dequantized - test_data).max().item(), 0.1)
218228

torchao/quantization/quant_api.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,6 +1522,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
15221522
layout: Optional[Layout] = PlainLayout()
15231523
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
15241524
weight_only_decode: bool = False
1525+
granularity: Optional[Union[PerRow, PerTensor]] = PerRow()
15251526
set_inductor_config: bool = True
15261527
version: int = 2
15271528

@@ -1555,9 +1556,6 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
15551556
mapping_type = MappingType.SYMMETRIC
15561557
weight_zero_point_domain = ZeroPointDomain.NONE
15571558

1558-
def get_weight_block_size(x):
1559-
return tuple([1 for _ in range(x.dim() - 1)] + [x.shape[-1]])
1560-
15611559
target_dtype = torch.int8
15621560
eps = torch.finfo(torch.float32).eps
15631561
zero_point_dtype = torch.int64
@@ -1571,7 +1569,13 @@ def get_weight_block_size(x):
15711569
else:
15721570
input_quant_func = _int8_asymm_per_token_quant
15731571

1574-
block_size = get_weight_block_size(weight)
1572+
if isinstance(config.granularity, PerTensor):
1573+
# Tensor granularity
1574+
block_size = weight.shape
1575+
else:
1576+
# Per row granularity
1577+
block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]])
1578+
15751579
if config.version == 1:
15761580
warnings.warn(
15771581
"Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details"

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import List, Optional
8+
from typing import Optional
99

1010
import torch
1111
from torch.utils._python_dispatch import return_and_correct_aliasing
1212

13+
from torchao.float8.inference import _slice_scale_for_dimension
1314
from torchao.quantization.quant_primitives import (
1415
MappingType,
1516
_maybe_expand_scale_to_tensor_shape,
@@ -20,7 +21,7 @@
2021
QuantizeTensorKwargs,
2122
_choose_quant_func_and_quantize_tensor,
2223
)
23-
from torchao.utils import TorchAOBaseTensor
24+
from torchao.utils import TorchAOBaseTensor, fill_defaults
2425

2526
__all__ = ["Int8Tensor", "QuantizeTensorToInt8Kwargs"]
2627

@@ -32,11 +33,11 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
3233
"""Tensor kwargs for creating int8 tensor (either activation or weight)
3334
3435
Args:
35-
block_size (List[int]): block size for quantization granularity
36+
block_size (list[int]): block size for quantization granularity
3637
static_scale (Optional[torch.Tensor]): pre-computed scale for static quantization
3738
"""
3839

39-
block_size: List[int]
40+
block_size: list[int]
4041
static_scale: Optional[torch.Tensor] = None
4142

4243

@@ -64,7 +65,7 @@ def __new__(
6465
cls: type,
6566
qdata: torch.Tensor,
6667
scale: torch.Tensor,
67-
block_size: List[int],
68+
block_size: list[int],
6869
act_quant_kwargs=None,
6970
dtype=None,
7071
):
@@ -73,13 +74,13 @@ def __new__(
7374
"dtype": dtype or scale.dtype,
7475
"requires_grad": False,
7576
}
76-
return torch.Tensor._make_wrapper_subclass(cls, List(qdata.shape), **kwargs)
77+
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, **kwargs)
7778

7879
def __init__(
7980
self,
8081
qdata: torch.Tensor,
8182
scale: torch.Tensor,
82-
block_size: List[int],
83+
block_size: list[int],
8384
act_quant_kwargs=None,
8485
dtype=None,
8586
):
@@ -99,13 +100,13 @@ def __repr__(self):
99100
def from_hp(
100101
cls,
101102
w: torch.Tensor,
102-
block_size: List[int],
103+
block_size: list[int],
103104
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
104105
):
105106
if w.dim() != 2 or len(block_size) != 2:
106107
raise ValueError("Expected 2D tensor and block_size length 2")
107108

108-
if act_quant_kwargs and act_quant_kwargs.static_scale is not None:
109+
if act_quant_kwargs is not None and act_quant_kwargs.static_scale is not None:
109110
# INT8 × INT8 (static)
110111
scale = act_quant_kwargs.static_scale
111112
zero_point = torch.zeros_like(scale, dtype=torch.int8)
@@ -114,7 +115,7 @@ def from_hp(
114115
scale, zero_point = choose_qparams_affine(
115116
input=w,
116117
mapping_type=MappingType.SYMMETRIC,
117-
block_size=tuple(block_size),
118+
block_size=block_size,
118119
target_dtype=torch.int8,
119120
quant_min=-128,
120121
quant_max=127,
@@ -124,12 +125,19 @@ def from_hp(
124125

125126
int_data = quantize_affine(
126127
w,
127-
block_size=tuple(block_size),
128+
block_size=block_size,
128129
scale=scale,
129130
zero_point=zero_point,
130131
output_dtype=torch.int8,
131132
)
132133

134+
if tuple(block_size) == w.shape:
135+
# per-tensor
136+
scale = scale.expand(w.shape)
137+
elif len(scale.shape) == 1:
138+
# per-row, 1D -> 2D
139+
scale = scale.unsqueeze(-1)
140+
133141
return cls(
134142
int_data,
135143
scale,
@@ -208,37 +216,32 @@ def _(func, types, args, kwargs):
208216
return result + bias if bias is not None else result
209217

210218

211-
@implements([aten.slice.Tensor])
219+
@implements(aten.slice.Tensor)
212220
def _(func, types, args, kwargs):
213221
"""Slice operation for Int8Tensor"""
214-
tensor, dim, start, end, step = (
215-
args[0],
216-
args[1],
217-
args[2],
218-
args[3],
219-
args[4] if len(args) > 4 else 1,
220-
)
222+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
221223

222-
assert dim in (0, 1), f"Only dim 0 or 1 supported, got {dim}"
224+
if step != 1:
225+
raise NotImplementedError("Slicing with step > 1 is not supported")
223226

224-
if end >= tensor.shape[dim]:
225-
end = tensor.shape[dim]
227+
if end >= self.shape[dim]:
228+
end = self.shape[dim]
226229

227-
# Always slice the qdata
228-
sliced_qdata = func(tensor.qdata, dim, start, end, step)
230+
sliced_qdata = aten.slice.Tensor(self.qdata, dim, start, end, step)
229231

230-
if tensor.scale.numel() == 1:
232+
if self.scale.numel() == 1:
231233
# Per-tensor quantization - scale doesn't change
232-
sliced_scale = tensor.scale
233-
elif dim < tensor.scale.ndim and tensor.scale.shape[dim] > 1:
234+
sliced_scale = self.scale
235+
elif dim < self.scale.ndim and self.scale.shape[dim] > 1:
234236
# Block-wise quantization - need to slice the scale appropriately
235-
sliced_scale = func(tensor.scale, dim, start, end, step)
237+
sliced_scale = aten.slice.Tensor(self.scale, dim, start, end, step)
236238
else:
237-
sliced_scale = tensor.scale
238-
239-
# adjust block_size since the shape has changed, block_size[i] should not be greater than shape[i]
240-
block_size = List(tensor.block_size)
239+
# Block-wise quantization - need to slice the scale appropriately
240+
sliced_scale = _slice_scale_for_dimension(
241+
self.scale, self.qdata.shape, dim, start, end, step
242+
)
241243

244+
block_size = list(self.block_size)
242245
for i in range(len(block_size)):
243246
block_size[i] = min(block_size[i], sliced_qdata.shape[i])
244247

@@ -250,8 +253,8 @@ def _(func, types, args, kwargs):
250253
sliced_qdata,
251254
sliced_scale,
252255
block_size,
253-
tensor.act_quant_kwargs,
254-
tensor.dtype,
256+
self.act_quant_kwargs,
257+
dtype=self.dtype,
255258
),
256259
)
257260

0 commit comments

Comments
 (0)