diff --git a/python/tvm/relay/backend/contrib/ethosu/codegen.py b/python/tvm/relay/backend/contrib/ethosu/codegen.py index 5a5f1478e16e..02533e7a9b1d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/codegen.py +++ b/python/tvm/relay/backend/contrib/ethosu/codegen.py @@ -32,7 +32,8 @@ ) from tvm.relay.backend.contrib.ethosu.legalize import LegalizeEthosU from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator, util, vela_api -from tvm.relay.expr_functor import ExprMutator, ExprVisitor +from tvm.relay.expr_functor import ExprMutator, ExprVisitor, Call +from tvm.relay import expr as _expr # pylint: disable=unused-import from tvm.relay.backend.contrib.ethosu.op import op_attrs @@ -357,6 +358,92 @@ def __call__(self, *args, **kwargs): pass +class PadsWithMultipleConsumersReplicator(ExprMutator): + """A pass to to handle the situation when nn.pad operator has + more than one qnn.conv2d consumer. + + pad + / \ + Conv2D Conv2D + + In this case, because of the peculiarities of pattern parsing, + conv2d does not get into the composite for the NPU. + Therefore, pads are added so that each has only one consumer. + """ + + def __init__(self): + super().__init__() + # a set to record hashes of an pads which already have one qnn.conv2d consumer + self.hashes = set() + + def visit_call(self, call: tvm.relay.expr.Call) -> tvm.relay.expr.Call: + if ( + isinstance(call.op, tvm.ir.Op) + and isinstance(call.args[0], Call) + and isinstance(call.args[0].op, tvm.ir.Op) + and call.op == relay.op.get("qnn.conv2d") + and call.args[0].op == relay.op.get("nn.pad") + ): + if tvm.ir.structural_hash(call.args[0]) not in self.hashes: + # add the hash of nn.pad to set + self.hashes.add(tvm.ir.structural_hash(call.args[0])) + else: + # if this pad already has a conv2d consumer, duplicate the pad + # and make it an input for current conv2d + used_pad = self.visit(call.args[0]) + used_pad_args = [self.visit(arg) for arg in used_pad.args] + new_pad = Call( + used_pad.op, used_pad_args, used_pad.attrs, used_pad.type_args, used_pad.span + ) + new_conv2d_args = [] + for i, arg in enumerate(call.args): + if i == 0: + new_conv2d_args.append(self.visit(new_pad)) + else: + new_conv2d_args.append(self.visit(arg)) + new_conv2d_op = self.visit(call.op) + expr__ = _expr.CallWithFields( + call, + new_conv2d_op, + new_conv2d_args, + call.attrs, + call.type_args, + None, + call.span, + ) + return expr__ + + new_args = [self.visit(arg) for arg in call.args] + new_op = self.visit(call.op) + expr__ = _expr.CallWithFields( + call, new_op, new_args, call.attrs, call.type_args, None, call.span + ) + return expr__ + + +def replicate_pads(mod): + """Traverses the Relay graph to replicate nn.pad operators if thay have + multiple qnn.conv2d consumers. That making remove the situation when + e.g. pad+conv2d corresponds qnn_conv2d_pattern, but can not be grouped + because several conv2d use the same pad operation. + + Parameters + ---------- + tvm.ir.IRModule + The IRModule that gets generated from a relay frontend. + + Returns + ------- + tvm.ir.IRModule + The IRModule without nn.pad operators with multiple consumers. + """ + replicator = PadsWithMultipleConsumersReplicator() + for global_var, func in mod.functions.items(): + func = replicator.visit(func) + mod.update_func(global_var, func) + return mod + + def IdentityOptimizer(): # pylint: disable=invalid-name """Pass that removes redundant identities diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 0796ccf62a85..386ef9038e49 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -2341,13 +2341,15 @@ def partition_for_ethosu( mod : IRModule The partitioned IRModule with external global functions """ - from tvm.relay.backend.contrib.ethosu import preprocess + from tvm.relay.backend.contrib.ethosu import preprocess, codegen if params: mod["main"] = bind_params_by_name(mod["main"], params) pattern = relay.op.contrib.get_pattern_table("ethos-u") mod = relay.transform.InferType()(mod) + mod = codegen.replicate_pads(mod) + mod = relay.transform.InferType()(mod) mod = relay.transform.MergeComposite(pattern)(mod) mod = relay.transform.AnnotateTarget("ethos-u")(mod) mod = relay.transform.MergeCompilerRegions()(mod) diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index cb1592c041ec..d56b8b6ec943 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -157,6 +157,69 @@ def conv2d_double(x): infra.compare_tvm_with_tflite(conv2d_double, [ifm_shape], accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")] +) +def test_tflite_shared_pad( + accel_type, + op_pairs, +): + np.random.seed(0) + + ifm_shape = (1, 55, 32, 3) + kernel_shape = (3, 3) + strides = (3, 2) + dilation = (1, 1) + activation_function = "RELU" + op_padding = "SAME" + sep_padding = (0, 0, 1, 1) + + @tf.function + def tf_function(x): + def make_depthwise_or_conv2d(pair_idx, x): + # The input strides to the TensorFlow API needs to be of shape 1x4 + tf_strides = [1, strides[0], strides[1], 1] + if op_pairs[pair_idx] == "depthwise": + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.depthwise_conv2d( + x, weight, strides=tf_strides, padding=op_padding, dilations=dilation + ) + else: + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + weight, + strides=tf_strides, + padding=op_padding, + dilations=dilation, + ) + if activation_function == "RELU": + op = tf.nn.relu(op) + return op + + x = tf.pad( + x, + [ + [0, 0], + [sep_padding[0], sep_padding[2]], + [sep_padding[1], sep_padding[3]], + [0, 0], + ], + "CONSTANT", + ) + + x1 = make_depthwise_or_conv2d(0, x) + x2 = make_depthwise_or_conv2d(1, x) + + x3 = tf.math.add(x1, x2) + return x3 + + infra.compare_tvm_with_tflite(tf_function, [ifm_shape], accel_type) + + @pytest.mark.parametrize("weight_min, weight_max", [(0.0, 1e-11), (-1e10, 1e10)]) def test_out_of_range_scaling(weight_min, weight_max): np.random.seed(0) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 1b643f815721..05022321df64 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -31,7 +31,7 @@ from tvm.relay.backend.contrib.ethosu import legalize, preprocess from tvm.relay import dataflow_pattern from tvm.relay.op.contrib import ethosu -from tvm.relay.backend.contrib.ethosu import util +from tvm.relay.backend.contrib.ethosu import util, codegen from tvm.relay.build_module import bind_params_by_name from tvm.relay.frontend.tflite import get_pad_value from tvm.relay.expr_functor import ExprVisitor @@ -44,6 +44,8 @@ def partition_ethosu_by_table(mod, pattern_table): want to add the operator's pattern to the pattern table so that the compiler wouldn't attempt to offload an operator without full stack support.""" mod = relay.transform.InferType()(mod) + mod = mod = codegen.replicate_pads(mod) + mod = relay.transform.InferType()(mod) mod = relay.transform.MergeComposite(pattern_table)(mod) mod = relay.transform.AnnotateTarget("ethos-u")(mod) mod = relay.transform.MergeCompilerRegions()(mod) @@ -3646,5 +3648,133 @@ def _visit(stmt): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize("ifm_shape", [(1, 55, 55, 3)]) +@pytest.mark.parametrize("kernel_shape", [(3, 3)]) +@pytest.mark.parametrize("strides, dilation", [((1, 1), (1, 1))]) +@pytest.mark.parametrize("op_padding", ["SAME", "VALID"]) +@pytest.mark.parametrize("sep_padding", [(0, 0, 1, 1), (7, 5, 4, 5)]) +@pytest.mark.parametrize( + "op_pairs", [("conv2d", "conv2d"), ("depthwise", "depthwise"), ("conv2d", "depthwise")] +) +def test_tflite_shared_pad_legalize( + ifm_shape, + kernel_shape, + strides, + dilation, + op_padding, + sep_padding, + op_pairs, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + def make_depthwise_or_conv2d(pair_idx): + if op_pairs[pair_idx] == "depthwise": + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.depthwise_conv2d( + x, weight, strides=tf_strides, padding=op_padding, dilations=dilation + ) + weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3] + weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + return tf.nn.conv2d( + x, + weight, + strides=tf_strides, + padding=op_padding, + dilations=dilation, + ) + + x = tf.pad( + x, + [ + [0, 0], + [sep_padding[0], sep_padding[2]], + [sep_padding[1], sep_padding[3]], + [0, 0], + ], + "CONSTANT", + ) + + # The input strides to the TensorFlow API needs to be of shape 1x4 + tf_strides = [1, strides[0], strides[1], 1] + + x1 = make_depthwise_or_conv2d(0) + x2 = make_depthwise_or_conv2d(1) + + x3 = tf.math.add(x1, x2) + return x3 + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + conv2d_pattern_table = [ + ( + ethosu.QnnConv2DParams.composite_name, + ethosu.qnn_conv2d_pattern(), + lambda pat: ethosu.QnnConv2DParams(pat).is_valid(), + ), + ( + ethosu.QnnDepthwiseConv2DParams.composite_name, + ethosu.qnn_depthwise_conv2d_pattern(), + lambda pat: ethosu.QnnDepthwiseConv2DParams(pat).is_valid(), + ), + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod["main"] = bind_params_by_name(mod["main"], params) + mod = partition_ethosu_by_table(mod, conv2d_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + [legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()], + mod["tvmgen_default_ethos_u_main_0"], + ) + mod["tvmgen_default_ethos_u_main_1"] = dataflow_pattern.rewrite( + [legalize.Conv2DRewriter(), legalize.DepthwiseConv2DRewriter()], + mod["tvmgen_default_ethos_u_main_1"], + ) + + if op_pairs[0] == "depthwise": + assert ( + mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.depthwise_conv2d" + ) + else: + assert mod["tvmgen_default_ethos_u_main_0"].body.op.name == "contrib.ethosu.conv2d" + + if op_pairs[1] == "depthwise": + assert ( + mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.depthwise_conv2d" + ) + else: + assert mod["tvmgen_default_ethos_u_main_1"].body.op.name == "contrib.ethosu.conv2d" + + if __name__ == "__main__": tvm.testing.main()