@@ -3913,10 +3913,17 @@ class QLinearMatMul(OnnxOpConverter):
39133913 - Only supports 2D input tensors.
39143914 - Not guaranteed to meet the integer-overflow behavior stipulated in the
39153915 ONNX documentation for this operator.
3916+
3917+ The QLinearMatMul converter is re-used for MatMulInteger and is adapted for
3918+ the latter with the optional `expected_out_dtypes` argument.
39163919 """
39173920
39183921 @classmethod
3919- def _impl_v10 (cls , inputs , attr , params ):
3922+ def _impl_v10 (cls , inputs , attr , params , expected_out_dtypes = None ):
3923+ if expected_out_dtypes is None :
3924+ # The default QLinearMatMul converter is expected to have one of
3925+ # these output dtypes.
3926+ expected_out_dtypes = ["int8" , "uint8" ]
39203927
39213928 # Some of the ops used below take scalar-like inputs, and may require either
39223929 # of the following:
@@ -3966,7 +3973,7 @@ def try_resolve_to_const(x, dtype_override=None):
39663973 assert b_zp_type .dtype == b_type .dtype
39673974
39683975 assert y_scale_type .dtype == "float32"
3969- assert y_zp_type .dtype in [ "int8" , "uint8" ]
3976+ assert y_zp_type .dtype in expected_out_dtypes
39703977
39713978 # TODO: relax this limitation in a future version of this importer.
39723979 a_rank = len (a_shape )
@@ -4028,6 +4035,11 @@ def try_resolve_to_const(x, dtype_override=None):
40284035 matmul_result_scale_scalar = fold_constant (_op .multiply (a_scale_scalar , b_scale_scalar ))
40294036 matmul_result_zp_scalar = _op .const (0 , dtype = "int32" )
40304037
4038+ if "int32" in expected_out_dtypes :
4039+ # This is the adaptation of the QLinearMatMul converter for MatMulInteger,
4040+ # in the MatMulInteger case we skip the unnecessary requantization step.
4041+ return matmul_result
4042+
40314043 # requantize requires y_scale to be constant,
40324044 # if y_scale is not constant, doing dequantize -> quantize
40334045 if isinstance (y_scale_scalar , _expr .Constant ):
@@ -4053,6 +4065,58 @@ def try_resolve_to_const(x, dtype_override=None):
40534065 return y
40544066
40554067
4068+ class MatMulInteger (OnnxOpConverter ):
4069+ """Operator converter for MatMulInteger."""
4070+
4071+ @classmethod
4072+ def _impl_v10 (cls , inputs , attr , params ):
4073+ a = inputs [0 ]
4074+ b = inputs [1 ]
4075+
4076+ a_dtype = infer_type (a ).checked_type .dtype
4077+ b_dtype = infer_type (b ).checked_type .dtype
4078+
4079+ assert a_dtype in ("int8" , "uint8" ), "MatMulInteger: invalid dtype for first input"
4080+ assert b_dtype in ("int8" , "uint8" ), "MatMulInteger: invalid dtype for second input"
4081+
4082+ assert a_dtype == b_dtype , "MatMulInteger: input dtypes must match"
4083+
4084+ a_scale = _op .const (1.0 , dtype = "float32" )
4085+ b_scale = _op .const (1.0 , dtype = "float32" )
4086+ out_scale = _op .const (1.0 , dtype = "float32" )
4087+
4088+ a_zero_point = _op .const (0.0 , dtype = a_dtype )
4089+ b_zero_point = _op .const (0.0 , dtype = b_dtype )
4090+ out_zero_point = _op .const (0.0 , dtype = "int32" )
4091+
4092+ if len (inputs ) == 4 :
4093+ a_zero_point = inputs [2 ]
4094+ b_zero_point = inputs [3 ]
4095+
4096+ a_zp_dtype = infer_type (a_zero_point ).checked_type .dtype
4097+ b_zp_dtype = infer_type (b_zero_point ).checked_type .dtype
4098+ assert (
4099+ a_zp_dtype == a_dtype and b_zp_dtype == b_dtype
4100+ ), "MatMulInteger: input dtype doesn't match zero point dtype"
4101+ elif len (inputs ) != 2 :
4102+ raise AssertionError (
4103+ "MatMulInteger op takes 2 or 4 inputs, {} given" .format (len (inputs ))
4104+ )
4105+
4106+ inputs = [
4107+ a ,
4108+ a_scale ,
4109+ a_zero_point ,
4110+ b ,
4111+ b_scale ,
4112+ b_zero_point ,
4113+ out_scale ,
4114+ out_zero_point ,
4115+ ]
4116+
4117+ return QLinearMatMul .get_converter (10 )(inputs , attr , params , expected_out_dtypes = ["int32" ])
4118+
4119+
40564120class QLinearMul (OnnxOpConverter ):
40574121 """Operator converter for QLinearMul from Microsoft onnxruntime contrib opset."""
40584122
@@ -4781,6 +4845,7 @@ def _get_convert_map(opset):
47814845 "Softsign" : Softsign .get_converter (opset ),
47824846 "Gemm" : Gemm .get_converter (opset ),
47834847 "MatMul" : MatMul .get_converter (opset ),
4848+ "MatMulInteger" : MatMulInteger .get_converter (opset ),
47844849 "MatMulInteger16" : MatMulInteger16 .get_converter (opset ),
47854850 "Mod" : Mod .get_converter (opset ),
47864851 "Xor" : Renamer ("logical_xor" ),
0 commit comments