Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3913,10 +3913,17 @@ class QLinearMatMul(OnnxOpConverter):
- Only supports 2D input tensors.
- Not guaranteed to meet the integer-overflow behavior stipulated in the
ONNX documentation for this operator.

The QLinearMatMul converter is re-used for MatMulInteger and is adapted for
the latter with the optional `expected_out_dtypes` argument.
"""

@classmethod
def _impl_v10(cls, inputs, attr, params):
def _impl_v10(cls, inputs, attr, params, expected_out_dtypes=None):
if expected_out_dtypes is None:
# The default QLinearMatMul converter is expected to have one of
# these output dtypes.
expected_out_dtypes = ["int8", "uint8"]

# Some of the ops used below take scalar-like inputs, and may require either
# of the following:
Expand Down Expand Up @@ -3966,7 +3973,7 @@ def try_resolve_to_const(x, dtype_override=None):
assert b_zp_type.dtype == b_type.dtype

assert y_scale_type.dtype == "float32"
assert y_zp_type.dtype in ["int8", "uint8"]
assert y_zp_type.dtype in expected_out_dtypes

# TODO: relax this limitation in a future version of this importer.
a_rank = len(a_shape)
Expand Down Expand Up @@ -4028,6 +4035,11 @@ def try_resolve_to_const(x, dtype_override=None):
matmul_result_scale_scalar = fold_constant(_op.multiply(a_scale_scalar, b_scale_scalar))
matmul_result_zp_scalar = _op.const(0, dtype="int32")

if "int32" in expected_out_dtypes:
# This is the adaptation of the QLinearMatMul converter for MatMulInteger,
# in the MatMulInteger case we skip the unnecessary requantization step.
return matmul_result

# requantize requires y_scale to be constant,
# if y_scale is not constant, doing dequantize -> quantize
if isinstance(y_scale_scalar, _expr.Constant):
Expand All @@ -4053,6 +4065,58 @@ def try_resolve_to_const(x, dtype_override=None):
return y


class MatMulInteger(OnnxOpConverter):
"""Operator converter for MatMulInteger."""

@classmethod
def _impl_v10(cls, inputs, attr, params):
a = inputs[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so pretty sure you can just do relay.nn.matmul(..., out_dtype='int32'). You don't need all of this QLinearMatMul stuff to handle accumulation without overflow imo.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If scale is always 1 and zero point is always 0, I think that's true.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My impression was that MatMulInteger exactly corresponds to our qnn.dense. We can use QLinearMatMul converter but we want to skip requantize.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmulinteger <-- I was fixated on the line where it said it was identical to numpy's. It appears the zero points are not fixed always so using QLinearMatMul is probably best choice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified QLinearMatMul converter to skip requantize in the MatMulInteger case

b = inputs[1]

a_dtype = infer_type(a).checked_type.dtype
b_dtype = infer_type(b).checked_type.dtype

assert a_dtype in ("int8", "uint8"), "MatMulInteger: invalid dtype for first input"
assert b_dtype in ("int8", "uint8"), "MatMulInteger: invalid dtype for second input"

assert a_dtype == b_dtype, "MatMulInteger: input dtypes must match"

a_scale = _op.const(1.0, dtype="float32")
b_scale = _op.const(1.0, dtype="float32")
out_scale = _op.const(1.0, dtype="float32")

a_zero_point = _op.const(0.0, dtype=a_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zero point does not seem to be fixed to 0? The scale points appear to be not present though which is weird. You might have to read the codebase and see what the intention really is since the documentation is confusing (and my guess has a mistake in it): https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmulinteger

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah zp isn't necessarily fixed to 0, but the docs say 0 should be the default. I do this a few lines down:

        if len(inputs) == 4:
            a_zero_point = inputs[2]
            b_zero_point = inputs[3]

From the example it looks like the scale factors are 1 by default.

b_zero_point = _op.const(0.0, dtype=b_dtype)
out_zero_point = _op.const(0.0, dtype="int32")

if len(inputs) == 4:
a_zero_point = inputs[2]
b_zero_point = inputs[3]

a_zp_dtype = infer_type(a_zero_point).checked_type.dtype
b_zp_dtype = infer_type(b_zero_point).checked_type.dtype
assert (
a_zp_dtype == a_dtype and b_zp_dtype == b_dtype
), "MatMulInteger: input dtype doesn't match zero point dtype"
elif len(inputs) != 2:
raise AssertionError(
"MatMulInteger op takes 2 or 4 inputs, {} given".format(len(inputs))
)

inputs = [
a,
a_scale,
a_zero_point,
b,
b_scale,
b_zero_point,
out_scale,
out_zero_point,
]

return QLinearMatMul.get_converter(10)(inputs, attr, params, expected_out_dtypes=["int32"])


class QLinearMul(OnnxOpConverter):
"""Operator converter for QLinearMul from Microsoft onnxruntime contrib opset."""

Expand Down Expand Up @@ -4781,6 +4845,7 @@ def _get_convert_map(opset):
"Softsign": Softsign.get_converter(opset),
"Gemm": Gemm.get_converter(opset),
"MatMul": MatMul.get_converter(opset),
"MatMulInteger": MatMulInteger.get_converter(opset),
"MatMulInteger16": MatMulInteger16.get_converter(opset),
"Mod": Mod.get_converter(opset),
"Xor": Renamer("logical_xor"),
Expand Down
22 changes: 20 additions & 2 deletions python/tvm/topi/cuda/tensorcore_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,22 @@ def _dense_legalize(attrs, inputs, arg_types):
return None

(dm, dk, dn), extra_flops_ratio = pad_to_tensorcore(M, K, N, candidates)
skip_pad = extra_flops_ratio > 2

if skip_pad and dtype in ["int8", "uint8"]:
skip_pad = False
# If tensorcore schedule padding fails, pad to nearest upward 4x4x4 as long as
# the additional flops ratio isn't double or more.
# Note that 4x4x4 is invalid for tensorcore scheduling, but padding upwards to 4x4x4
# doesn't hurt if tensorcore padding has already failed.
if M % 4 == 0 and K % 4 == 0 and N % 4 == 0:
# No need to pad
return None
(dm, dk, dn) = _pad_to(M, K, N, (4, 4, 4))
extra_flops_ratio = _extra_flops(M, K, N, dm, dk, dn) / (M * K * N)
skip_pad = extra_flops_ratio > 2

if extra_flops_ratio > 2:
if skip_pad:
logger.info("dense pad_to_tensorcore skipped, extra_flops_ratio %s", extra_flops_ratio)
return None

Expand Down Expand Up @@ -198,14 +212,18 @@ def pad_to_tensorcore(M, K, N, candidates):
best_pad = (0, 0, 0)
for padding in candidates:
dm, dk, dn = _pad_to(M, K, N, padding)
e = (M + dm) * (N + dn) * (K + dk) - M * N * K
e = _extra_flops(M, K, N, dm, dk, dn)
# print(dm, dk, dn, e, flops)
if e < extra_flops:
extra_flops = e
best_pad = (dm, dk, dn)
return best_pad, extra_flops / flops


def _extra_flops(M, K, N, dm, dk, dn):
return (M + dm) * (N + dn) * (K + dk) - M * N * K


def _pad_to(M, K, N, PADDING):
dm, dk, dn = 0, 0, 0

Expand Down
1 change: 0 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5053,7 +5053,6 @@ def verify_eyelike(indata):
"test_loop11",
"test_loop13_seq",
"test_lstm_batchwise",
"test_matmulinteger",
"test_maxpool_with_argmax_2d_precomputed_pads",
"test_maxpool_with_argmax_2d_precomputed_strides",
"test_maxunpool_export_with_output_shape",
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relay/test_pass_legalize_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def expected():
a = before()
a = run_opt_pass(a, transform.Legalize())
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "Expected = \n" + str(b)

# dense
Expand All @@ -259,7 +260,7 @@ def expected():
_test_legalize_dense((8, 16), (31, 16), (0, 0, 1), dtype)
_test_legalize_dense((7, 15), (31, 15), (1, 1, 1), dtype)
_test_legalize_dense((3, 16), (32, 16), (5, 0, 0), dtype)
_test_legalize_dense((2, 16), (32, 16), (0, 0, 0), dtype, False)
_test_legalize_dense((1, 16), (32, 16), (0, 0, 0), dtype, False)

# Test if units parameter is correctly updated
_test_legalize_dense((8, 16), (30, 16), (0, 0, 2), "float16", units=30)
Expand Down