From 06ae770d59cb29c23f8a9a871beff6b066d6ea3e Mon Sep 17 00:00:00 2001 From: Fredrik Knutsson Date: Wed, 17 Apr 2024 11:49:22 +0200 Subject: [PATCH 1/4] [Arm backend] Fix for TOSA BI clamp ops Min/max range values need to be on quantized form. Change-Id: I68d091306890f0a500d829ce20fc337e6cbe9dba Signed-off-by: Fredrik Knutsson --- backends/arm/operators/op_hardtanh.py | 31 +++++++++-- backends/arm/test/ops/test_conv_combos.py | 67 ++++++++++++++++++----- 2 files changed, 80 insertions(+), 18 deletions(-) diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index eb9b0a18fba..0744578f77e 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -1,4 +1,4 @@ -# Copyright 2023 Arm Limited and/or its affiliates. +# Copyright 2023-2024 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -11,6 +11,8 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg + +from executorch.backends.arm.tosa_quant_utils import get_quant_node_args from serializer.tosa_serializer import TosaOp @@ -30,12 +32,31 @@ def define_node( is_quant_node: bool, ) -> None: attr = ts.TosaSerializerAttribute() + + if is_quant_node: + # Get quant parameters + scale, zp = get_quant_node_args(node.all_input_nodes[0]) + # Convert to quantized representation + clamp_min_qs = round((inputs[1].number / scale) + zp) + clamp_min_qs = max(clamp_min_qs, -128) + clamp_max_qs = round((inputs[2].number / scale) + zp) + clamp_max_qs = min(clamp_max_qs, 127) + # Set fp values to 0.0 since they are not used + clamp_min_fp = 0.0 + clamp_max_fp = 0.0 + else: + clamp_min_fp = inputs[1].number + clamp_max_fp = inputs[2].number + # Set qs values to 0 since they are not used + clamp_min_qs = 0 + clamp_max_qs = 0 + attr.ClampAttribute( tosa_graph.builder, - int(inputs[1].number), - int(inputs[2].number), - inputs[1].number, - inputs[2].number, + clamp_min_qs, + clamp_max_qs, + clamp_min_fp, + clamp_max_fp, ) tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr) diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 1fd68493790..7bbc7a12c8e 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -12,6 +12,7 @@ import torch from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.arm_tester import ArmTester +from parameterized import parameterized logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -126,6 +127,32 @@ def forward(self, x): return x +class ComboConvRelu6(torch.nn.Module): + edge_op_list = [ + "executorch_exir_dialects_edge__ops_aten_convolution_default", + "executorch_exir_dialects_edge__ops_aten_hardtanh_default", + ] + + test_data = [ + (20 * torch.randn(1, 3, 256, 256),), + (5 * torch.randn(1, 3, 256, 256),), + (torch.randn(1, 3, 256, 256),), + (-5 * torch.randn(1, 3, 256, 256),), + ] + + def __init__(self): + super().__init__() + self.conv2d = torch.nn.Conv2d( + in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1 + ) + self.relu6 = torch.nn.ReLU6() + + def forward(self, x): + x = self.conv2d(x) + x = self.relu6(x) + return x + + class TestConvCombos(unittest.TestCase): def _test_conv_combo_tosa_MI_pipeline( self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] @@ -222,15 +249,9 @@ def test_conv_batchnorm_relu_tosa_MI(self): model = ComboConvBatchnormRelu() self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs()) - # TODO(MLETORCH-85): Investigate numerical issue. This diff is present in legacy - # testcase as well (and also not tested). For now, just increase the - # tolerance, such that we don't skip the test entirely (i.e. we maintain - # functionality). def test_conv_batchnorm_relu_tosa_BI(self): model = ComboConvBatchnormRelu() - self._test_conv_combo_tosa_BI_pipeline( - model, model.get_inputs(), atol=1.0, rtol=1.0 - ) + self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs()) @unittest.skipIf( not common.VELA_INSTALLED, @@ -240,6 +261,31 @@ def test_conv_batchnorm_relu_u55_BI(self): model = ComboConvBatchnormRelu() self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs()) + ################## + ## Conv + ReLU6 ## + ################## + @parameterized.expand(ComboConvRelu6.test_data) + def test_conv_relu6_tosa_MI(self, test_data: torch.Tensor): + model = ComboConvRelu6() + test_data = (test_data,) + self._test_conv_combo_tosa_MI_pipeline(model, test_data) + + @parameterized.expand(ComboConvRelu6.test_data) + def test_conv_relu6_tosa_BI(self, test_data: torch.Tensor): + model = ComboConvRelu6() + test_data = (test_data,) + self._test_conv_combo_tosa_BI_pipeline(model, test_data) + + @unittest.skipIf( + not common.VELA_INSTALLED, + "There is no point in running U55 tests if the Vela tool is not installed", + ) + @parameterized.expand(ComboConvRelu6.test_data) + def test_conv_relu6_u55_BI(self, test_data: torch.Tensor): + model = ComboConvRelu6() + test_data = (test_data,) + self._test_conv_combo_u55_BI_pipeline(model, test_data) + ############################### ## Block bottleneck residual ## ############################### @@ -247,14 +293,9 @@ def test_block_bottleneck_residual_tosa_MI(self): model = ComboBlockBottleneckResidual() self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs()) - # TODO(MLETORCH-85): Investigate numerical issue. This diff was present in legacy - # testcase as well. For now, just increase the tolerance, such that - # we don't skip the test entirely (i.e. we maintain functionality). def test_block_bottleneck_residual_tosa_BI(self): model = ComboBlockBottleneckResidual() - self._test_conv_combo_tosa_BI_pipeline( - model, model.get_inputs(), atol=1.0, rtol=1.0 - ) + self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs()) @unittest.skipIf( not common.VELA_INSTALLED, From 66b2ffa4b0cbc19931217111c5d51b28f0298f40 Mon Sep 17 00:00:00 2001 From: Fredrik Knutsson Date: Wed, 17 Apr 2024 15:01:00 +0200 Subject: [PATCH 2/4] Add mean square error to output check Change-Id: Ie59881824d76d5a9c30e95a8024dbbb11055577b --- backends/xnnpack/test/tester/tester.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index e0115a29eef..29fefedeaf4 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -555,10 +555,7 @@ def run_method_and_compare_outputs( print(f"Run {run_iteration} with input shapes: {input_shapes}") # Reference output (and quantization scale) - ( - reference_output, - quantization_scale, - ) = self._calculate_reference_output( + (reference_output, quantization_scale,) = self._calculate_reference_output( reference_stage.artifact, inputs_to_run ) @@ -586,16 +583,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): for i in range(len(model_output)): model = model_output[i] ref = ref_output[i] - assert torch.allclose( - model, - ref, - atol=atol, - rtol=rtol, - ), ( + assert torch.allclose(model, ref, atol=atol, rtol=rtol,), ( f"Output {i} does not match reference output.\n" f"\tGiven atol: {atol}, rtol: {rtol}.\n" f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" - f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}.\n" + f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n" f"\t-- Model vs. Reference --\n" f"\t Numel: {model.numel()}, {ref.numel()}\n" f"\tMedian: {model.median()}, {ref.median()}\n" From b95eb037c549f28d7fb919c44251a391ce6ac163 Mon Sep 17 00:00:00 2001 From: Fredrik Knutsson Date: Wed, 17 Apr 2024 18:23:21 +0200 Subject: [PATCH 3/4] Fixed CI failures Change-Id: Ia773c04d17f24bd155365a412b3e96c3b3d9aa63 --- backends/arm/test/ops/test_conv_combos.py | 2 +- backends/xnnpack/test/tester/tester.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/backends/arm/test/ops/test_conv_combos.py b/backends/arm/test/ops/test_conv_combos.py index 7bbc7a12c8e..2bde0688489 100644 --- a/backends/arm/test/ops/test_conv_combos.py +++ b/backends/arm/test/ops/test_conv_combos.py @@ -276,11 +276,11 @@ def test_conv_relu6_tosa_BI(self, test_data: torch.Tensor): test_data = (test_data,) self._test_conv_combo_tosa_BI_pipeline(model, test_data) + @parameterized.expand(ComboConvRelu6.test_data) @unittest.skipIf( not common.VELA_INSTALLED, "There is no point in running U55 tests if the Vela tool is not installed", ) - @parameterized.expand(ComboConvRelu6.test_data) def test_conv_relu6_u55_BI(self, test_data: torch.Tensor): model = ComboConvRelu6() test_data = (test_data,) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 29fefedeaf4..8812d5e5019 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -555,7 +555,10 @@ def run_method_and_compare_outputs( print(f"Run {run_iteration} with input shapes: {input_shapes}") # Reference output (and quantization scale) - (reference_output, quantization_scale,) = self._calculate_reference_output( + ( + reference_output, + quantization_scale, + ) = self._calculate_reference_output( reference_stage.artifact, inputs_to_run ) @@ -583,7 +586,12 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03): for i in range(len(model_output)): model = model_output[i] ref = ref_output[i] - assert torch.allclose(model, ref, atol=atol, rtol=rtol,), ( + assert torch.allclose( + model, + ref, + atol=atol, + rtol=rtol, + ), ( f"Output {i} does not match reference output.\n" f"\tGiven atol: {atol}, rtol: {rtol}.\n" f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" From 86fca94e9a2c7e7bb5cd4168029bb588b5f9f617 Mon Sep 17 00:00:00 2001 From: Fredrik Knutsson Date: Fri, 19 Apr 2024 12:13:35 +0200 Subject: [PATCH 4/4] Addressing review comments Change-Id: I9dee9bec58e51fcef57ebb287dbad62016c221d1 --- backends/arm/operators/op_addmm.py | 23 ++++++++----------- backends/arm/operators/op_common.py | 4 ++-- backends/arm/operators/op_conv2d.py | 2 +- backends/arm/operators/op_hardtanh.py | 6 ++--- backends/arm/operators/op_placeholder.py | 12 ++++++---- backends/arm/tosa_quant_utils.py | 29 ++++++++++++++++++++---- 6 files changed, 48 insertions(+), 28 deletions(-) diff --git a/backends/arm/operators/op_addmm.py b/backends/arm/operators/op_addmm.py index cc49e5c3821..444799d3536 100644 --- a/backends/arm/operators/op_addmm.py +++ b/backends/arm/operators/op_addmm.py @@ -73,7 +73,7 @@ def define_node( quant_node = input_node.all_input_nodes[0] else: quant_node = input_node - input_zp = get_quant_node_args(quant_node)[1] + input_zp = get_quant_node_args(quant_node).zp attr.ConvAttribute( pad=pad_attr, stride=stride_attr, @@ -111,24 +111,21 @@ def define_node( # rank > 2 linear layer if input_node.target == exir_ops.edge.aten.view_copy.default: quant_node = input_node.all_input_nodes[0] - input_scale, _ = get_quant_node_args(quant_node) + input_scale = get_quant_node_args(quant_node).scale consumer_node = list(node.users)[0] consumer_consumer_node = list(consumer_node.users)[0] - ( - consumer_node_scale, - consumer_node_node_zp, - ) = get_quant_node_args(consumer_consumer_node) - + quant_args = get_quant_node_args(consumer_consumer_node) + consumer_node_scale = quant_args.scale + consumer_node_node_zp = quant_args.zp else: - input_scale, _ = get_quant_node_args(input_node) + input_scale = get_quant_node_args(input_node).scale consumer_node = list(node.users)[0] - ( - consumer_node_scale, - consumer_node_node_zp, - ) = get_quant_node_args(consumer_node) + quant_args = get_quant_node_args(consumer_node) + consumer_node_scale = quant_args.scale + consumer_node_node_zp = quant_args.zp weight_node_q_node = weight_node.all_input_nodes[0] - weight_scale, _ = get_quant_node_args(weight_node_q_node) + weight_scale = get_quant_node_args(weight_node_q_node).scale output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale ( diff --git a/backends/arm/operators/op_common.py b/backends/arm/operators/op_common.py index 4701343e8e8..eadf00c294d 100644 --- a/backends/arm/operators/op_common.py +++ b/backends/arm/operators/op_common.py @@ -31,8 +31,8 @@ def build_avg_pool_2d_common( output_zp = 0 if is_quant_node: - _, input_zp = get_quant_node_args(node.args[0]) - _, output_zp = get_quant_node_args(list(node.users)[0]) + input_zp = get_quant_node_args(node.args[0]).zp + output_zp = get_quant_node_args(list(node.users)[0]).zp attr = ts.TosaSerializerAttribute() attr.PoolAttribute( diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 45a9a6671d1..cc1c1f3c263 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -80,7 +80,7 @@ def define_node( ) input_zp = ( - get_quant_node_args(node.all_input_nodes[0])[1] if is_quant_node else 0 + get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0 ) attr.ConvAttribute( diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index 0744578f77e..3d58f6d628c 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -35,12 +35,12 @@ def define_node( if is_quant_node: # Get quant parameters - scale, zp = get_quant_node_args(node.all_input_nodes[0]) + scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0]) # Convert to quantized representation clamp_min_qs = round((inputs[1].number / scale) + zp) - clamp_min_qs = max(clamp_min_qs, -128) + clamp_min_qs = max(clamp_min_qs, qmin) clamp_max_qs = round((inputs[2].number / scale) + zp) - clamp_max_qs = min(clamp_max_qs, 127) + clamp_max_qs = min(clamp_max_qs, qmax) # Set fp values to 0.0 since they are not used clamp_min_fp = 0.0 clamp_max_fp = 0.0 diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/operators/op_placeholder.py index 6a57d895ff0..05e02468d6d 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/operators/op_placeholder.py @@ -50,11 +50,13 @@ def process_placeholder( weight_node = weight_node_permuted.all_input_nodes[0] if input_node.target == exir_ops.edge.aten.view_copy.default: - input_node_scale, _ = get_quant_node_args(input_node.all_input_nodes[0]) + input_node_scale = get_quant_node_args( + input_node.all_input_nodes[0] + ).scale else: - input_node_scale, _ = get_quant_node_args(input_node) + input_node_scale = get_quant_node_args(input_node).scale - weight_node_scale, _ = get_quant_node_args(weight_node) + weight_node_scale = get_quant_node_args(weight_node).scale bias_values_quantized = ( (parameter_values / (input_node_scale * weight_node_scale)) @@ -81,8 +83,8 @@ def process_placeholder( bias_node, ) = consumer_node.all_input_nodes - input_node_scale, _ = get_quant_node_args(input_node) - weight_node_scale, _ = get_quant_node_args(weight_node) + input_node_scale = get_quant_node_args(input_node).scale + weight_node_scale = get_quant_node_args(weight_node).scale bias_scales = input_node_scale * weight_node_scale parameter_values_quantized = ( diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index 9e04ba68eef..25fba250395 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -6,8 +6,10 @@ # Utiliy functions for TOSA quantized lowerings import math +from typing import NamedTuple import serializer.tosa_serializer as ts +import torch.fx from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp, TosaSerializerTensor @@ -17,7 +19,14 @@ dq_q_ops = [q_op, dq_op] -def is_quant_node(node): +class QuantArgs(NamedTuple): + scale: float + zp: int + qmin: int + qmax: int + + +def is_quant_node(node: torch.fx.Node): consumer_node = list(node.users)[0] input = node.all_input_nodes[0] @@ -41,10 +50,22 @@ def is_quant_arg(arg): return consumer_node.target == q_op -def get_quant_node_args(node): +def get_quant_node_args(node: torch.fx.Node): + """ + Get the quantization parameters from a quant node. + + Args: + node: The quant node. + Returns: + QuantArgs: scale, zp, qmin, qmax + """ quant_args = [TosaArg(arg) for arg in node.args] - # Return the scale and zp - return quant_args[1].number, quant_args[2].number + return QuantArgs( + quant_args[1].number, + quant_args[2].number, + quant_args[3].number, + quant_args[4].number, + ) # Check if scale32 mode is used for given output element type