@@ -27,6 +27,7 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
2727
2828 Args:
2929 kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc.
30+ TODO: Implement flags for kernel preference, same as QuantizeTensorToFloat8Kwargs
3031 block_size (Optional[list[int]]): block size for quantization granularity
3132 """
3233
@@ -165,26 +166,68 @@ def _(func, types, args, kwargs):
165166 f"Expected weight to be Int8Tensor, got { type (weight_tensor )} "
166167 )
167168
168- # Dynamic activation quantization if enabled
169- if weight_tensor .act_quant_kwargs is not None :
170- input_tensor = _choose_quant_func_and_quantize_tensor (
171- input_tensor , weight_tensor .act_quant_kwargs
169+ if isinstance (input_tensor , Int8Tensor ):
170+ # INT8 × INT8 (static)
171+ x_vals_int8 = input_tensor .qdata
172+ x_scales = input_tensor .scale
173+ w_vals_int8_t = weight_tensor .qdata .contiguous ().t ()
174+ w_scales = weight_tensor .scale
175+
176+ tmp = x_vals_int8 .reshape (- 1 , x_vals_int8 .shape [- 1 ])
177+ x_scales_dtype = x_scales .dtype
178+
179+ # Cast fp16 scale to float to avoid overflow in y_dot_int32
180+ intermediate_dtype = (
181+ torch .float if x_scales_dtype == torch .half else x_scales_dtype
172182 )
173183
174- if isinstance (input_tensor , Int8Tensor ):
175- # INT8 × INT8 (dynamic)
176- x_int32 = input_tensor .qdata .to (torch .int32 )
177- w_int32 = weight_tensor .qdata .to (torch .int32 ).t ()
178-
179- result = torch .mm (x_int32 .view (- 1 , x_int32 .size (- 1 )), w_int32 )
180- scale = input_tensor .scale .view (- 1 , 1 ) * weight_tensor .scale .unsqueeze (0 )
181- result = result .to (scale .dtype ) * scale
182- result = result .view (* input_tensor .shape [:- 1 ], - 1 )
183- else :
184- # FP × INT8 (static)
185- result = torch .nn .functional .linear (
186- input_tensor , weight_tensor .dequantize (), None
184+ # First apply input scaling to avoid overflow
185+ y_dot_int32 = torch .mm (tmp .to (torch .int32 ), w_vals_int8_t .to (torch .int32 ))
186+ y_dot_scaled = y_dot_int32 .to (intermediate_dtype ) * x_scales .reshape (- 1 , 1 ).to (
187+ intermediate_dtype
187188 )
189+ y_dot_scaled = y_dot_scaled .to (x_scales_dtype )
190+
191+ # Then apply weight scaling
192+ result = (y_dot_scaled * w_scales ).reshape (
193+ * x_vals_int8 .shape [:- 1 ], y_dot_scaled .shape [- 1 ]
194+ )
195+ result = result .to (input_tensor .dtype )
196+
197+ else :
198+ if weight_tensor .act_quant_kwargs is not None :
199+ # INT8 × INT8 (dynamic)
200+ input_tensor = _choose_quant_func_and_quantize_tensor (
201+ input_tensor , weight_tensor .act_quant_kwargs
202+ )
203+
204+ x_vals_int8 = input_tensor .qdata
205+ x_scales = input_tensor .scale
206+ w_vals_int8_t = weight_tensor .qdata .contiguous ().t ()
207+ w_scales = weight_tensor .scale
208+
209+ tmp = x_vals_int8 .reshape (- 1 , x_vals_int8 .shape [- 1 ])
210+ x_scales_dtype = x_scales .dtype
211+
212+ # Cast fp16 scale to float to avoid overflow in y_dot_int32
213+ intermediate_dtype = (
214+ torch .float if x_scales_dtype == torch .half else x_scales_dtype
215+ )
216+ y_dot_int32 = torch .mm (tmp .to (torch .int32 ), w_vals_int8_t .to (torch .int32 ))
217+ y_dot_scaled = y_dot_int32 .to (intermediate_dtype ) * x_scales .reshape (
218+ - 1 , 1
219+ ).to (intermediate_dtype )
220+ y_dot_scaled = y_dot_scaled .to (x_scales_dtype )
221+
222+ result = (y_dot_scaled * w_scales ).reshape (
223+ * x_vals_int8 .shape [:- 1 ], y_dot_scaled .shape [- 1 ]
224+ )
225+ result = result .to (input_tensor .dtype )
226+ else :
227+ # FP × INT8 (weight-only)
228+ result = torch .nn .functional .linear (
229+ input_tensor , weight_tensor .dequantize (), None
230+ )
188231
189232 return result + bias if bias is not None else result
190233
0 commit comments