Skip to content

Commit 597000c

Browse files
authored
[ONNX] Add MatMulInteger importer (#10450)
* implement matmulinteger * rm test * rm outdated comments * fix lint and review * wip * fixes * fix * alter tests * extra 4x4x4 step * comments
1 parent 6d9a111 commit 597000c

File tree

4 files changed

+89
-6
lines changed

4 files changed

+89
-6
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3913,10 +3913,17 @@ class QLinearMatMul(OnnxOpConverter):
39133913
- Only supports 2D input tensors.
39143914
- Not guaranteed to meet the integer-overflow behavior stipulated in the
39153915
ONNX documentation for this operator.
3916+
3917+
The QLinearMatMul converter is re-used for MatMulInteger and is adapted for
3918+
the latter with the optional `expected_out_dtypes` argument.
39163919
"""
39173920

39183921
@classmethod
3919-
def _impl_v10(cls, inputs, attr, params):
3922+
def _impl_v10(cls, inputs, attr, params, expected_out_dtypes=None):
3923+
if expected_out_dtypes is None:
3924+
# The default QLinearMatMul converter is expected to have one of
3925+
# these output dtypes.
3926+
expected_out_dtypes = ["int8", "uint8"]
39203927

39213928
# Some of the ops used below take scalar-like inputs, and may require either
39223929
# of the following:
@@ -3966,7 +3973,7 @@ def try_resolve_to_const(x, dtype_override=None):
39663973
assert b_zp_type.dtype == b_type.dtype
39673974

39683975
assert y_scale_type.dtype == "float32"
3969-
assert y_zp_type.dtype in ["int8", "uint8"]
3976+
assert y_zp_type.dtype in expected_out_dtypes
39703977

39713978
# TODO: relax this limitation in a future version of this importer.
39723979
a_rank = len(a_shape)
@@ -4028,6 +4035,11 @@ def try_resolve_to_const(x, dtype_override=None):
40284035
matmul_result_scale_scalar = fold_constant(_op.multiply(a_scale_scalar, b_scale_scalar))
40294036
matmul_result_zp_scalar = _op.const(0, dtype="int32")
40304037

4038+
if "int32" in expected_out_dtypes:
4039+
# This is the adaptation of the QLinearMatMul converter for MatMulInteger,
4040+
# in the MatMulInteger case we skip the unnecessary requantization step.
4041+
return matmul_result
4042+
40314043
# requantize requires y_scale to be constant,
40324044
# if y_scale is not constant, doing dequantize -> quantize
40334045
if isinstance(y_scale_scalar, _expr.Constant):
@@ -4053,6 +4065,58 @@ def try_resolve_to_const(x, dtype_override=None):
40534065
return y
40544066

40554067

4068+
class MatMulInteger(OnnxOpConverter):
4069+
"""Operator converter for MatMulInteger."""
4070+
4071+
@classmethod
4072+
def _impl_v10(cls, inputs, attr, params):
4073+
a = inputs[0]
4074+
b = inputs[1]
4075+
4076+
a_dtype = infer_type(a).checked_type.dtype
4077+
b_dtype = infer_type(b).checked_type.dtype
4078+
4079+
assert a_dtype in ("int8", "uint8"), "MatMulInteger: invalid dtype for first input"
4080+
assert b_dtype in ("int8", "uint8"), "MatMulInteger: invalid dtype for second input"
4081+
4082+
assert a_dtype == b_dtype, "MatMulInteger: input dtypes must match"
4083+
4084+
a_scale = _op.const(1.0, dtype="float32")
4085+
b_scale = _op.const(1.0, dtype="float32")
4086+
out_scale = _op.const(1.0, dtype="float32")
4087+
4088+
a_zero_point = _op.const(0.0, dtype=a_dtype)
4089+
b_zero_point = _op.const(0.0, dtype=b_dtype)
4090+
out_zero_point = _op.const(0.0, dtype="int32")
4091+
4092+
if len(inputs) == 4:
4093+
a_zero_point = inputs[2]
4094+
b_zero_point = inputs[3]
4095+
4096+
a_zp_dtype = infer_type(a_zero_point).checked_type.dtype
4097+
b_zp_dtype = infer_type(b_zero_point).checked_type.dtype
4098+
assert (
4099+
a_zp_dtype == a_dtype and b_zp_dtype == b_dtype
4100+
), "MatMulInteger: input dtype doesn't match zero point dtype"
4101+
elif len(inputs) != 2:
4102+
raise AssertionError(
4103+
"MatMulInteger op takes 2 or 4 inputs, {} given".format(len(inputs))
4104+
)
4105+
4106+
inputs = [
4107+
a,
4108+
a_scale,
4109+
a_zero_point,
4110+
b,
4111+
b_scale,
4112+
b_zero_point,
4113+
out_scale,
4114+
out_zero_point,
4115+
]
4116+
4117+
return QLinearMatMul.get_converter(10)(inputs, attr, params, expected_out_dtypes=["int32"])
4118+
4119+
40564120
class QLinearMul(OnnxOpConverter):
40574121
"""Operator converter for QLinearMul from Microsoft onnxruntime contrib opset."""
40584122

@@ -4781,6 +4845,7 @@ def _get_convert_map(opset):
47814845
"Softsign": Softsign.get_converter(opset),
47824846
"Gemm": Gemm.get_converter(opset),
47834847
"MatMul": MatMul.get_converter(opset),
4848+
"MatMulInteger": MatMulInteger.get_converter(opset),
47844849
"MatMulInteger16": MatMulInteger16.get_converter(opset),
47854850
"Mod": Mod.get_converter(opset),
47864851
"Xor": Renamer("logical_xor"),

python/tvm/topi/cuda/tensorcore_alter_op.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,22 @@ def _dense_legalize(attrs, inputs, arg_types):
167167
return None
168168

169169
(dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates)
170+
skip_pad = extra_flops_ratio > 2
171+
172+
if skip_pad and dtype in ["int8", "uint8"]:
173+
skip_pad = False
174+
# If tensorcore schedule padding fails, pad to nearest upward 4x4x4 as long as
175+
# the additional flops ratio isn't double or more.
176+
# Note that 4x4x4 is invalid for tensorcore scheduling, but padding upwards to 4x4x4
177+
# doesn't hurt if tensorcore padding has already failed.
178+
if M % 4 == 0 and K % 4 == 0 and N % 4 == 0:
179+
# No need to pad
180+
return None
181+
(dm, dk, dn) = _pad_to(M, K, N, (4, 4, 4))
182+
extra_flops_ratio = _extra_flops(M, K, N, dm, dk, dn) / (M * K * N)
183+
skip_pad = extra_flops_ratio > 2
170184

171-
if extra_flops_ratio > 2:
185+
if skip_pad:
172186
logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio)
173187
return None
174188

@@ -198,14 +212,18 @@ def pad_to_tensorcore(M, K, N, candidates):
198212
best_pad = (0, 0, 0)
199213
for padding in candidates:
200214
dm, dk, dn = _pad_to(M, K, N, padding)
201-
e = (M + dm) * (N + dn) * (K + dk) - M * N * K
215+
e = _extra_flops(M, K, N, dm, dk, dn)
202216
# print(dm, dk, dn, e, flops)
203217
if e < extra_flops:
204218
extra_flops = e
205219
best_pad = (dm, dk, dn)
206220
return best_pad, extra_flops / flops
207221

208222

223+
def _extra_flops(M, K, N, dm, dk, dn):
224+
return (M + dm) * (N + dn) * (K + dk) - M * N * K
225+
226+
209227
def _pad_to(M, K, N, PADDING):
210228
dm, dk, dn = 0, 0, 0
211229

tests/python/frontend/onnx/test_forward.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5053,7 +5053,6 @@ def verify_eyelike(indata):
50535053
"test_loop11",
50545054
"test_loop13_seq",
50555055
"test_lstm_batchwise",
5056-
"test_matmulinteger",
50575056
"test_maxpool_with_argmax_2d_precomputed_pads",
50585057
"test_maxpool_with_argmax_2d_precomputed_strides",
50595058
"test_maxunpool_export_with_output_shape",

tests/python/relay/test_pass_legalize_tensorcore.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def expected():
249249
a = before()
250250
a = run_opt_pass(a, transform.Legalize())
251251
b = run_opt_pass(expected(), transform.InferType())
252+
252253
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)
253254

254255
# dense
@@ -259,7 +260,7 @@ def expected():
259260
_test_legalize_dense((8, 16), (31, 16), (0, 0, 1), dtype)
260261
_test_legalize_dense((7, 15), (31, 15), (1, 1, 1), dtype)
261262
_test_legalize_dense((3, 16), (32, 16), (5, 0, 0), dtype)
262-
_test_legalize_dense((2, 16), (32, 16), (0, 0, 0), dtype, False)
263+
_test_legalize_dense((1, 16), (32, 16), (0, 0, 0), dtype, False)
263264

264265
# Test if units parameter is correctly updated
265266
_test_legalize_dense((8, 16), (30, 16), (0, 0, 2), "float16", units=30)

0 commit comments

Comments
 (0)