|
32 | 32 | from .. import op as _op |
33 | 33 | from .. import qnn as _qnn |
34 | 34 | from .common import ExprTable |
| 35 | +from .common import fold_constant as _fold_constant |
35 | 36 | from .common import infer_shape as _infer_shape |
| 37 | +from .common import infer_type as _infer_type |
36 | 38 | from .common import lstm_cell, to_int_list, shape_of, try_infer_value |
37 | 39 | from .common import set_span |
38 | 40 | from .tflite_flexbuffer import FlexBufferDecoder |
@@ -80,6 +82,7 @@ def __init__(self, model, subgraph, exp_tab): |
80 | 82 | "ARG_MIN": self.convert_arg_min, |
81 | 83 | "AVERAGE_POOL_2D": self.convert_average_pool2d, |
82 | 84 | "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd, |
| 85 | + "BATCH_MATMUL": self.convert_batch_matmul, |
83 | 86 | "CAST": self.convert_cast, |
84 | 87 | "CEIL": self.convert_ceil, |
85 | 88 | "CONCATENATION": self.convert_concatenation, |
@@ -492,6 +495,21 @@ def get_tensor_type_str(self, tensor_type): |
492 | 495 | "Tensor type {} is currently not supported".format(str(tensor_type)) |
493 | 496 | ) |
494 | 497 |
|
| 498 | + def flatten_to_nd(self, x, x_shape, nd=3): |
| 499 | + """Flatten input tensor to nd rank""" |
| 500 | + ndims = _infer_shape(x_shape)[0] |
| 501 | + if ndims == nd: |
| 502 | + return x |
| 503 | + newshape = _op.concatenate( |
| 504 | + [ |
| 505 | + _expr.const([-1], dtype=_infer_type(x_shape).checked_type.dtype), |
| 506 | + _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), |
| 507 | + ], |
| 508 | + 0, |
| 509 | + ) |
| 510 | + out = _op.reshape(x, _fold_constant(newshape)) |
| 511 | + return out |
| 512 | + |
495 | 513 | def has_same_qnn_params(self, lhs_tensor, rhs_tensor): |
496 | 514 | lhs_scale = lhs_tensor.qnn_params["scale"] |
497 | 515 | rhs_scale = rhs_tensor.qnn_params["scale"] |
@@ -2959,6 +2977,135 @@ def convert_batch_to_space_nd(self, op): |
2959 | 2977 |
|
2960 | 2978 | return out |
2961 | 2979 |
|
| 2980 | + def convert_batch_matmul(self, op): |
| 2981 | + """batch_matmul implementation.""" |
| 2982 | + try: |
| 2983 | + from tflite.BatchMatMulOptions import BatchMatMulOptions |
| 2984 | + except ImportError: |
| 2985 | + raise ImportError("The tflite package must be installed") |
| 2986 | + |
| 2987 | + input_tensors = self.get_input_tensors(op) |
| 2988 | + |
| 2989 | + assert len(input_tensors) == 2, "two input tensor arguments expected" |
| 2990 | + |
| 2991 | + batch_matmul_options = BatchMatMulOptions() |
| 2992 | + op_options = op.BuiltinOptions() |
| 2993 | + batch_matmul_options.Init(op_options.Bytes, op_options.Pos) |
| 2994 | + |
| 2995 | + input_a = self.get_expr(input_tensors[0].tensor_idx) |
| 2996 | + input_b = self.get_expr(input_tensors[1].tensor_idx) |
| 2997 | + |
| 2998 | + shape_a = shape_of(input_a) |
| 2999 | + shape_b = shape_of(input_b) |
| 3000 | + rank_a = _infer_shape(shape_a)[0] |
| 3001 | + rank_b = _infer_shape(shape_b)[0] |
| 3002 | + |
| 3003 | + if rank_a > 2 or rank_b > 2: |
| 3004 | + # Determine the output batch dimension |
| 3005 | + new_a_shape = shape_a |
| 3006 | + new_b_shape = shape_b |
| 3007 | + if rank_a > rank_b: |
| 3008 | + rank_diff = rank_a - rank_b |
| 3009 | + new_b_shape = _op.concatenate( |
| 3010 | + [ |
| 3011 | + _expr.const([1] * rank_diff, dtype=_infer_type(b_shape).checked_type.dtype), |
| 3012 | + shape_b, |
| 3013 | + ], |
| 3014 | + 0, |
| 3015 | + ) |
| 3016 | + elif rank_a < rank_b: |
| 3017 | + rank_diff = rank_b - rank_a |
| 3018 | + new_a_shape = _op.concatenate( |
| 3019 | + [ |
| 3020 | + _expr.const([1] * rank_diff, dtype=_infer_type(a_shape).checked_type.dtype), |
| 3021 | + shape_a, |
| 3022 | + ], |
| 3023 | + 0, |
| 3024 | + ) |
| 3025 | + else: |
| 3026 | + pass |
| 3027 | + |
| 3028 | + out_batch = _op.concatenate( |
| 3029 | + [ |
| 3030 | + _op.maximum( |
| 3031 | + _op.strided_slice(new_b_shape, [i], [i + 1]), |
| 3032 | + _op.strided_slice(new_a_shape, [i], [i + 1]), |
| 3033 | + ) |
| 3034 | + for i in range(max(rank_a, rank_b) - 2) |
| 3035 | + ], |
| 3036 | + 0, |
| 3037 | + ) |
| 3038 | + |
| 3039 | + a_broadcasted_shape = _fold_constant( |
| 3040 | + _op.concatenate( |
| 3041 | + [ |
| 3042 | + out_batch, |
| 3043 | + _op.strided_slice(shape_a, [rank_a - 2], [rank_a]), |
| 3044 | + ], |
| 3045 | + 0, |
| 3046 | + ) |
| 3047 | + ) |
| 3048 | + b_broadcasted_shape = _fold_constant( |
| 3049 | + _op.concatenate( |
| 3050 | + [ |
| 3051 | + out_batch, |
| 3052 | + _op.strided_slice(shape_b, [rank_b - 2], [rank_b]), |
| 3053 | + ], |
| 3054 | + 0, |
| 3055 | + ) |
| 3056 | + ) |
| 3057 | + if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape): |
| 3058 | + input_a = _op.transform.broadcast_to(a, a_broadcasted_shape) |
| 3059 | + if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape): |
| 3060 | + input_b = _op.transform.broadcast_to(b, b_broadcasted_shape) |
| 3061 | + |
| 3062 | + input_a = self.flatten_to_nd(input_a, shape_a, 3) |
| 3063 | + input_b = self.flatten_to_nd(input_b, shape_b, 3) |
| 3064 | + |
| 3065 | + if batch_matmul_options.AdjX(): |
| 3066 | + input_a = _op.transpose(input_a, [0, 2, 1]) |
| 3067 | + if not batch_matmul_options.AdjY(): |
| 3068 | + input_b = _op.transpose(input_b, [0, 2, 1]) |
| 3069 | + |
| 3070 | + if self.is_quantized(op): |
| 3071 | + output = _qnn.op.batch_matmul( |
| 3072 | + input_a, |
| 3073 | + input_b, |
| 3074 | + relay.const(0, "int32"), |
| 3075 | + relay.const(0, "int32"), |
| 3076 | + relay.const(1.0, "float32"), |
| 3077 | + relay.const(1.0, "float32"), |
| 3078 | + ) |
| 3079 | + else: |
| 3080 | + output = _op.nn.batch_matmul(input_a, input_b) |
| 3081 | + |
| 3082 | + # Reshape output to original dimensions. |
| 3083 | + output_shape = shape_of(output) |
| 3084 | + |
| 3085 | + rank_out = _infer_shape(output_shape)[0] |
| 3086 | + |
| 3087 | + final_shape = _op.concatenate( |
| 3088 | + [ |
| 3089 | + _op.strided_slice(shape_a, [0], [rank_a - 2]), |
| 3090 | + _op.strided_slice(output_shape, [rank_out - 2], [rank_out]), |
| 3091 | + ], |
| 3092 | + 0, |
| 3093 | + ) |
| 3094 | + |
| 3095 | + reshape = _op.reshape(output, _fold_constant(final_shape)) |
| 3096 | + # qnn batch matmul returns a int32 tensor so we need to requantize |
| 3097 | + if self.is_quantized(op): |
| 3098 | + return _qnn.op.requantize( |
| 3099 | + reshape, |
| 3100 | + relay.const(1.0, "float32"), |
| 3101 | + relay.const(0, "int32"), |
| 3102 | + relay.const(1.0, "float32"), |
| 3103 | + relay.const(0, "int32"), |
| 3104 | + out_dtype="int8", |
| 3105 | + ) |
| 3106 | + else: |
| 3107 | + return reshape |
| 3108 | + |
2962 | 3109 | def convert_space_to_batch_nd(self, op): |
2963 | 3110 | """space_to_batch_nd implementation.""" |
2964 | 3111 | input_tensors = self.get_input_tensors(op) |
|
0 commit comments