diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index fdd465529123..24dd9afd7bfa 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1121,6 +1121,87 @@ def callback( return reduced_op +class SumRewriter(DFPatternCallback): + """ + Convert ethosu.sum composite functions to pooling operations + """ + + def __init__(self): + super().__init__(require_type=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.SumParams.composite_name}) + )(wildcard()) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + + params = ethosu_patterns.SumParams(post.op.body) + + ifm_shape = params.ifm.shape + ofm_shape = params.ofm.shape + lut = relay.const([], "int8") + reduced_op = post.args[0] + + # Enforce 4d input + if len(ifm_shape) == 3: + ifm_shape = [1, params.height, params.width, ifm_shape[2]] + reduced_op = relay.reshape(reduced_op, ifm_shape) + + activation_map = {"clip": "CLIP"} + if params.activation: + activation = activation_map[params.activation.op.name] + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + else: + activation = "NONE" + clip_min = 0 + clip_max = 0 + + reduced_op = ethosu_ops.ethosu_pooling( + ifm=reduced_op, + lut=lut, + pooling_type="SUM", + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=0, + pool_shape=(1, 1), + ofm_channels=1, + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + ifm_layout=params.ifm.layout, + ofm_layout=params.ofm.layout, + rounding_mode="NATURAL", + ) + + # Convert tensor dtype from int32 to int8 + scalar_tensor = relay.const(np.ones([1, 1, 1, 1], dtype="int32"), dtype="int32") + reduced_op = ethosu_ops.ethosu_binary_elementwise( + ifm=reduced_op, + ifm2=scalar_tensor, + lut=lut, + operator_type="MUL", + ifm_scale=0.0, + ifm_zero_point=0, + ifm2_scale=0.0, + ifm2_zero_point=0, + ofm_scale=0.0, + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=1, + ifm2_channels=1, + reversed_operands=False, + ofm_dtype="int8", + ) + + # Reshape to original ofm shape + if len(ofm_shape) < 4: + reduced_op = relay.reshape(reduced_op, ofm_shape) + + return reduced_op + + class ConcatRewriter(DFPatternCallback): """The newer versions of TFLite converters return a concatenate operator that concatenates tensors with same QNN params (if the QNN params of tensors were initially different, @@ -1443,6 +1524,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function: HardSwishRewriter(), LeakyReLURewriter(), MeanRewriter(), + SumRewriter(), ConcatRewriter(), SigmoidRewriter(), RequantizeRewriter(), diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py index ca8c2ec9b395..6843046fd01e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py @@ -110,11 +110,12 @@ def pooling_compute( padding = [int(v) for v in padding] stride_h, stride_w = [int(v) for v in strides] pool_shape_h, pool_shape_w = [int(v) for v in pool_shape] + ifm_channels = ofm_channels if pooling_type != "SUM" else ifm.shape[-1] upscale_factor = 2 if upscale != "NONE" else 1 # Compute operation for the IFM DMA pipeline dmaed_ifm = dma_ifm_compute( - ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, padding, upscale_factor + ifm, ifm_layout, ifm_zero_point, ifm_scale, ifm_channels, padding, upscale_factor ) # Pooling compute operation @@ -122,6 +123,8 @@ def pooling_compute( ofm_width = (dmaed_ifm.shape[2] - pool_shape_w) // stride_w + 1 rh = te.reduce_axis((0, pool_shape_h), name="ry") rw = te.reduce_axis((0, pool_shape_w), name="rx") + rc = te.reduce_axis((0, 1 if pooling_type != "SUM" else ifm_channels), name="rc") + ofm_dtype = ifm.dtype if pooling_type != "SUM" else "int32" pooling_attrs = { "op": "ethosu_pooling", @@ -149,10 +152,10 @@ def pooling_compute( pooling = te.compute( (1, ofm_height, ofm_width, ofm_channels), lambda nn, hh, ww, cc: te.max( - (dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc) + lut_expr).astype( - ifm.dtype + (dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc + rc) + lut_expr).astype( + ofm_dtype ), - axis=[rh, rw], + axis=[rh, rw, rc], ), name="ethosu_pooling", attrs=pooling_attrs, diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 39d1c48a965b..ba2c6e209b72 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -822,7 +822,10 @@ def _create_npu_quantization( """This is a helper function to capture a list of arguments to create Vela NpuQuantization object. """ - return vapi.NpuQuantization(scale_f32=float(scale), zero_point=int(zero_point)) + scale = float(scale) + if scale == 0.0: + scale = None + return vapi.NpuQuantization(scale_f32=scale, zero_point=int(zero_point)) def _create_npu_weights_zero_point( @@ -960,6 +963,8 @@ def _create_npu_op_pooling(serial_pooling: spec.SerialPooling): npu_pooling_op = vapi.NpuPoolingOp.AVERAGE elif pooling_type == "MAX": npu_pooling_op = vapi.NpuPoolingOp.MAX + elif pooling_type == "SUM": + npu_pooling_op = vapi.NpuPoolingOp.REDUCE_SUM npu_pooling_op = vapi.NpuPoolingOperation(npu_pooling_op) npu_pooling_op.ifm = _create_npu_feature_map(serial_pooling.ifm) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 5d1e75b03043..a6d959c98b01 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1375,6 +1375,91 @@ def mean_pattern() -> tvm.relay.dataflow_pattern.DFPattern: return pattern +class SumParams: + """ + This class will parse a call to ethosu.sum composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.sum" + + def __init__(self, func_body: Call): + from tvm.relay.backend.contrib.ethosu.util import RequantArgs + + clip = None + if str(func_body.op.name) == "clip": + clip = func_body + requantize = clip.args[0] + else: + requantize = func_body + + sum_op = requantize.args[0] + attrs = sum_op.attrs + cast = sum_op.args[0] + + layout = "NHWC" + self.ifm = TensorParams( + cast.args[0], + layout, + requantize.args[RequantArgs.IFM_SCALE.value], + requantize.args[RequantArgs.IFM_ZERO_POINT.value], + ) + self.ofm = TensorParams( + requantize, + layout, + requantize.args[RequantArgs.OFM_SCALE.value], + requantize.args[RequantArgs.OFM_ZERO_POINT.value], + ) + + self.activation = clip + + ifm_shape = self.ifm.shape + self.height = ifm_shape[0] if len(ifm_shape) in (2, 3) else ifm_shape[1] + self.width = ifm_shape[1] if len(ifm_shape) in (2, 3) else ifm_shape[2] + self.keepdims = attrs.keepdims + + self.axis = list(sorted(attrs.axis)) + if attrs.exclude: + self.axis = [i for i in range(len(self.ifm.shape)) if i not in self.axis] + + def is_valid(self) -> bool: + """ + Checks whether Sum has compatible attributes with HW. + """ + + ifm_shape_len = len(self.ifm.shape) + + if not check_valid_dtypes([self.ifm], [np.uint8, np.int8, np.int16, np.int32]): + return False + if not check_valid_dtypes([self.ofm], [np.int8]): + return False + if not ifm_shape_len in (3, 4): + return False + if ifm_shape_len == 3 and self.axis not in [[2]]: + return False + if ifm_shape_len == 4 and self.axis not in [[3]]: + return False + + return True + + +def sum_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for sum. + """ + pattern = is_op("cast")(wildcard()) + pattern = is_op("sum")(pattern) + pattern = is_op("qnn.requantize")( + pattern, + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + pattern = pattern.optional(is_op("clip")) + return pattern + + class ConcatParams: """ This class will parse a call to a ethos-u.concat composite function @@ -1995,6 +2080,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal mean_pattern(), lambda pat: MeanParams(pat).is_valid(), ), + ( + SumParams.composite_name, + sum_pattern(), + lambda pat: SumParams(pat).is_valid(), + ), ( LeakyReLUParams.composite_name, leaky_relu_pattern(), diff --git a/src/relay/op/contrib/ethosu/op_attrs.h b/src/relay/op/contrib/ethosu/op_attrs.h index 9eac004e185d..e4ba2cfb9bad 100644 --- a/src/relay/op/contrib/ethosu/op_attrs.h +++ b/src/relay/op/contrib/ethosu/op_attrs.h @@ -361,7 +361,9 @@ struct EthosuPoolingAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(EthosuPoolingAttrs, "relay.attrs.EthosuPoolingAttrs") { TVM_ATTR_FIELD(pooling_type) - .describe("The type of the pooling. 'AVG' - average pool, 'MAX' - max pool."); + .describe( + "The type of the pooling. 'AVG' - average pool, 'MAX' - max pool, " + "'SUM' - reduce sum pool."); TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor."); TVM_ATTR_FIELD(ifm_zero_point) .describe("The quantization zero point for the Input Feature Map tensor."); diff --git a/src/relay/op/contrib/ethosu/pooling.cc b/src/relay/op/contrib/ethosu/pooling.cc index 8ad5909f0c17..a9c072a01121 100644 --- a/src/relay/op/contrib/ethosu/pooling.cc +++ b/src/relay/op/contrib/ethosu/pooling.cc @@ -46,14 +46,28 @@ bool EthosuPoolingRel(const Array& types, int num_inputs, const Attrs& att const String operator_name = "ethosu_pooling"; - if (param->pooling_type != "AVG" && param->pooling_type != "MAX") { + if (param->pooling_type != "AVG" && param->pooling_type != "MAX" && + param->pooling_type != "SUM") { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << "Invalid operator: expected " << operator_name - << " type 'AVG' or 'MAX' but was " << param->pooling_type); + << " type 'AVG', 'MAX', or 'SUM' but was " + << param->pooling_type); return false; } - CheckDataType(reporter, ifm->dtype, {DataType::UInt(8), DataType::Int(8)}, operator_name, "ifm", + std::initializer_list max_avg_pooling_ifm_dtypes = {DataType::UInt(8), DataType::Int(8), + DataType::Int(16)}; + std::initializer_list sum_pooling_ifm_dtypes = {DataType::UInt(8), DataType::Int(8), + DataType::Int(16), DataType::Int(32)}; + + std::initializer_list& allowed_ifm_dtypes = max_avg_pooling_ifm_dtypes; + auto ofm_dtype = ifm->dtype; + if (param->pooling_type == "SUM") { + allowed_ifm_dtypes = sum_pooling_ifm_dtypes; + ofm_dtype = DataType::Int(32); + } + + CheckDataType(reporter, ifm->dtype, allowed_ifm_dtypes, operator_name, "ifm", param->pooling_type); CheckUpscaleMethod(reporter, param->upscale, {"NONE", "ZEROS", "NEAREST"}, operator_name); @@ -67,7 +81,8 @@ bool EthosuPoolingRel(const Array& types, int num_inputs, const Attrs& att auto ofm_shape = EthosuInferKernelOutput( ifm_shape, param->ifm_layout, param->ofm_layout, param->pool_shape, param->ofm_channels, Array({1, 1}), param->strides, param->padding); - reporter->Assign(types[result_index], TensorType(ofm_shape, ifm->dtype)); + + reporter->Assign(types[result_index], TensorType(ofm_shape, ofm_dtype)); return true; } diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index f07fd0463b5a..6eb382d8f588 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -490,6 +490,36 @@ def create_mod_from_relay(): infra.verify_source(compiled_models, test_runner) +@pytest.mark.parametrize( + "accel_type", + ACCEL_TYPES, +) +@pytest.mark.parametrize( + "ifm_shape, axis, keepdims, relu", + [ + [(1, 4, 2, 8), 3, False, False], + [(1, 4, 4, 1), 3, False, True], + [(3, 5, 7), 2, False, True], + [(1, 4, 2, 8), 3, True, False], + [(3, 5, 7), 2, True, False], + ], +) +def test_ethosu_sum(accel_type, ifm_shape, axis, keepdims, relu): + np.random.seed(0) + + @tf.function + def sum_func(x): + op = tf.math.reduce_sum(x, axis=axis, keepdims=keepdims) + return tf.nn.relu(op) if relu else op + + infra.compare_tvm_with_tflite( + sum_func, + [ifm_shape], + accel_type, + enable_cascader=is_u55_accel_type(accel_type), + ) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize("dtype", ["int8", "uint8"]) @pytest.mark.parametrize("constant", [np.ones((1, 1, 1, 1)), np.array(1)]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index c445ceb2f3e3..0bd9c1ac3bf4 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -1694,6 +1694,123 @@ def calculate_expected_output_shape(): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize( + "ifm_shape, axis, keepdims, relu", + [ + [(1, 4, 2, 8), 3, False, False], + [(1, 4, 4, 1), 3, False, True], + [(3, 5, 7), 2, False, True], + [(1, 4, 2, 8), 3, True, False], + [(3, 5, 7), 2, True, False], + ], +) +def test_ethosu_sum(ifm_shape, axis, keepdims, relu): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + op = tf.math.reduce_sum(x, axis=axis, keepdims=keepdims) + return tf.nn.relu(op) if relu else op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + return mod + + def verify(ext_func): + out_var = ext_func.body + + binary_elementwise_op = None + pooling_op = None + next_op = out_var + if ( + isinstance(next_op, relay.expr.Call) + and isinstance(next_op.op, tvm.ir.op.Op) + and next_op.op.name == "reshape" + ): + next_op = next_op.args[0] + binary_elementwise_op = next_op + pooling_op = binary_elementwise_op.args[0] + next_op = pooling_op.args[0] + if ( + isinstance(next_op, relay.expr.Call) + and isinstance(next_op.op, tvm.ir.op.Op) + and next_op.op.name == "reshape" + ): + next_op = next_op.args[0] + in_var = next_op + + def calculate_expected_output_shape(): + for i in range(len(ifm_shape)): + if i != axis: + yield ifm_shape[i] + elif keepdims: + yield 1 + + out_shape = tuple(calculate_expected_output_shape()) + + # check IFM + assert tuple(in_var.checked_type.shape) == ifm_shape + assert in_var.checked_type.dtype == dtype + + # check OFM + assert tuple(out_var.checked_type.shape) == out_shape + assert out_var.checked_type.dtype == dtype + + # check expected legalization case + assert pooling_op + attrs = pooling_op.attrs + assert attrs.pooling_type == "SUM" + if relu: + assert attrs.activation == "CLIP" + + assert binary_elementwise_op + attrs = binary_elementwise_op.attrs + assert attrs.operator_type == "MUL" + assert attrs.ifm_channels == attrs.ifm2_channels == 1 + assert attrs.ofm_dtype == "int8" + + rewriter = legalize.SumRewriter() + pattern_table = [ + ( + ethosu.SumParams.composite_name, + ethosu.sum_pattern(), + lambda pat: ethosu.SumParams(pat).is_valid(), + ), + ] + + mod = create_tflite_graph() + mod = partition_ethosu_by_table(mod, pattern_table) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethos_u_main_0"] + ) + verify(mod["tvmgen_default_ethos_u_main_0"]) + + @pytest.mark.parametrize( "shapes, axis", [ diff --git a/tests/python/contrib/test_ethosu/test_replace_pooling.py b/tests/python/contrib/test_ethosu/test_replace_pooling.py index 564701637856..1ef59e0b9b03 100644 --- a/tests/python/contrib/test_ethosu/test_replace_pooling.py +++ b/tests/python/contrib/test_ethosu/test_replace_pooling.py @@ -38,6 +38,7 @@ def _create_serial_pooling( activation="NONE", rounding_mode="TFL", upscale="NONE", + ofm_dtype="int8", ): upscale_factor = 2 if upscale != "NONE" else 1 if ifm_layout == "NHWC": @@ -70,12 +71,14 @@ def _create_serial_pooling( ofm_stride_c = 16 * ofm_width if ofm_channels >= 16 else 1 ofm_stride_h = 16 * ofm_width * ((ofm_channels - 1) // 16 + 1) + ifm_channels = ofm_channels if pooling_type != "SUM" else ifm_shape[-1] + return spec.SerialPooling( ifm=spec.SerialFeatureMap( data_type="int8", height=ifm_shape[1], width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], - channels=ofm_channels, + channels=ifm_channels, tile_height_0=ifm_shape[1], tile_height_1=0, tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], @@ -91,7 +94,7 @@ def _create_serial_pooling( stride_c=ifm_stride_c, ), ofm=spec.SerialFeatureMap( - data_type="int8", + data_type=ofm_dtype, height=ofm_height, width=ofm_width, channels=ofm_channels, @@ -145,7 +148,7 @@ def _create_serial_pooling( ) @pytest.mark.parametrize("pooling_type", ["AVG", "MAX"]) @pytest.mark.parametrize("activation", ["NONE", "CLIP"]) -def test_pooling_single( +def test_avg_max_pooling_single( ifm_shape, ofm_channels, ifm_layout, @@ -207,6 +210,61 @@ def _visit(stmt): assert data[0] == ["ethosu_pooling"] + list(serial_pooling) +@pytest.mark.parametrize( + "ifm_shape, ofm_layout, rounding_mode", + [ + ((1, 5, 9, 3), "NHWC", "TFL"), + ((1, 8, 9, 40), "NHCWB16", "TFL"), + ((1, 8, 9, 8), "NHCWB16", "TRUNCATE"), + ((1, 5, 9, 3), "NHWC", "NATURAL"), + ], +) +@pytest.mark.parametrize("activation", ["NONE", "CLIP"]) +def test_sum_pooling_single( + ifm_shape, + ofm_layout, + activation, + rounding_mode, +): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + pooling = make_ethosu_pooling( + ifm=ifm, + pooling_type="SUM", + pool_shape=(1, 1), + ofm_channels=1, + strides=(1, 1), + padding=(0, 0, 0, 0), + activation=activation, + ofm_layout=ofm_layout, + rounding_mode=rounding_mode, + ) + func = relay.Function(relay.analysis.free_vars(pooling), pooling) + func = run_opt_pass(func, relay.transform.InferType()) + mod, _ = _lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_pooling_args(stmt)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + + serial_pooling = _create_serial_pooling( + ifm_shape=ifm_shape, + ofm_channels=1, + ifm_layout="NHWC", + ofm_layout=ofm_layout, + pool_shape=(1, 1), + pooling_type="SUM", + strides=(1, 1), + padding=(0, 0, 0, 0), + activation=activation, + rounding_mode=rounding_mode, + ofm_dtype="int32", + ) + assert data[0] == ["ethosu_pooling"] + list(serial_pooling) + + def test_correct_stride_with_multiple_pooling(): """Testing a specific case of two pooling operations with NHWC inputs/outputs but a NHCWB16 intermediate tensor. This lead to elements being accessed in the