1717from torchao .quantization .granularity import PerRow
1818from torchao .quantization .quant_primitives import (
1919 MappingType ,
20- _maybe_expand_scale_to_tensor_shape ,
2120 choose_qparams_affine ,
21+ dequantize_affine ,
2222 quantize_affine ,
2323)
2424from torchao .quantization .quantize_ .common import (
@@ -38,13 +38,11 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
3838 """Tensor kwargs for creating int8 tensor (either activation or weight)
3939
4040 Args:
41- block_size (list[int]): block size for quantization granularity
4241 granularity: the granularity for the Tensor, currently either PerRow() or PerTensor()
4342 # TODO: Static quantization support using `static_scale`, `static_zero_point`
4443 """
4544
46- block_size : list [int ]
47- granularity = PerRow ()
45+ granularity : object = PerRow ()
4846
4947
5048class Int8Tensor (TorchAOBaseTensor ):
@@ -56,12 +54,12 @@ class Int8Tensor(TorchAOBaseTensor):
5654 scale: scale factors for dequantization
5755
5856 Non-Tensor Attributes:
59- block_size: block size for quantization granularity
57+ granularity: the granularity for quantization (e.g., PerRow(), PerTensor())
6058 act_quant_kwargs: flags for dynamic activation quantization
6159 """
6260
6361 tensor_data_names = ["qdata" , "scale" ]
64- tensor_attribute_names = ["block_size " ]
62+ tensor_attribute_names = ["granularity " ]
6563 optional_tensor_attribute_names = [
6664 "act_quant_kwargs" ,
6765 "dtype" ,
@@ -86,20 +84,20 @@ def __init__(
8684 self ,
8785 qdata : torch .Tensor ,
8886 scale : torch .Tensor ,
89- block_size : list [ int ] ,
87+ granularity ,
9088 act_quant_kwargs = None ,
9189 dtype = None ,
9290 ):
9391 super ().__init__ ()
9492 self .qdata = qdata
9593 self .scale = scale
96- self .block_size = block_size
94+ self .granularity = granularity
9795 self .act_quant_kwargs = act_quant_kwargs
9896
9997 def __repr__ (self ):
10098 return (
10199 f"{ self .__class__ .__name__ } ({ self .act_quant_kwargs = } , { self .qdata = } , { self .scale = } , "
102- f"{ self .block_size = } , { self .shape = } , { self .device = } , { self .dtype = } )"
100+ f"{ self .granularity = } , { self .shape = } , { self .device = } , { self .dtype = } )"
103101 )
104102
105103 @classmethod
@@ -109,7 +107,7 @@ def from_hp(
109107 granularity = PerRow (),
110108 act_quant_kwargs : Optional [QuantizeTensorToInt8Kwargs ] = None ,
111109 ):
112- block_size = list ( get_block_size (w_hp .shape , granularity ) )
110+ block_size = get_block_size (w_hp .shape , granularity )
113111
114112 if w_hp .dim () not in [2 , 3 ] or len (block_size ) != w_hp .dim ():
115113 raise ValueError ("Expected 2D or 3D tensor with same block_size length" )
@@ -136,7 +134,7 @@ def from_hp(
136134 return cls (
137135 int_data ,
138136 scale ,
139- block_size ,
137+ granularity ,
140138 act_quant_kwargs = act_quant_kwargs ,
141139 dtype = w_hp .dtype ,
142140 )
@@ -147,13 +145,18 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
147145 if output_dtype is None :
148146 output_dtype = self .dtype
149147
150- qdata_fp = self .qdata .to (output_dtype )
151- scale = self .scale
152- while scale .ndim < qdata_fp .ndim :
153- scale = scale .unsqueeze (- 1 )
148+ block_size = get_block_size (self .qdata .shape , self .granularity )
154149
155- scale_expanded = _maybe_expand_scale_to_tensor_shape (scale , qdata_fp .shape )
156- return qdata_fp * scale_expanded .to (output_dtype )
150+ return dequantize_affine (
151+ input = self .qdata ,
152+ block_size = block_size ,
153+ scale = self .scale ,
154+ zero_point = None ,
155+ input_dtype = torch .int8 ,
156+ quant_min = - 128 ,
157+ quant_max = 127 ,
158+ output_dtype = output_dtype ,
159+ )
157160
158161
159162implements = Int8Tensor .implements
@@ -178,15 +181,6 @@ def _(func, types, args, kwargs):
178181 if not isinstance (activation_tensor , Int8Tensor ):
179182 # Dynamic activation quantization
180183 act_kwargs = weight_tensor .act_quant_kwargs
181- input_ndim = activation_tensor .ndim
182-
183- # Ensure block_size matches input tensor dimensions
184- if len (act_kwargs .block_size ) != input_ndim :
185- if input_ndim == 3 and len (act_kwargs .block_size ) == 2 :
186- block_size_updated = [1 ] + list (act_kwargs .block_size )
187- else :
188- block_size_updated = list (act_kwargs .block_size )[- input_ndim :]
189- act_kwargs = QuantizeTensorToInt8Kwargs (block_size = block_size_updated )
190184
191185 activation_tensor = _choose_quant_func_and_quantize_tensor (
192186 activation_tensor , act_kwargs
@@ -261,18 +255,14 @@ def _(func, types, args, kwargs):
261255 self .scale , self .qdata .shape , dim , start , end , step
262256 )
263257
264- block_size = list (self .block_size )
265- for i in range (len (block_size )):
266- block_size [i ] = min (block_size [i ], sliced_qdata .shape [i ])
267-
268258 return return_and_correct_aliasing (
269259 func ,
270260 args ,
271261 kwargs ,
272262 Int8Tensor (
273263 sliced_qdata ,
274264 sliced_scale ,
275- block_size ,
265+ self . granularity ,
276266 self .act_quant_kwargs ,
277267 dtype = self .dtype ,
278268 ),
@@ -296,7 +286,7 @@ def _(func, types, args, kwargs):
296286 Int8Tensor (
297287 selected_qdata ,
298288 selected_scale ,
299- self .block_size [ 1 :] ,
289+ self .granularity ,
300290 self .act_quant_kwargs ,
301291 self .dtype ,
302292 ),
0 commit comments