55# LICENSE file in the root directory of this source tree.
66
77from dataclasses import dataclass
8- from typing import List , Optional
8+ from typing import Optional
99
1010import torch
1111from torch .utils ._python_dispatch import return_and_correct_aliasing
1212
13+ from torchao .float8 .inference import _slice_scale_for_dimension
1314from torchao .quantization .quant_primitives import (
1415 MappingType ,
1516 _maybe_expand_scale_to_tensor_shape ,
2021 QuantizeTensorKwargs ,
2122 _choose_quant_func_and_quantize_tensor ,
2223)
23- from torchao .utils import TorchAOBaseTensor
24+ from torchao .utils import TorchAOBaseTensor , fill_defaults
2425
2526__all__ = ["Int8Tensor" , "QuantizeTensorToInt8Kwargs" ]
2627
@@ -32,11 +33,11 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
3233 """Tensor kwargs for creating int8 tensor (either activation or weight)
3334
3435 Args:
35- block_size (List [int]): block size for quantization granularity
36+ block_size (list [int]): block size for quantization granularity
3637 static_scale (Optional[torch.Tensor]): pre-computed scale for static quantization
3738 """
3839
39- block_size : List [int ]
40+ block_size : list [int ]
4041 static_scale : Optional [torch .Tensor ] = None
4142
4243
@@ -64,7 +65,7 @@ def __new__(
6465 cls : type ,
6566 qdata : torch .Tensor ,
6667 scale : torch .Tensor ,
67- block_size : List [int ],
68+ block_size : list [int ],
6869 act_quant_kwargs = None ,
6970 dtype = None ,
7071 ):
@@ -73,13 +74,13 @@ def __new__(
7374 "dtype" : dtype or scale .dtype ,
7475 "requires_grad" : False ,
7576 }
76- return torch .Tensor ._make_wrapper_subclass (cls , List ( qdata .shape ) , ** kwargs )
77+ return torch .Tensor ._make_wrapper_subclass (cls , qdata .shape , ** kwargs )
7778
7879 def __init__ (
7980 self ,
8081 qdata : torch .Tensor ,
8182 scale : torch .Tensor ,
82- block_size : List [int ],
83+ block_size : list [int ],
8384 act_quant_kwargs = None ,
8485 dtype = None ,
8586 ):
@@ -99,13 +100,13 @@ def __repr__(self):
99100 def from_hp (
100101 cls ,
101102 w : torch .Tensor ,
102- block_size : List [int ],
103+ block_size : list [int ],
103104 act_quant_kwargs : Optional [QuantizeTensorToInt8Kwargs ] = None ,
104105 ):
105106 if w .dim () != 2 or len (block_size ) != 2 :
106107 raise ValueError ("Expected 2D tensor and block_size length 2" )
107108
108- if act_quant_kwargs and act_quant_kwargs .static_scale is not None :
109+ if act_quant_kwargs is not None and act_quant_kwargs .static_scale is not None :
109110 # INT8 × INT8 (static)
110111 scale = act_quant_kwargs .static_scale
111112 zero_point = torch .zeros_like (scale , dtype = torch .int8 )
@@ -114,7 +115,7 @@ def from_hp(
114115 scale , zero_point = choose_qparams_affine (
115116 input = w ,
116117 mapping_type = MappingType .SYMMETRIC ,
117- block_size = tuple ( block_size ) ,
118+ block_size = block_size ,
118119 target_dtype = torch .int8 ,
119120 quant_min = - 128 ,
120121 quant_max = 127 ,
@@ -124,12 +125,19 @@ def from_hp(
124125
125126 int_data = quantize_affine (
126127 w ,
127- block_size = tuple ( block_size ) ,
128+ block_size = block_size ,
128129 scale = scale ,
129130 zero_point = zero_point ,
130131 output_dtype = torch .int8 ,
131132 )
132133
134+ if tuple (block_size ) == w .shape :
135+ # per-tensor
136+ scale = scale .expand (w .shape )
137+ elif len (scale .shape ) == 1 :
138+ # per-row, 1D -> 2D
139+ scale = scale .unsqueeze (- 1 )
140+
133141 return cls (
134142 int_data ,
135143 scale ,
@@ -208,37 +216,32 @@ def _(func, types, args, kwargs):
208216 return result + bias if bias is not None else result
209217
210218
211- @implements ([ aten .slice .Tensor ] )
219+ @implements (aten .slice .Tensor )
212220def _ (func , types , args , kwargs ):
213221 """Slice operation for Int8Tensor"""
214- tensor , dim , start , end , step = (
215- args [0 ],
216- args [1 ],
217- args [2 ],
218- args [3 ],
219- args [4 ] if len (args ) > 4 else 1 ,
220- )
222+ self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
221223
222- assert dim in (0 , 1 ), f"Only dim 0 or 1 supported, got { dim } "
224+ if step != 1 :
225+ raise NotImplementedError ("Slicing with step > 1 is not supported" )
223226
224- if end >= tensor .shape [dim ]:
225- end = tensor .shape [dim ]
227+ if end >= self .shape [dim ]:
228+ end = self .shape [dim ]
226229
227- # Always slice the qdata
228- sliced_qdata = func (tensor .qdata , dim , start , end , step )
230+ sliced_qdata = aten .slice .Tensor (self .qdata , dim , start , end , step )
229231
230- if tensor .scale .numel () == 1 :
232+ if self .scale .numel () == 1 :
231233 # Per-tensor quantization - scale doesn't change
232- sliced_scale = tensor .scale
233- elif dim < tensor .scale .ndim and tensor .scale .shape [dim ] > 1 :
234+ sliced_scale = self .scale
235+ elif dim < self .scale .ndim and self .scale .shape [dim ] > 1 :
234236 # Block-wise quantization - need to slice the scale appropriately
235- sliced_scale = func ( tensor .scale , dim , start , end , step )
237+ sliced_scale = aten . slice . Tensor ( self .scale , dim , start , end , step )
236238 else :
237- sliced_scale = tensor . scale
238-
239- # adjust block_size since the shape has changed, block_size[i] should not be greater than shape[i]
240- block_size = List ( tensor . block_size )
239+ # Block-wise quantization - need to slice the scale appropriately
240+ sliced_scale = _slice_scale_for_dimension (
241+ self . scale , self . qdata . shape , dim , start , end , step
242+ )
241243
244+ block_size = list (self .block_size )
242245 for i in range (len (block_size )):
243246 block_size [i ] = min (block_size [i ], sliced_qdata .shape [i ])
244247
@@ -250,8 +253,8 @@ def _(func, types, args, kwargs):
250253 sliced_qdata ,
251254 sliced_scale ,
252255 block_size ,
253- tensor .act_quant_kwargs ,
254- tensor .dtype ,
256+ self .act_quant_kwargs ,
257+ dtype = self .dtype ,
255258 ),
256259 )
257260
0 commit comments