Skip to content

Commit 9383550

Browse files
committed
update static/dynamic quantization workflows
1 parent b861dbc commit 9383550

File tree

1 file changed

+60
-17
lines changed

1 file changed

+60
-17
lines changed

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
2727
2828
Args:
2929
kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc.
30+
TODO: Implement flags for kernel preference, same as QuantizeTensorToFloat8Kwargs
3031
block_size (Optional[list[int]]): block size for quantization granularity
3132
"""
3233

@@ -165,26 +166,68 @@ def _(func, types, args, kwargs):
165166
f"Expected weight to be Int8Tensor, got {type(weight_tensor)}"
166167
)
167168

168-
# Dynamic activation quantization if enabled
169-
if weight_tensor.act_quant_kwargs is not None:
170-
input_tensor = _choose_quant_func_and_quantize_tensor(
171-
input_tensor, weight_tensor.act_quant_kwargs
169+
if isinstance(input_tensor, Int8Tensor):
170+
# INT8 × INT8 (static)
171+
x_vals_int8 = input_tensor.qdata
172+
x_scales = input_tensor.scale
173+
w_vals_int8_t = weight_tensor.qdata.contiguous().t()
174+
w_scales = weight_tensor.scale
175+
176+
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
177+
x_scales_dtype = x_scales.dtype
178+
179+
# Cast fp16 scale to float to avoid overflow in y_dot_int32
180+
intermediate_dtype = (
181+
torch.float if x_scales_dtype == torch.half else x_scales_dtype
172182
)
173183

174-
if isinstance(input_tensor, Int8Tensor):
175-
# INT8 × INT8 (dynamic)
176-
x_int32 = input_tensor.qdata.to(torch.int32)
177-
w_int32 = weight_tensor.qdata.to(torch.int32).t()
178-
179-
result = torch.mm(x_int32.view(-1, x_int32.size(-1)), w_int32)
180-
scale = input_tensor.scale.view(-1, 1) * weight_tensor.scale.unsqueeze(0)
181-
result = result.to(scale.dtype) * scale
182-
result = result.view(*input_tensor.shape[:-1], -1)
183-
else:
184-
# FP × INT8 (static)
185-
result = torch.nn.functional.linear(
186-
input_tensor, weight_tensor.dequantize(), None
184+
# First apply input scaling to avoid overflow
185+
y_dot_int32 = torch.mm(tmp.to(torch.int32), w_vals_int8_t.to(torch.int32))
186+
y_dot_scaled = y_dot_int32.to(intermediate_dtype) * x_scales.reshape(-1, 1).to(
187+
intermediate_dtype
187188
)
189+
y_dot_scaled = y_dot_scaled.to(x_scales_dtype)
190+
191+
# Then apply weight scaling
192+
result = (y_dot_scaled * w_scales).reshape(
193+
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
194+
)
195+
result = result.to(input_tensor.dtype)
196+
197+
else:
198+
if weight_tensor.act_quant_kwargs is not None:
199+
# INT8 × INT8 (dynamic)
200+
input_tensor = _choose_quant_func_and_quantize_tensor(
201+
input_tensor, weight_tensor.act_quant_kwargs
202+
)
203+
204+
x_vals_int8 = input_tensor.qdata
205+
x_scales = input_tensor.scale
206+
w_vals_int8_t = weight_tensor.qdata.contiguous().t()
207+
w_scales = weight_tensor.scale
208+
209+
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
210+
x_scales_dtype = x_scales.dtype
211+
212+
# Cast fp16 scale to float to avoid overflow in y_dot_int32
213+
intermediate_dtype = (
214+
torch.float if x_scales_dtype == torch.half else x_scales_dtype
215+
)
216+
y_dot_int32 = torch.mm(tmp.to(torch.int32), w_vals_int8_t.to(torch.int32))
217+
y_dot_scaled = y_dot_int32.to(intermediate_dtype) * x_scales.reshape(
218+
-1, 1
219+
).to(intermediate_dtype)
220+
y_dot_scaled = y_dot_scaled.to(x_scales_dtype)
221+
222+
result = (y_dot_scaled * w_scales).reshape(
223+
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
224+
)
225+
result = result.to(input_tensor.dtype)
226+
else:
227+
# FP × INT8 (weight-only)
228+
result = torch.nn.functional.linear(
229+
input_tensor, weight_tensor.dequantize(), None
230+
)
188231

189232
return result + bias if bias is not None else result
190233

0 commit comments

Comments
 (0)