From 243bb6c32e4336bb9d41911485b5748fe3186346 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Wed, 5 Nov 2025 09:48:40 +0100 Subject: [PATCH] Arm backend: Improve dtype validation Improve dtype validaiton in NodeVisitors. Signed-off-by: Oscar Andersson Change-Id: Ieb9ced1ae8d2db916e6c8bc0b45773a640d330db --- backends/arm/operators/op_avg_pool2d.py | 8 +++- backends/arm/operators/op_cat.py | 19 ++++++++- backends/arm/operators/op_clamp.py | 8 +++- backends/arm/operators/op_eq.py | 2 +- backends/arm/operators/op_ge.py | 2 +- backends/arm/operators/op_gt.py | 2 +- backends/arm/operators/op_index_select.py | 17 ++++++-- backends/arm/operators/op_le.py | 2 +- backends/arm/operators/op_lt.py | 2 +- backends/arm/operators/op_max_pool2d.py | 8 +++- backends/arm/operators/op_mul.py | 2 +- backends/arm/operators/op_permute.py | 8 +++- backends/arm/operators/op_repeat.py | 8 +++- backends/arm/operators/op_slice.py | 8 +++- backends/arm/operators/op_sum.py | 7 ++++ backends/arm/operators/op_tosa_matmul.py | 15 ++++++- backends/arm/operators/op_tosa_resize.py | 43 ++++++++++++--------- backends/arm/operators/op_tosa_table.py | 18 ++++++--- backends/arm/operators/op_tosa_transpose.py | 4 +- backends/arm/operators/op_where.py | 5 +-- backends/arm/operators/ops_identity.py | 20 ++++++++++ backends/arm/test/ops/test_expand.py | 7 +++- backends/arm/test/ops/test_permute.py | 17 +++++--- backends/arm/test/ops/test_repeat.py | 8 +++- 24 files changed, 185 insertions(+), 55 deletions(-) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 693f3f1155a..96d2a2c984f 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -25,6 +25,7 @@ ) from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import Tosa_1_00 @register_node_visitor @@ -115,10 +116,15 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension( + "int16" + ): + supported_dtypes.append(ts.DType.INT16) validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], + supported_dtypes, output.tosa_spec, ) diff --git a/backends/arm/operators/op_cat.py b/backends/arm/operators/op_cat.py index 2cfa4720c3c..cc39c24fba4 100644 --- a/backends/arm/operators/op_cat.py +++ b/backends/arm/operators/op_cat.py @@ -14,8 +14,11 @@ ) from executorch.backends.arm.operators.operator_validation_utils import ( validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, ) from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import Tosa_1_00 from torch.fx import Node @@ -35,9 +38,21 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: + supported_dtypes = [ts.DType.BOOL, ts.DType.INT8, ts.DType.INT32, ts.DType.FP32] + if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension( + "int16" + ): + supported_dtypes.append(ts.DType.INT16) validate_num_inputs(self.target, inputs, [1, 2]) + input_tosa_args = [TosaArg(arg, output.tosa_spec) for arg in inputs[0].special] + validate_same_dtype(self.target, [*input_tosa_args, output], ts) + validate_valid_dtype( + self.target, + [*input_tosa_args, output], + supported_dtypes, + output.tosa_spec, + ) - tensors = inputs[0].special dim = 0 if len(inputs) < 2 else inputs[1].number rank = len(output.shape) dim = (dim + rank) % rank @@ -50,7 +65,7 @@ def define_node( node, tosa_graph, ts.Op.CONCAT, - [tensor.name for tensor in tensors], + [tensor.name for tensor in input_tosa_args], [output.name], attr, ) diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index 76aa75cd9fd..ab9af7a6ce2 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -22,6 +22,7 @@ from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import Tosa_1_00 from torch.fx import Node @@ -88,10 +89,15 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, [2, 3]) validate_same_dtype(self.target, [inputs[0], output], ts) + supported_dtypes = [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32] + if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension( + "int16" + ): + supported_dtypes.append(ts.DType.INT16) validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.FP16, ts.DType.FP32], + supported_dtypes, output.tosa_spec, ) diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index 7cfd497b1fe..bd72c9491ca 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index 5d6eeb75275..754778487e9 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index 92879d549b1..2a483f735a7 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_index_select.py b/backends/arm/operators/op_index_select.py index 5b73b5e91ae..a4f541e65d9 100644 --- a/backends/arm/operators/op_index_select.py +++ b/backends/arm/operators/op_index_select.py @@ -12,6 +12,11 @@ NodeVisitor, register_node_visitor, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) from executorch.backends.arm.tosa.mapping import TosaArg from executorch.backends.arm.tosa.utils import build_reshape_tosa_1_0 @@ -45,10 +50,16 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: - if len(inputs) != 3: - raise ValueError(f"Number of inputs are not 3: {len(inputs)}") + validate_num_inputs(self.target, inputs, 3) + validate_same_dtype(self.target, [inputs[0], output], ts) + validate_valid_dtype( + self.target, + [inputs[0], output], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], + output.tosa_spec, + ) - weights, index, indices = inputs + weights, _, indices = inputs if len(weights.shape) == 2: weights_new_shape = [1, weights.shape[0], weights.shape[1]] diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index 2b1a023d624..aa6b52b9982 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index 4f3e1163c69..4b2b1a1960b 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 5690b82d97b..e62a21c7e34 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -21,6 +21,7 @@ ) from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import Tosa_1_00 @register_node_visitor @@ -44,10 +45,15 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, [3, 4, 5, 6]) validate_same_dtype(self.target, [inputs[0], output], ts) + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + if isinstance( + output.tosa_spec, Tosa_1_00 + ) and output.tosa_spec.support_extension("int16"): + supported_dtypes.append(ts.DType.INT16) validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], + supported_dtypes, output.tosa_spec, ) diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index f1cd5de6fd6..0e10443e523 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -44,7 +44,7 @@ def define_node( validate_valid_dtype( self.target, [*inputs, output], - [ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 80ccfae04e6..fea0aea9298 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -116,7 +116,13 @@ def define_node( validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], + [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.FP32, + ], output.tosa_spec, ) diff --git a/backends/arm/operators/op_repeat.py b/backends/arm/operators/op_repeat.py index 99c0ecce0b2..e44fede736d 100644 --- a/backends/arm/operators/op_repeat.py +++ b/backends/arm/operators/op_repeat.py @@ -43,7 +43,13 @@ def define_node( validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.INT16, ts.DType.FP32], + [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.FP32, + ], output.tosa_spec, ) diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index 7366703083c..21c86e5f7c4 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -73,7 +73,13 @@ def define_node( validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], + [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.FP32, + ], output.tosa_spec, ) diff --git a/backends/arm/operators/op_sum.py b/backends/arm/operators/op_sum.py index 91c37e25f43..e956359736c 100644 --- a/backends/arm/operators/op_sum.py +++ b/backends/arm/operators/op_sum.py @@ -15,6 +15,7 @@ from executorch.backends.arm.operators.operator_validation_utils import ( validate_num_inputs, validate_same_dtype, + validate_valid_dtype, ) from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg @@ -39,6 +40,12 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, 3) validate_same_dtype(self.target, [inputs[0], output], ts) + validate_valid_dtype( + self.target, + [inputs[0], output], + [ts.DType.INT32, ts.DType.FP32], + output.tosa_spec, + ) tensor = inputs[0] input_shape = list(tensor.shape) diff --git a/backends/arm/operators/op_tosa_matmul.py b/backends/arm/operators/op_tosa_matmul.py index be73a60f7c7..abcb28ad05b 100644 --- a/backends/arm/operators/op_tosa_matmul.py +++ b/backends/arm/operators/op_tosa_matmul.py @@ -25,6 +25,7 @@ ) from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import Tosa_1_00 @register_node_visitor @@ -51,16 +52,26 @@ def define_node( """Define the TOSA ``MATMUL`` operator.""" validate_num_inputs(self.target, inputs, 2) validate_same_dtype(self.target, [*inputs], ts) + supported_input_dtypes = [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32] + if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension( + "int16" + ): + supported_input_dtypes.append(ts.DType.INT16) validate_valid_dtype( self.target, [*inputs], - [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], + supported_input_dtypes, output.tosa_spec, ) + supported_output_dtypes = [ts.DType.INT32, ts.DType.FP32] + if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension( + "int16" + ): + supported_output_dtypes.append(ts.DType.INT48) validate_valid_dtype( self.target, [output], - [ts.DType.INT32, ts.DType.INT48, ts.DType.FP32], + supported_output_dtypes, output.tosa_spec, ) diff --git a/backends/arm/operators/op_tosa_resize.py b/backends/arm/operators/op_tosa_resize.py index 6e6edf4fd41..de5df983789 100644 --- a/backends/arm/operators/op_tosa_resize.py +++ b/backends/arm/operators/op_tosa_resize.py @@ -19,6 +19,7 @@ validate_valid_dtype, ) from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import Tosa_1_00 from executorch.backends.arm.tosa.utils import get_resize_parameters @@ -39,33 +40,37 @@ def define_node( output: TosaArg, ) -> None: validate_num_inputs(self.target, inputs, [3, 4]) + supported_input_dtypes = [ts.DType.INT8, ts.DType.FP32] + if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension( + "int16" + ): + supported_input_dtypes.append(ts.DType.INT16) + validate_valid_dtype( + self.target, + [inputs[0]], + supported_input_dtypes, + output.tosa_spec, + ) + supported_output_dtypes = [ts.DType.FP32] if node.kwargs.get("resize_mode") == "bilinear": resize_mode = ts.ResizeMode.BILINEAR align_corners = bool(node.args[2]) + supported_output_dtypes.append(ts.DType.INT32) + if isinstance( + self.tosa_spec, Tosa_1_00 + ) and self.tosa_spec.support_extension("int16"): + supported_output_dtypes.append(ts.DType.INT48) else: resize_mode = ts.ResizeMode.NEAREST align_corners = False validate_same_dtype(self.target, [inputs[0], output], ts) - - valid_dtypes = [] - if self.tosa_spec.support_integer(): - valid_dtypes.extend( - [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.INT48] - ) - - if self.tosa_spec.support_float(): - valid_dtypes.extend( - [ - ts.DType.FP16, - ts.DType.FP32, - ] - ) - + supported_output_dtypes.append(ts.DType.INT8) + if isinstance( + self.tosa_spec, Tosa_1_00 + ) and self.tosa_spec.support_extension("int16"): + supported_output_dtypes.append(ts.DType.INT16) validate_valid_dtype( - self.target, - [inputs[0], output], - valid_dtypes, - output.tosa_spec, + self.target, [output], supported_output_dtypes, output.tosa_spec ) # tosa_shape output is NHWC, take HW input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[ diff --git a/backends/arm/operators/op_tosa_table.py b/backends/arm/operators/op_tosa_table.py index d867b5efd7b..0235f61b14a 100644 --- a/backends/arm/operators/op_tosa_table.py +++ b/backends/arm/operators/op_tosa_table.py @@ -20,6 +20,7 @@ from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import Tosa_1_00 @register_node_visitor @@ -36,13 +37,20 @@ def define_node( output: TosaArg, ) -> None: validate_num_inputs(self.target, inputs, 2) + supported_input_dtypes = [ts.DType.INT8] + supported_output_dtypes = [ts.DType.INT8] + if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension( + "int16" + ): + supported_input_dtypes.append(ts.DType.INT16) + supported_output_dtypes.append(ts.DType.INT32) + + validate_valid_dtype( + self.target, inputs, supported_input_dtypes, output.tosa_spec + ) validate_valid_dtype( - self.target, inputs, [ts.DType.INT8, ts.DType.INT16], output.tosa_spec + self.target, output, supported_output_dtypes, output.tosa_spec ) - if inputs[0].dtype == ts.DType.INT8: - validate_valid_dtype(self.target, output, ts.DType.INT8, output.tosa_spec) - if inputs[0].dtype == ts.DType.INT16: - validate_valid_dtype(self.target, output, ts.DType.INT32, output.tosa_spec) # The name of the table constant is a bit complex. # The name of the pytorch buffer will be the target of last node argument. diff --git a/backends/arm/operators/op_tosa_transpose.py b/backends/arm/operators/op_tosa_transpose.py index bbd9252f8f8..c5aa66a85fd 100644 --- a/backends/arm/operators/op_tosa_transpose.py +++ b/backends/arm/operators/op_tosa_transpose.py @@ -47,12 +47,12 @@ def define_node( self.target, [inputs[0], output], [ + ts.DType.BOOL, ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, - ts.DType.FP32, - ts.DType.BOOL, ts.DType.FP16, + ts.DType.FP32, ], output.tosa_spec, ) diff --git a/backends/arm/operators/op_where.py b/backends/arm/operators/op_where.py index 53ce2e0fc22..f0b6538ac27 100644 --- a/backends/arm/operators/op_where.py +++ b/backends/arm/operators/op_where.py @@ -42,16 +42,15 @@ def define_node( output: TosaArg, ) -> None: - supported_dtypes = [] + supported_dtypes = [ts.DType.BOOL] if output.tosa_spec.support_integer(): supported_dtypes += [ - ts.DType.BOOL, ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ] if output.tosa_spec.support_float(): - supported_dtypes += [ts.DType.BOOL, ts.DType.FP16, ts.DType.FP32] + supported_dtypes += [ts.DType.FP16, ts.DType.FP32] validate_num_inputs(self.target, inputs, 3) # Not first input, which is condition tensor. diff --git a/backends/arm/operators/ops_identity.py b/backends/arm/operators/ops_identity.py index c4a6d78fef4..153b611f0af 100644 --- a/backends/arm/operators/ops_identity.py +++ b/backends/arm/operators/ops_identity.py @@ -18,8 +18,10 @@ from executorch.backends.arm.operators.operator_validation_utils import ( validate_num_inputs, validate_same_dtype, + validate_valid_dtype, ) from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.specification import Tosa_1_00 def identity_operator_factory(identity_target: str): @@ -42,6 +44,24 @@ def define_node( ) -> None: validate_num_inputs(self.target, inputs, 1) validate_same_dtype(self.target, [inputs[0], output], ts) + supported_dtypes = [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ] + if output.tosa_spec.support_float(): + supported_dtypes += [ts.DType.FP32] + if isinstance( + self.tosa_spec, Tosa_1_00 + ) and self.tosa_spec.support_extension("int16"): + supported_dtypes += [ts.DType.INT48] + validate_valid_dtype( + self.target, + [inputs[0], output], + supported_dtypes, + output.tosa_spec, + ) # Simply add an identityOp attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/test/ops/test_expand.py b/backends/arm/test/ops/test_expand.py index 34694469bbf..f8d9fe132a9 100644 --- a/backends/arm/test/ops/test_expand.py +++ b/backends/arm/test/ops/test_expand.py @@ -31,6 +31,7 @@ class Expand(torch.nn.Module): # (input tensor, multiples) test_parameters = { + "randbool_1d": lambda: (torch.randint(0, 1, (1,), dtype=torch.bool), (5,)), "rand_1d_both": lambda: (torch.rand(1), (2,)), "rand_1d": lambda: (torch.randn(1), (2, 2, 4)), "rand_4d": lambda: (torch.randn(1, 1, 1, 5), (1, 4, -1, -1)), @@ -71,7 +72,11 @@ def test_expand_tosa_INT(test_data: Tuple): pipeline.run() -@common.parametrize("test_data", Expand.test_parameters) +@common.parametrize( + "test_data", + Expand.test_parameters, + xfails={"randbool_1d": "Bool not supported on U55"}, +) @common.XfailIfNoCorstone300 def test_expand_u55_INT(test_data: Tuple): pipeline = EthosU55PipelineINT[input_t1]( diff --git a/backends/arm/test/ops/test_permute.py b/backends/arm/test/ops/test_permute.py index 8938ebcc27e..a2b322d42cf 100644 --- a/backends/arm/test/ops/test_permute.py +++ b/backends/arm/test/ops/test_permute.py @@ -33,6 +33,7 @@ test_data_suite = { # (test_name,test_data,dims) "rank_2": lambda: (torch.rand(10, 10), [1, 0]), + "rank2_bool": lambda: (torch.randint(0, 2, (5, 5), dtype=torch.bool), [1, 0]), "rank_3": lambda: (torch.rand(10, 10, 10), [2, 0, 1]), "rank_3_2": lambda: (torch.rand(10, 10, 10), [1, 2, 0]), "rank_4": lambda: (torch.rand(1, 5, 1, 10), [0, 2, 3, 1]), @@ -80,7 +81,9 @@ def test_permute_tosa_INT(test_data: torch.Tensor): pipeline.run() -@common.parametrize("test_data", test_data_suite) +@common.parametrize( + "test_data", test_data_suite, xfails={"rank2_bool": "Bool not supported on U55"} +) @common.XfailIfNoCorstone300 def test_permute_u55_INT(test_data): test_data, dims = test_data() @@ -156,7 +159,7 @@ def get_symmetric_a16w8_permute_quantizer( @common.parametrize("test_data", test_data_suite) -def test_permute_int16_tosa_INT(test_data: torch.Tensor): +def test_permute_16a8w_tosa_INT(test_data: torch.Tensor): """Test permute operation with int16 quantization""" test_data, dims = test_data() pipeline = TosaPipelineINT[input_t1]( @@ -182,9 +185,13 @@ def test_permute_int16_tosa_INT(test_data: torch.Tensor): } -@common.parametrize("test_data", test_data_suite_exact) +@common.parametrize( + "test_data", + test_data_suite_exact, + xfails={"rank2_bool": "Bool not supported on U55"}, +) @common.XfailIfNoCorstone300 -def test_permute_int16_u55_INT16(test_data: torch.Tensor): +def test_permute_16a8w_u55_INT16(test_data: torch.Tensor): """Test permute operation with int16 quantization on U55""" test_data, dims = test_data() pipeline = EthosU55PipelineINT[input_t1]( @@ -208,7 +215,7 @@ def test_permute_int16_u55_INT16(test_data: torch.Tensor): @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 -def test_permute_int16_u85_INT16(test_data: torch.Tensor): +def test_permute_16a8w_u85_INT16(test_data: torch.Tensor): """Test permute operation with int16 quantization on U85""" test_data, dims = test_data() pipeline = EthosU85PipelineINT[input_t1]( diff --git a/backends/arm/test/ops/test_repeat.py b/backends/arm/test/ops/test_repeat.py index 56986a54781..2c7583ab77f 100644 --- a/backends/arm/test/ops/test_repeat.py +++ b/backends/arm/test/ops/test_repeat.py @@ -52,6 +52,10 @@ def forward(self, x: torch.Tensor): test_data_suite = { # test_name : lambda: (module, test_data) + "1_x_1_bool": lambda: ( + Repeat((2,)), + (torch.randint(0, 2, (3,), dtype=torch.bool),), + ), "1_x_1": lambda: (Repeat((2,)), (torch.randn(3),)), "2_x_2": lambda: (Repeat((2, 1)), (torch.randn(3, 4),)), "4_x_4": lambda: (Repeat((1, 2, 3, 4)), (torch.randn(1, 1, 2, 2),)), @@ -87,7 +91,9 @@ def test_repeat_tosa_INT(test_data: Tuple): pipeline.run() -@common.parametrize("test_data", test_data_suite) +@common.parametrize( + "test_data", test_data_suite, xfails={"1_x_1_bool": "Bool not supported on U55"} +) @common.XfailIfNoCorstone300 def test_repeat_u55_INT(test_data: Tuple): module, test_data = test_data()