diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index c940abdeab5f..77ef51ef9c40 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -298,6 +298,83 @@ def calculate_lut_value(i): return identity +class HardSwishRewriter(DFPatternCallback): + """Convert ethosu.hard_swish composite function to add operation with LUT.""" + + def __init__(self): + super().__init__(require_type=True, rewrite_once=True) + self.params_class = ethosu_patterns.HardSwishParams + self.pattern = wildcard().has_attr({"Composite": self.params_class.composite_name})( + wildcard() + ) + + def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map): + params = self.params_class(post.op.body) + params.ifm.tensor = post.args[0] + + # The calculation of the LUT values is similar to that in Vela + # convert_hardswish_to_lut(op, arch, nng) + # (https://review.mlplatform.org/plugins/gitiles/ml/ethos-u/ethos-u-vela/+/refs/tags/3.2.0/ethosu/vela/tflite_graph_optimiser.py#719) # pylint: disable=line-too-long + input_scale = np.double(params.ifm.q_params.scale_f32) + input_zp = int(params.ifm.q_params.zero_point) + hires_input_scale = (1 / 128) * input_scale + + output_scale = np.double(params.ofm.q_params.scale_f32) + output_zp = int(params.ofm.q_params.zero_point) + output_scale, output_shift = scaling.quantise_scale(hires_input_scale / output_scale) + output_scale_16 = fp_math.downscale_multiplier_int32_to_int16(output_scale) + output_shift = 31 - output_shift + output_shift = -output_shift if output_shift < 0 else 0 + + dtype = params.ifm.dtype + qmin, qmax = np.iinfo(dtype).min, np.iinfo(dtype).max + + def calculate_relu_multiplier(inp, input_scale): + rmultiplier = np.double(3 / 32768) + rscale, rshift = scaling.quantise_scale(input_scale / rmultiplier) + rscale_16 = fp_math.downscale_multiplier_int32_to_int16(rscale) + + rvalue = np.int16(inp) + if rshift < 31: + rvalue = fp_math.shift_left16(rvalue, 30 - rshift) + rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16) + rvalue = fp_math.shift_left16(rvalue, 1) + elif rshift > 31: + rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16) + rvalue = fp_math.rounding_divide_by_pot(rvalue, rshift - 31) + else: + rvalue = fp_math.saturating_rounding_mul16(rvalue, rscale_16) + + rvalue = (rvalue + (1 << 15)) >> 1 + return rvalue + + def calculate_lut_values(i): + hires_input_value = (i - input_zp) * 128 + preshift_input_value = fp_math.saturating_rounding_mul16( + hires_input_value, output_scale_16 + ) + relu_value = calculate_relu_multiplier(hires_input_value, hires_input_scale) + lut_result = fp_math.saturating_mul16(relu_value, preshift_input_value) + lut_result = fp_math.rounding_divide_by_pot(lut_result, output_shift) + output_zp + return min(qmax, max(qmin, lut_result)) + + values = list(map(calculate_lut_values, range(-128, 128))) + lut = relay.const(values, dtype=dtype) + + # We baked the requantization into the LUT, so we don't requantize the identity operator + identity = ethosu_ops.ethosu_identity( + ifm=params.ifm.tensor, + lut=lut, + ifm_scale=input_scale, + ifm_zero_point=input_zp, + ofm_scale=input_scale, + ofm_zero_point=input_zp, + activation="LUT", + ) + + return identity + + class Conv2DRewriter(DFPatternCallback): """Convert conv2d related composite functions into ethosu_conv2d operators""" @@ -1306,6 +1383,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function: ShlRewriter(), AbsRewriter(), TanhRewriter(), + HardSwishRewriter(), LeakyReLURewriter(), MeanRewriter(), ConcatRewriter(), diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 4c3dcc2fc45a..c0f8e5e9708e 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1724,6 +1724,54 @@ def qnn_fc_pattern(): return optional_clip +class HardSwishParams: + """ + This class will parse a call to a ethos-u.hard_swish composite function + and extract the parameter information. + """ + + composite_name = "ethos-u.hard_swish" + + def __init__(self, func_body): + from tvm.relay.backend.contrib.ethosu.util import QuantizeArgs + from tvm.relay.backend.contrib.ethosu.util import DequantizeArgs + + quantize = func_body + divide = quantize.args[0] + multiply = divide.args[0] + clip = multiply.args[1] + add = clip.args[0] + dequantize = add.args[0] + + self.ifm = TensorParams( + dequantize.args[0], + scale=dequantize.args[DequantizeArgs.IFM_SCALE.value], + zero_point=dequantize.args[DequantizeArgs.IFM_ZERO_POINT.value], + ) + self.ofm = TensorParams( + quantize, + scale=quantize.args[QuantizeArgs.OFM_SCALE.value], + zero_point=quantize.args[QuantizeArgs.OFM_ZERO_POINT.value], + ) + + def is_valid(self): + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.int8]): + return False + return True + + +def hard_swish_pattern(): + """Create the pattern for hard swish.""" + dequantize = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) + add = is_op("add")(dequantize, is_constant()) + clip = is_op("clip")(add) + multiply = is_op("multiply")(dequantize, clip) + divide = is_op("divide")(multiply, is_constant()) + quantize = is_op("qnn.quantize")(divide, is_constant(), is_constant()) + return quantize + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -1844,6 +1892,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal squeeze_pattern(), lambda pat: SqueezeParams(pat).is_valid(), ), + ( + HardSwishParams.composite_name, + hard_swish_pattern(), + lambda pat: HardSwishParams(pat).is_valid(), + ), ] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 2d3489889e8a..920cfff17884 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -819,6 +819,21 @@ def tanh_func(x): ) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("ifm_shape", [(1, 5, 5, 3), (1, 12, 9, 1)]) +def test_tflite_hard_swish(accel_type, ifm_shape): + np.random.seed(0) + + @tf.function + def hard_swish_func(x): + op = tf.keras.layers.Lambda( + lambda x: x * tf.keras.activations.relu(x + 3.0, max_value=6.0) / 6.0 + )(x) + return op + + infra.compare_tvm_with_tflite(hard_swish_func, [ifm_shape], accel_type, ranges=[(-1, 1)]) + + @pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize( "shapes, axis", diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index 3f8b5f7d5b58..0f8fa4d84bf7 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -2751,5 +2751,60 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize("ifm_shape", [(1, 5, 5, 3), (1, 12, 9, 1)]) +def test_tflite_hard_swish(ifm_shape): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + op = tf.keras.layers.Lambda( + lambda x: x * tf.keras.activations.relu(x + 3.0, max_value=6.0) / 6.0 + )(x) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, tf.float32) + ) + + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"input": ifm_shape}, + dtype_dict={"input": dtype}, + ) + + mod = ethosu.partition_for_ethosu(mod, params) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.HardSwishRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + mod = relay.transform.InferType()(mod) + + func_body = mod["tvmgen_default_ethos_u_main_0"].body + assert func_body.op.name == "contrib.ethosu.identity" + assert func_body.attrs.activation == "LUT" + assert tuple(func_body.args[0].checked_type.shape) == (ifm_shape) + assert tuple(func_body.args[1].checked_type.shape) == (256,) + + if __name__ == "__main__": pytest.main([__file__])