1212
1313from torchao .float8 .inference import (
1414 _slice_scale_for_dimension ,
15- preprocess_scale ,
1615)
1716from torchao .kernel import int_scaled_matmul
1817from torchao .quantization .quant_primitives import (
@@ -169,21 +168,6 @@ def _(func, types, args, kwargs):
169168 f"Expected weight to be Int8Tensor, got { type (weight_tensor )} "
170169 )
171170
172- # Store original shape for reshaping result
173- original_weight_shape = weight_tensor .qdata .shape
174-
175- # Reshape 3D weights to 2D: (B, N, K) -> (B*N, K)
176- if weight_tensor .qdata .dim () == 3 :
177- w_q_2d = weight_tensor .qdata .reshape (- 1 , original_weight_shape [- 1 ])
178- w_scale_2d = (
179- weight_tensor .scale .reshape (- 1 )
180- if weight_tensor .scale .numel () > 1
181- else weight_tensor .scale
182- )
183- else :
184- w_q_2d = weight_tensor .qdata
185- w_scale_2d = weight_tensor .scale
186-
187171 if weight_tensor .act_quant_kwargs is not None :
188172 if not isinstance (activation_tensor , Int8Tensor ):
189173 # Dynamic activation quantization
@@ -202,35 +186,52 @@ def _(func, types, args, kwargs):
202186 activation_tensor , act_kwargs
203187 )
204188
205- x_vals = activation_tensor .qdata .reshape (- 1 , activation_tensor .qdata .shape [- 1 ])
206- x_scales = preprocess_scale (activation_tensor .scale , x_vals .shape )
207- w_vals_t = w_q_2d .contiguous ().t ()
189+ # 1. do the matrix form of dot(X_i, W_j)
190+ #
191+ # 2. rescale the output
192+ #
193+ # in cases with large matrices, y_dot_int32 can grow sufficiently
194+ # large that y_dot_int32 * a FP16 scale is greater than the maximum
195+ # value of a FP16, (which results in a value of inf even if multiplying
196+ # by the other scale would bring it within the expected range)
197+
198+ x_vals_int8 = activation_tensor .qdata
199+ x_scales = activation_tensor .scale
200+ w_vals_int8_t = weight_tensor .qdata .contiguous ().t ()
201+ w_scales = weight_tensor .scale
202+ tmp = x_vals_int8 .reshape (- 1 , x_vals_int8 .shape [- 1 ])
203+ x_scales_dtype = x_scales .dtype
204+ # Cast FP16 scale to float to avoid overflow in int_scaled_matmul
208205 intermediate_dtype = (
209- torch .float if x_scales . dtype == torch .half else x_scales . dtype
206+ torch .float if x_scales_dtype == torch .half else x_scales_dtype
210207 )
211-
212208 y_dot_scaled = int_scaled_matmul (
213- x_vals , w_vals_t , x_scales .to (intermediate_dtype )
209+ tmp , w_vals_int8_t , x_scales . reshape ( - 1 , 1 ) .to (intermediate_dtype )
214210 )
215- y_dot_scaled = y_dot_scaled .to (activation_tensor . scale . dtype )
211+ y_dot_scaled = y_dot_scaled .to (x_scales_dtype )
216212
217- result = (y_dot_scaled * w_scale_2d ).reshape (
218- * activation_tensor .shape [:- 1 ], * original_weight_shape [: - 1 ]
213+ y = (y_dot_scaled * w_scales ).reshape (
214+ * x_vals_int8 .shape [:- 1 ], y_dot_scaled . shape [ - 1 ]
219215 )
220- result = result .to (activation_tensor .dtype )
216+
217+ # can downcast only at the very end
218+ output_dtype = activation_tensor .dtype
219+ y = y .to (output_dtype )
220+ if bias is not None :
221+ y += bias
222+ return y
221223 else :
222224 # FP × INT8 (weight-only)
223- w_vals_int8_t = w_q_2d .t ()
225+ w_vals_int8_t = weight_tensor . qdata .t ()
224226 m = torch .mm (
225227 activation_tensor .reshape (- 1 , activation_tensor .shape [- 1 ]),
226228 w_vals_int8_t .to (activation_tensor .dtype ),
227229 )
228- result = m * w_scale_2d .to (m .dtype )
229- result = result .reshape (
230- * activation_tensor .shape [:- 1 ], * original_weight_shape [:- 1 ]
231- )
232-
233- return result + bias if bias is not None else result
230+ y = m * weight_tensor .scale .to (m .dtype )
231+ y = y .reshape (* activation_tensor .shape [:- 1 ], weight_tensor .qdata .shape [0 ])
232+ if bias is not None :
233+ y += bias
234+ return y
234235
235236
236237@implements (aten .slice .Tensor )
0 commit comments