Skip to content

Commit 3ab38ba

Browse files
committed
fix quantization ops
1 parent 305c3a9 commit 3ab38ba

File tree

1 file changed

+34
-33
lines changed

1 file changed

+34
-33
lines changed

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from torchao.float8.inference import (
1414
_slice_scale_for_dimension,
15-
preprocess_scale,
1615
)
1716
from torchao.kernel import int_scaled_matmul
1817
from torchao.quantization.quant_primitives import (
@@ -169,21 +168,6 @@ def _(func, types, args, kwargs):
169168
f"Expected weight to be Int8Tensor, got {type(weight_tensor)}"
170169
)
171170

172-
# Store original shape for reshaping result
173-
original_weight_shape = weight_tensor.qdata.shape
174-
175-
# Reshape 3D weights to 2D: (B, N, K) -> (B*N, K)
176-
if weight_tensor.qdata.dim() == 3:
177-
w_q_2d = weight_tensor.qdata.reshape(-1, original_weight_shape[-1])
178-
w_scale_2d = (
179-
weight_tensor.scale.reshape(-1)
180-
if weight_tensor.scale.numel() > 1
181-
else weight_tensor.scale
182-
)
183-
else:
184-
w_q_2d = weight_tensor.qdata
185-
w_scale_2d = weight_tensor.scale
186-
187171
if weight_tensor.act_quant_kwargs is not None:
188172
if not isinstance(activation_tensor, Int8Tensor):
189173
# Dynamic activation quantization
@@ -202,35 +186,52 @@ def _(func, types, args, kwargs):
202186
activation_tensor, act_kwargs
203187
)
204188

205-
x_vals = activation_tensor.qdata.reshape(-1, activation_tensor.qdata.shape[-1])
206-
x_scales = preprocess_scale(activation_tensor.scale, x_vals.shape)
207-
w_vals_t = w_q_2d.contiguous().t()
189+
# 1. do the matrix form of dot(X_i, W_j)
190+
#
191+
# 2. rescale the output
192+
#
193+
# in cases with large matrices, y_dot_int32 can grow sufficiently
194+
# large that y_dot_int32 * a FP16 scale is greater than the maximum
195+
# value of a FP16, (which results in a value of inf even if multiplying
196+
# by the other scale would bring it within the expected range)
197+
198+
x_vals_int8 = activation_tensor.qdata
199+
x_scales = activation_tensor.scale
200+
w_vals_int8_t = weight_tensor.qdata.contiguous().t()
201+
w_scales = weight_tensor.scale
202+
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
203+
x_scales_dtype = x_scales.dtype
204+
# Cast FP16 scale to float to avoid overflow in int_scaled_matmul
208205
intermediate_dtype = (
209-
torch.float if x_scales.dtype == torch.half else x_scales.dtype
206+
torch.float if x_scales_dtype == torch.half else x_scales_dtype
210207
)
211-
212208
y_dot_scaled = int_scaled_matmul(
213-
x_vals, w_vals_t, x_scales.to(intermediate_dtype)
209+
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)
214210
)
215-
y_dot_scaled = y_dot_scaled.to(activation_tensor.scale.dtype)
211+
y_dot_scaled = y_dot_scaled.to(x_scales_dtype)
216212

217-
result = (y_dot_scaled * w_scale_2d).reshape(
218-
*activation_tensor.shape[:-1], *original_weight_shape[:-1]
213+
y = (y_dot_scaled * w_scales).reshape(
214+
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
219215
)
220-
result = result.to(activation_tensor.dtype)
216+
217+
# can downcast only at the very end
218+
output_dtype = activation_tensor.dtype
219+
y = y.to(output_dtype)
220+
if bias is not None:
221+
y += bias
222+
return y
221223
else:
222224
# FP × INT8 (weight-only)
223-
w_vals_int8_t = w_q_2d.t()
225+
w_vals_int8_t = weight_tensor.qdata.t()
224226
m = torch.mm(
225227
activation_tensor.reshape(-1, activation_tensor.shape[-1]),
226228
w_vals_int8_t.to(activation_tensor.dtype),
227229
)
228-
result = m * w_scale_2d.to(m.dtype)
229-
result = result.reshape(
230-
*activation_tensor.shape[:-1], *original_weight_shape[:-1]
231-
)
232-
233-
return result + bias if bias is not None else result
230+
y = m * weight_tensor.scale.to(m.dtype)
231+
y = y.reshape(*activation_tensor.shape[:-1], weight_tensor.qdata.shape[0])
232+
if bias is not None:
233+
y += bias
234+
return y
234235

235236

236237
@implements(aten.slice.Tensor)

0 commit comments

Comments
 (0)