Skip to content

Commit 70399da

Browse files
author
neildhickey
authored
[TFLite] Support for BATCH_MATMUL tflite operator (#14423)
* [TFLite] Support for BATCH_MATMUL tflite operator Adds support for BATCH_MATMUL operator in the TFLite frontend. Adds a test that checks supported TFLite types. * Fixing linting issues * Fixing more lint issues * Fixing compare_tflite function for input_tensors < 2
1 parent 41fb9f4 commit 70399da

File tree

2 files changed

+212
-9
lines changed

2 files changed

+212
-9
lines changed

python/tvm/relay/frontend/tflite.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
from .. import op as _op
3333
from .. import qnn as _qnn
3434
from .common import ExprTable
35+
from .common import fold_constant as _fold_constant
3536
from .common import infer_shape as _infer_shape
37+
from .common import infer_type as _infer_type
3638
from .common import lstm_cell, to_int_list, shape_of, try_infer_value
3739
from .common import set_span
3840
from .tflite_flexbuffer import FlexBufferDecoder
@@ -80,6 +82,7 @@ def __init__(self, model, subgraph, exp_tab):
8082
"ARG_MIN": self.convert_arg_min,
8183
"AVERAGE_POOL_2D": self.convert_average_pool2d,
8284
"BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
85+
"BATCH_MATMUL": self.convert_batch_matmul,
8386
"CAST": self.convert_cast,
8487
"CEIL": self.convert_ceil,
8588
"CONCATENATION": self.convert_concatenation,
@@ -492,6 +495,21 @@ def get_tensor_type_str(self, tensor_type):
492495
"Tensor type {} is currently not supported".format(str(tensor_type))
493496
)
494497

498+
def flatten_to_nd(self, x, x_shape, nd=3):
499+
"""Flatten input tensor to nd rank"""
500+
ndims = _infer_shape(x_shape)[0]
501+
if ndims == nd:
502+
return x
503+
newshape = _op.concatenate(
504+
[
505+
_expr.const([-1], dtype=_infer_type(x_shape).checked_type.dtype),
506+
_op.strided_slice(x_shape, [ndims - nd + 1], [ndims]),
507+
],
508+
0,
509+
)
510+
out = _op.reshape(x, _fold_constant(newshape))
511+
return out
512+
495513
def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
496514
lhs_scale = lhs_tensor.qnn_params["scale"]
497515
rhs_scale = rhs_tensor.qnn_params["scale"]
@@ -2959,6 +2977,135 @@ def convert_batch_to_space_nd(self, op):
29592977

29602978
return out
29612979

2980+
def convert_batch_matmul(self, op):
2981+
"""batch_matmul implementation."""
2982+
try:
2983+
from tflite.BatchMatMulOptions import BatchMatMulOptions
2984+
except ImportError:
2985+
raise ImportError("The tflite package must be installed")
2986+
2987+
input_tensors = self.get_input_tensors(op)
2988+
2989+
assert len(input_tensors) == 2, "two input tensor arguments expected"
2990+
2991+
batch_matmul_options = BatchMatMulOptions()
2992+
op_options = op.BuiltinOptions()
2993+
batch_matmul_options.Init(op_options.Bytes, op_options.Pos)
2994+
2995+
input_a = self.get_expr(input_tensors[0].tensor_idx)
2996+
input_b = self.get_expr(input_tensors[1].tensor_idx)
2997+
2998+
shape_a = shape_of(input_a)
2999+
shape_b = shape_of(input_b)
3000+
rank_a = _infer_shape(shape_a)[0]
3001+
rank_b = _infer_shape(shape_b)[0]
3002+
3003+
if rank_a > 2 or rank_b > 2:
3004+
# Determine the output batch dimension
3005+
new_a_shape = shape_a
3006+
new_b_shape = shape_b
3007+
if rank_a > rank_b:
3008+
rank_diff = rank_a - rank_b
3009+
new_b_shape = _op.concatenate(
3010+
[
3011+
_expr.const([1] * rank_diff, dtype=_infer_type(b_shape).checked_type.dtype),
3012+
shape_b,
3013+
],
3014+
0,
3015+
)
3016+
elif rank_a < rank_b:
3017+
rank_diff = rank_b - rank_a
3018+
new_a_shape = _op.concatenate(
3019+
[
3020+
_expr.const([1] * rank_diff, dtype=_infer_type(a_shape).checked_type.dtype),
3021+
shape_a,
3022+
],
3023+
0,
3024+
)
3025+
else:
3026+
pass
3027+
3028+
out_batch = _op.concatenate(
3029+
[
3030+
_op.maximum(
3031+
_op.strided_slice(new_b_shape, [i], [i + 1]),
3032+
_op.strided_slice(new_a_shape, [i], [i + 1]),
3033+
)
3034+
for i in range(max(rank_a, rank_b) - 2)
3035+
],
3036+
0,
3037+
)
3038+
3039+
a_broadcasted_shape = _fold_constant(
3040+
_op.concatenate(
3041+
[
3042+
out_batch,
3043+
_op.strided_slice(shape_a, [rank_a - 2], [rank_a]),
3044+
],
3045+
0,
3046+
)
3047+
)
3048+
b_broadcasted_shape = _fold_constant(
3049+
_op.concatenate(
3050+
[
3051+
out_batch,
3052+
_op.strided_slice(shape_b, [rank_b - 2], [rank_b]),
3053+
],
3054+
0,
3055+
)
3056+
)
3057+
if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape):
3058+
input_a = _op.transform.broadcast_to(a, a_broadcasted_shape)
3059+
if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape):
3060+
input_b = _op.transform.broadcast_to(b, b_broadcasted_shape)
3061+
3062+
input_a = self.flatten_to_nd(input_a, shape_a, 3)
3063+
input_b = self.flatten_to_nd(input_b, shape_b, 3)
3064+
3065+
if batch_matmul_options.AdjX():
3066+
input_a = _op.transpose(input_a, [0, 2, 1])
3067+
if not batch_matmul_options.AdjY():
3068+
input_b = _op.transpose(input_b, [0, 2, 1])
3069+
3070+
if self.is_quantized(op):
3071+
output = _qnn.op.batch_matmul(
3072+
input_a,
3073+
input_b,
3074+
relay.const(0, "int32"),
3075+
relay.const(0, "int32"),
3076+
relay.const(1.0, "float32"),
3077+
relay.const(1.0, "float32"),
3078+
)
3079+
else:
3080+
output = _op.nn.batch_matmul(input_a, input_b)
3081+
3082+
# Reshape output to original dimensions.
3083+
output_shape = shape_of(output)
3084+
3085+
rank_out = _infer_shape(output_shape)[0]
3086+
3087+
final_shape = _op.concatenate(
3088+
[
3089+
_op.strided_slice(shape_a, [0], [rank_a - 2]),
3090+
_op.strided_slice(output_shape, [rank_out - 2], [rank_out]),
3091+
],
3092+
0,
3093+
)
3094+
3095+
reshape = _op.reshape(output, _fold_constant(final_shape))
3096+
# qnn batch matmul returns a int32 tensor so we need to requantize
3097+
if self.is_quantized(op):
3098+
return _qnn.op.requantize(
3099+
reshape,
3100+
relay.const(1.0, "float32"),
3101+
relay.const(0, "int32"),
3102+
relay.const(1.0, "float32"),
3103+
relay.const(0, "int32"),
3104+
out_dtype="int8",
3105+
)
3106+
else:
3107+
return reshape
3108+
29623109
def convert_space_to_batch_nd(self, op):
29633110
"""space_to_batch_nd implementation."""
29643111
input_tensors = self.get_input_tensors(op)

tests/python/frontend/tflite/test_forward.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from tensorflow.python.ops import gen_array_ops
6262
from tensorflow.python.ops import nn_impl
6363
from tensorflow.python.ops import variables
64+
from tensorflow import raw_ops
6465

6566
try:
6667
from tensorflow import lite as interpreter_wrapper
@@ -319,6 +320,13 @@ def compare_tflite_with_tvm(
319320
sess.run(variables.global_variables_initializer())
320321
# convert to tflite model
321322
converter = tf.lite.TFLiteConverter.from_session(sess, input_tensors, output_tensors)
323+
324+
if len(input_tensors) > 1:
325+
if len(input_tensors[0].shape) <= 4 and len(input_tensors[1].shape) <= 4:
326+
converter._experimental_disable_batchmatmul_unfold = True
327+
else:
328+
converter._experimental_disable_batchmatmul_unfold = False
329+
322330
converter.experimental_new_converter = experimental_new_converter
323331
if quantized:
324332
if int_quant_dtype == tf.int16:
@@ -734,24 +742,72 @@ def test_forward_cast():
734742
#######################################################################
735743
# Batch Mat Mul
736744
# ----
737-
def _test_batch_matmul(a_shape, b_shape, dtype, adjoint_a=False, adjoint_b=False):
745+
def _test_batch_matmul(
746+
a_shape, b_shape, dtype, out_dtype, adjoint_a=False, adjoint_b=False, quantized=False
747+
):
738748
with tf.Graph().as_default():
739749
a = array_ops.placeholder(shape=a_shape, dtype=dtype, name="A")
740750
b = array_ops.placeholder(shape=b_shape, dtype=dtype, name="B")
741-
result = math_ops.matmul(a, b, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name="batchmatmul")
751+
print(tf.__version__)
752+
753+
result = raw_ops.BatchMatMulV3(
754+
x=a, y=b, Tout=out_dtype, adj_x=adjoint_a, adj_y=adjoint_b, name="batchmatmul"
755+
)
756+
input_range = {"A": (-100, 100), "B": (-100, 100)} if quantized else None
742757

743758
a_np = np.random.uniform(high=5.0, size=a_shape).astype(dtype)
744759
b_np = np.random.uniform(high=5.0, size=b_shape).astype(dtype)
745-
compare_tflite_with_tvm([a_np, b_np], [a.name, b.name], [a, b], [result])
760+
compare_tflite_with_tvm(
761+
[a_np, b_np],
762+
[a.name, b.name],
763+
[a, b],
764+
[result],
765+
experimental_new_converter=True,
766+
quantized=quantized,
767+
input_range=input_range,
768+
)
746769

747770

748-
def test_forward_batch_matmul():
771+
@pytest.mark.parametrize("config", [("int8", "int32", True), ("float32", "float32", False)])
772+
def test_forward_batch_matmul(config):
749773
"""BATCH_MAT_MUL"""
750-
_test_batch_matmul((3, 5, 4), (3, 4, 5), "float32")
751-
_test_batch_matmul((3, 5, 4), (3, 4, 5), "float32", True, True)
752-
_test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", True, False)
753-
_test_batch_matmul((3, 5, 4), (3, 5, 4), "float32", False, True)
754-
_test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), "float32")
774+
_test_batch_matmul(
775+
(3, 5, 4), (3, 4, 5), dtype=config[0], out_dtype=config[1], quantized=config[2]
776+
)
777+
_test_batch_matmul(
778+
(3, 5, 4),
779+
(3, 4, 5),
780+
dtype=config[0],
781+
out_dtype=config[1],
782+
adjoint_a=True,
783+
adjoint_b=True,
784+
quantized=config[2],
785+
)
786+
_test_batch_matmul(
787+
(3, 5, 4),
788+
(3, 5, 4),
789+
dtype=config[0],
790+
out_dtype=config[1],
791+
adjoint_a=True,
792+
adjoint_b=False,
793+
quantized=config[2],
794+
)
795+
_test_batch_matmul(
796+
(3, 5, 4),
797+
(3, 5, 4),
798+
dtype=config[0],
799+
out_dtype=config[1],
800+
adjoint_a=False,
801+
adjoint_b=True,
802+
quantized=config[2],
803+
)
804+
_test_batch_matmul(
805+
(3, 4, 5, 6), (3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2]
806+
)
807+
# BatchMatMul doesn't support larger than 4D tensors
808+
# _test_batch_matmul(
809+
# (2, 3, 4, 5, 6), (2, 3, 4, 6, 5), dtype=config[0], out_dtype=config[1], quantized=config[2]
810+
# )
755811

756812

757813
#######################################################################

0 commit comments

Comments
 (0)