Skip to content

Commit 664ae4d

Browse files
committed
use Granularity for slicing logic instead of block_size
1 parent 027afd8 commit 664ae4d

File tree

3 files changed

+30
-38
lines changed

3 files changed

+30
-38
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
quantize_,
1919
)
2020
from torchao.quantization.granularity import PerRow, PerTensor
21-
from torchao.quantization.utils import compute_error
21+
from torchao.quantization.utils import compute_error, get_block_size
2222
from torchao.testing.utils import TorchAOIntegrationTestCase
2323

2424

@@ -186,9 +186,13 @@ def test_index_select(self, config, granularity):
186186

187187
# Test block_size granularity
188188
if isinstance(granularity, PerRow):
189-
self.assertEqual(x_int8.block_size, [1, K])
189+
self.assertEqual(
190+
list(get_block_size(x_int8.shape, x_int8.granularity)), [1, K]
191+
)
190192
elif isinstance(granularity, PerTensor):
191-
self.assertEqual(x_int8.block_size, [N, K])
193+
self.assertEqual(
194+
list(get_block_size(x_int8.shape, x_int8.granularity)), [N, K]
195+
)
192196

193197
@common_utils.parametrize(
194198
"config",

torchao/quantization/quant_api.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,13 +1587,11 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
15871587
)
15881588

15891589
assert config.version == 2, f"Unexpected version: {config.version}"
1590-
# Compute block_size from granularity for activation quantization kwargs
1591-
block_size = get_block_size(weight.shape, config.granularity)
15921590

15931591
quantized_weight = Int8Tensor.from_hp(
15941592
weight,
15951593
granularity=config.granularity,
1596-
act_quant_kwargs=QuantizeTensorToInt8Kwargs(block_size=list(block_size)),
1594+
act_quant_kwargs=QuantizeTensorToInt8Kwargs(granularity=config.granularity),
15971595
)
15981596

15991597
return quantized_weight

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from torchao.quantization.granularity import PerRow
1818
from torchao.quantization.quant_primitives import (
1919
MappingType,
20-
_maybe_expand_scale_to_tensor_shape,
2120
choose_qparams_affine,
21+
dequantize_affine,
2222
quantize_affine,
2323
)
2424
from torchao.quantization.quantize_.common import (
@@ -38,13 +38,11 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
3838
"""Tensor kwargs for creating int8 tensor (either activation or weight)
3939
4040
Args:
41-
block_size (list[int]): block size for quantization granularity
4241
granularity: the granularity for the Tensor, currently either PerRow() or PerTensor()
4342
# TODO: Static quantization support using `static_scale`, `static_zero_point`
4443
"""
4544

46-
block_size: list[int]
47-
granularity = PerRow()
45+
granularity: object = PerRow()
4846

4947

5048
class Int8Tensor(TorchAOBaseTensor):
@@ -56,12 +54,12 @@ class Int8Tensor(TorchAOBaseTensor):
5654
scale: scale factors for dequantization
5755
5856
Non-Tensor Attributes:
59-
block_size: block size for quantization granularity
57+
granularity: the granularity for quantization (e.g., PerRow(), PerTensor())
6058
act_quant_kwargs: flags for dynamic activation quantization
6159
"""
6260

6361
tensor_data_names = ["qdata", "scale"]
64-
tensor_attribute_names = ["block_size"]
62+
tensor_attribute_names = ["granularity"]
6563
optional_tensor_attribute_names = [
6664
"act_quant_kwargs",
6765
"dtype",
@@ -86,20 +84,20 @@ def __init__(
8684
self,
8785
qdata: torch.Tensor,
8886
scale: torch.Tensor,
89-
block_size: list[int],
87+
granularity,
9088
act_quant_kwargs=None,
9189
dtype=None,
9290
):
9391
super().__init__()
9492
self.qdata = qdata
9593
self.scale = scale
96-
self.block_size = block_size
94+
self.granularity = granularity
9795
self.act_quant_kwargs = act_quant_kwargs
9896

9997
def __repr__(self):
10098
return (
10199
f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.scale=}, "
102-
f"{self.block_size=}, {self.shape=}, {self.device=}, {self.dtype=})"
100+
f"{self.granularity=}, {self.shape=}, {self.device=}, {self.dtype=})"
103101
)
104102

105103
@classmethod
@@ -109,7 +107,7 @@ def from_hp(
109107
granularity=PerRow(),
110108
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
111109
):
112-
block_size = list(get_block_size(w_hp.shape, granularity))
110+
block_size = get_block_size(w_hp.shape, granularity)
113111

114112
if w_hp.dim() not in [2, 3] or len(block_size) != w_hp.dim():
115113
raise ValueError("Expected 2D or 3D tensor with same block_size length")
@@ -136,7 +134,7 @@ def from_hp(
136134
return cls(
137135
int_data,
138136
scale,
139-
block_size,
137+
granularity,
140138
act_quant_kwargs=act_quant_kwargs,
141139
dtype=w_hp.dtype,
142140
)
@@ -147,13 +145,18 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
147145
if output_dtype is None:
148146
output_dtype = self.dtype
149147

150-
qdata_fp = self.qdata.to(output_dtype)
151-
scale = self.scale
152-
while scale.ndim < qdata_fp.ndim:
153-
scale = scale.unsqueeze(-1)
148+
block_size = get_block_size(self.qdata.shape, self.granularity)
154149

155-
scale_expanded = _maybe_expand_scale_to_tensor_shape(scale, qdata_fp.shape)
156-
return qdata_fp * scale_expanded.to(output_dtype)
150+
return dequantize_affine(
151+
input=self.qdata,
152+
block_size=block_size,
153+
scale=self.scale,
154+
zero_point=None,
155+
input_dtype=torch.int8,
156+
quant_min=-128,
157+
quant_max=127,
158+
output_dtype=output_dtype,
159+
)
157160

158161

159162
implements = Int8Tensor.implements
@@ -178,15 +181,6 @@ def _(func, types, args, kwargs):
178181
if not isinstance(activation_tensor, Int8Tensor):
179182
# Dynamic activation quantization
180183
act_kwargs = weight_tensor.act_quant_kwargs
181-
input_ndim = activation_tensor.ndim
182-
183-
# Ensure block_size matches input tensor dimensions
184-
if len(act_kwargs.block_size) != input_ndim:
185-
if input_ndim == 3 and len(act_kwargs.block_size) == 2:
186-
block_size_updated = [1] + list(act_kwargs.block_size)
187-
else:
188-
block_size_updated = list(act_kwargs.block_size)[-input_ndim:]
189-
act_kwargs = QuantizeTensorToInt8Kwargs(block_size=block_size_updated)
190184

191185
activation_tensor = _choose_quant_func_and_quantize_tensor(
192186
activation_tensor, act_kwargs
@@ -261,18 +255,14 @@ def _(func, types, args, kwargs):
261255
self.scale, self.qdata.shape, dim, start, end, step
262256
)
263257

264-
block_size = list(self.block_size)
265-
for i in range(len(block_size)):
266-
block_size[i] = min(block_size[i], sliced_qdata.shape[i])
267-
268258
return return_and_correct_aliasing(
269259
func,
270260
args,
271261
kwargs,
272262
Int8Tensor(
273263
sliced_qdata,
274264
sliced_scale,
275-
block_size,
265+
self.granularity,
276266
self.act_quant_kwargs,
277267
dtype=self.dtype,
278268
),
@@ -296,7 +286,7 @@ def _(func, types, args, kwargs):
296286
Int8Tensor(
297287
selected_qdata,
298288
selected_scale,
299-
self.block_size[1:],
289+
self.granularity,
300290
self.act_quant_kwargs,
301291
self.dtype,
302292
),

0 commit comments

Comments
 (0)