From 2fbaa5c9ccc7ad7e9b2d07e685ee5cf7ced7d964 Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Mon, 3 Mar 2025 17:29:12 +0000 Subject: [PATCH] Arm backend: Add TOSA support for any.default, any.dim and any.dims 1. Implement a pass ConvertAnyDefaultDimDimsPass to decompose any.default, any.dim and any.dims into a sequence of any.dim with keepdim=True and a squeeze_copy.dims if needed 2. Implement a NodeVisitor to lower any.dim to REDUCE_ANY in TOSA 3. Fix the failures in https://github.com/pytorch/executorch/pull/9128 Change-Id: Ifb6672f2c017cd7365e76319795290a36909657c Signed-off-by: Yufeng Shi --- backends/arm/_passes/arm_pass_manager.py | 5 + .../convert_any_default_dim_dims_pass.py | 106 ++++++++++ .../keep_dims_false_to_squeeze_pass.py | 5 +- .../tosa_supported_operators.py | 6 + backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_any.py | 53 +++++ backends/arm/test/models/test_conformer.py | 3 +- backends/arm/test/ops/test_any.py | 185 ++++++++++++++++++ 8 files changed, 360 insertions(+), 4 deletions(-) create mode 100644 backends/arm/_passes/convert_any_default_dim_dims_pass.py create mode 100644 backends/arm/operators/op_any.py create mode 100644 backends/arm/test/ops/test_any.py diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 26ff15db396..0d81cd50d8f 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -15,6 +15,9 @@ ) from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass +from executorch.backends.arm._passes.convert_any_default_dim_dims_pass import ( + ConvertAnyDefaultDimDimsPass, +) from executorch.backends.arm._passes.convert_expand_copy_to_repeat import ( ConvertExpandCopyToRepeatPass, ) @@ -110,6 +113,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) + self.add_pass(ConvertAnyDefaultDimDimsPass()) self.add_pass(ReplaceScalarWithTensorArgPass()) self.add_pass(AnnotateDecomposedMatmulPass()) @@ -155,6 +159,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) + self.add_pass(ConvertAnyDefaultDimDimsPass()) self.add_pass(AnnotateDecomposedMatmulPass()) self.add_pass(QuantizeOperatorArguments()) diff --git a/backends/arm/_passes/convert_any_default_dim_dims_pass.py b/backends/arm/_passes/convert_any_default_dim_dims_pass.py new file mode 100644 index 00000000000..7085f17add0 --- /dev/null +++ b/backends/arm/_passes/convert_any_default_dim_dims_pass.py @@ -0,0 +1,106 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.dialects._ops import ( # type: ignore[import-not-found] + ops as exir_ops, +) +from executorch.exir.pass_base import ( # type: ignore[import-not-found] + ExportPass, + PassResult, +) + + +class ConvertAnyDefaultDimDimsPass(ExportPass): + """ + Converts any.default, any.dim and any.dims to a sequence of any.dim by unrolling multi-dimensional reduction. + Please refer to KeepDimsFalseToSqueezePass for an explanation of this coversion. + + Example 1 + Original: + any() # x.shape: [dim1, dim2, ..., dimn] + After pass: + any.dim(dim1, keepdim = True) + any.dim(dim2, keepdim = True) + ... + any.dim(dimn, keepdim = True) + squeeze(dim = [dim1, dim2, ...., dimn]) + + Example 2 + Original: + any.dim(dim1, keepdim = False) + After pass: + any.dim(dim1, keepdim = True) + squeeze(dim = [dim1]) + + Example 3 + Original: + any.dims([dim1, dim2], keepdim = False) + After pass: + any.dim(dim1, keepdim = True) + any.dim(dim2, keepdim = True) + squeeze(dim = [dim1, dim2]) + """ + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if node.target not in [ + exir_ops.edge.aten.any.default, + exir_ops.edge.aten.any.dim, + exir_ops.edge.aten.any.dims, + ]: + continue + + if len(node.args) == 1: + # any.default(input) + input_node = (node.args)[0] + dims = range(len(input_node.meta["val"].shape)) + keepdim = False + elif len(node.args) == 2: + # any.dim/dims(input, dims=dims) + input_node, dims = node.args + keepdim = False + elif len(node.args) == 3: + # any.dim/dims(input, dims=dims, keepdim=keepdim) + input_node, dims, keepdim = node.args + else: + raise RuntimeError( + f"Unexpected arg size {len(node.args)} in {node.name}" + ) + try: + iter(dims) + except: + dims = [dims] # type: ignore[assignment] + else: + dims = list(dims) # type: ignore[assignment] + + # Unroll multi-dimensional reduction and keep-dims arg + with graph_module.graph.inserting_before(node): + for dim in dims: + args = (input_node, dim, True) + input_node = graph_module.graph.create_node( + "call_function", exir_ops.edge.aten.any.dim, args, node.kwargs + ) + + if not keepdim: + args = (input_node, dims) # type: ignore[assignment] + input_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.aten.squeeze_copy.dims, + args, + ) + + node.replace_all_uses_with(input_node) + modified = True + + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py b/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py index e3ed7e65a73..744436cba9e 100644 --- a/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py +++ b/backends/arm/_passes/keep_dims_false_to_squeeze_pass.py @@ -35,14 +35,15 @@ class KeepDimsFalseToSqueezePass(ExportPass): """ # CURRENTLY NOT HANDLED OPS - # exir_ops.edge.aten.any.dim, - # exir_ops.edge.aten.any.dims, # exir_ops.edge.aten.argmax, # exir_ops.edge.aten.argmin, # exir_ops.edge.aten.prod.dim_int, # HANDLED OPS # exir_ops.edge.aten.sum.dim_IntList + # exir_ops.edge.aten.any.default (decomposed in convert_any_default_dim_dims_pass) + # exir_ops.edge.aten.any.dim (decomposed in convert_any_default_dim_dims_pass) + # exir_ops.edge.aten.any.dims (decomposed in convert_any_default_dim_dims_pass) # exir_ops.edge.aten.max.dim (decomposed in convert_minmax_pass) # exir_ops.edge.aten.min.dim (decomposed in convert_minmax_pass) # exir_ops.edge.aten.amin (decomposed in convert_minmax_pass) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 48bf07bc3a2..fee10fe74db 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -112,6 +112,9 @@ def is_node_supported( supported = node.op == "call_function" and node.target in [ exir_ops.edge.aten.abs.default, exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.any.default, + exir_ops.edge.aten.any.dim, + exir_ops.edge.aten.any.dims, exir_ops.edge.aten.logical_and.default, exir_ops.edge.aten.logical_or.default, exir_ops.edge.aten.logical_xor.default, @@ -194,6 +197,9 @@ def is_node_supported( ) -> bool: if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: unsupported_ops = [ + exir_ops.edge.aten.any.default, + exir_ops.edge.aten.any.dim, + exir_ops.edge.aten.any.dims, exir_ops.edge.aten.bitwise_and.Tensor, exir_ops.edge.aten.bitwise_or.Tensor, exir_ops.edge.aten.bitwise_xor.Tensor, diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index ad5a107f9da..81743f37b15 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -11,6 +11,7 @@ op_add, op_amax, op_amin, + op_any, op_avg_pool2d, op_bmm, op_cat, diff --git a/backends/arm/operators/op_any.py b/backends/arm/operators/op_any.py new file mode 100644 index 00000000000..ffb2e8a3c5d --- /dev/null +++ b/backends/arm/operators/op_any.py @@ -0,0 +1,53 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +from typing import cast, List + +import serializer.tosa_serializer as ts # type: ignore +from executorch.backends.arm.operators.node_visitor import ( # type: ignore + NodeVisitor, + register_node_visitor, +) + +from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore +from serializer.tosa_serializer import TosaOp +from torch.fx import Node + + +@register_node_visitor +class AnyVisitor(NodeVisitor): + target = "aten.any.dim" + + def define_node( + self, + node: Node, + tosa_graph: ts.TosaSerializer, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + + if not (inputs[0].dtype == output.dtype): + raise ValueError( + "All inputs and outputs need same dtype." + f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}." + ) + if not (inputs[0].dtype == ts.DType.BOOL): + raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}") + + input_shape = list(inputs[0].shape) + dim = cast(int, inputs[1].number) % len( + input_shape + ) # process the negative index + keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False) + if not keep_dim: + raise ValueError("This case should be handled by ConvertAnyDimDimsPass") + + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(inputs[0].dim_order.index(dim)) + + tosa_graph.addOperator( + TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr + ) diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index 0976b981f62..376285632ea 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -34,11 +34,10 @@ class TestConformer(unittest.TestCase): "executorch_exir_dialects_edge__ops_aten_max_default": 1, "executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2, "executorch_exir_dialects_edge__ops_aten_where_self": 4, - "executorch_exir_dialects_edge__ops_aten_any_dim": 2, "torch.ops.aten._assert_scalar.default": 10, "torch.ops.aten._local_scalar_dense.default": 1, "torch.ops.aten.scalar_tensor.default": 2, - "torch.ops.higher_order.executorch_call_delegate": 4, + "torch.ops.higher_order.executorch_call_delegate": 6, } dim = 16 diff --git a/backends/arm/test/ops/test_any.py b/backends/arm/test/ops/test_any.py new file mode 100644 index 00000000000..d73ee1fda66 --- /dev/null +++ b/backends/arm/test/ops/test_any.py @@ -0,0 +1,185 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU85PipelineBI, + OpNotSupportedPipeline, + TosaPipelineBI, + TosaPipelineMI, +) + + +class AnyDim(torch.nn.Module): + aten_op = "torch.ops.aten.any.dim" + exir_op = "executorch_exir_dialects_edge__ops_aten_any_dim" + + def forward(self, x: torch.Tensor, dim: int, keepdim: bool): + return torch.any(x, dim=dim, keepdim=keepdim) + + +class AnyDims(torch.nn.Module): + aten_op = "torch.ops.aten.any.dims" + exir_op = "executorch_exir_dialects_edge__ops_aten_any_dims" + + def forward(self, x: torch.Tensor, dim: List[int], keepdim: bool): + return torch.any(x, dim=dim, keepdim=keepdim) + + +class AnyReduceAll(torch.nn.Module): + aten_op = "torch.ops.aten.any.default" + exir_op = "executorch_exir_dialects_edge__ops_aten_any_default" + + def forward(self, x: torch.Tensor): + return torch.any(x) + + +input_t1 = Tuple[torch.Tensor] # Input x + + +test_input: dict[input_t1] = { + "rank1": (torch.tensor([True, False, False], dtype=torch.bool), 0, True), + "rank1_squeeze": (torch.tensor([True, False, False], dtype=torch.bool), -1, False), + "rank2": ( + torch.randint(0, 2, (2, 3), dtype=torch.bool), + 0, + True, + ), + "rank2_squeeze": ( + torch.randint(0, 2, (2, 3), dtype=torch.bool), + 0, + False, + ), + "rank2_dims": ( + torch.randint(0, 2, (2, 3), dtype=torch.bool), + [0, 1], + True, + ), + "rank2_dims_squeeze": ( + torch.randint(0, 2, (2, 3), dtype=torch.bool), + [-2, 1], + False, + ), + "rank3_dims_squeeze": ( + torch.randint(0, 2, (6, 8, 10), dtype=torch.bool), + [1, 2], + False, + ), + "rank4": ( + torch.randint(0, 2, (1, 6, 8, 10), dtype=torch.bool), + 1, + True, + ), + "rank4_squeeze": ( + torch.randint(0, 2, (1, 6, 8, 10), dtype=torch.bool), + 1, + False, + ), + "rank4_dims": ( + torch.randint(0, 2, (1, 6, 8, 10), dtype=torch.bool), + [0, 2], + True, + ), + "rank4_dims_squeeze": ( + torch.randint(0, 2, (1, 6, 8, 10), dtype=torch.bool), + [1, -1], + False, + ), + "rank1_reduce_all": (torch.tensor([True, False, False], dtype=torch.bool),), + "rank2_reduce_all": (torch.randint(0, 2, (2, 3), dtype=torch.bool),), + "rank3_reduce_all": (torch.randint(0, 2, (6, 8, 10), dtype=torch.bool),), + "rank4_reduce_all": (torch.randint(0, 2, (1, 6, 8, 10), dtype=torch.bool),), +} + + +test_data = { + "any_rank1": (AnyDim(), test_input["rank1"]), + "any_rank1_squeeze": (AnyDim(), test_input["rank1_squeeze"]), + "any_rank2": (AnyDim(), test_input["rank2"]), + "any_rank2_squeeze": (AnyDim(), test_input["rank2_squeeze"]), + "any_rank2_dims": (AnyDims(), test_input["rank2_dims"]), + "any_rank2_dims_squeeze": (AnyDims(), test_input["rank2_dims_squeeze"]), + "any_rank3_dims_squeeze": (AnyDims(), test_input["rank3_dims_squeeze"]), + "any_rank4": (AnyDim(), test_input["rank4"]), + "any_rank4_squeeze": (AnyDim(), test_input["rank4_squeeze"]), + "any_rank4_dims": (AnyDims(), test_input["rank4_dims"]), + "any_rank4_dims_squeeze": (AnyDims(), test_input["rank4_dims_squeeze"]), + "any_rank1_reduce_all": (AnyReduceAll(), test_input["rank1_reduce_all"]), + "any_rank2_reduce_all": (AnyReduceAll(), test_input["rank2_reduce_all"]), + "any_rank3_reduce_all": (AnyReduceAll(), test_input["rank3_reduce_all"]), + "any_rank4_reduce_all": (AnyReduceAll(), test_input["rank4_reduce_all"]), +} + + +fvp_xfails = { + "any_rank1": "MLETORCH-706 Support ScalarType::Bool in EthosUBackend.", + "any_rank1_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank2": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank2_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank2_dims": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank2_dims_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank3_dims_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank4": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank4_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank4_dims": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank4_dims_squeeze": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank1_reduce_all": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank2_reduce_all": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank3_reduce_all": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", + "any_rank4_reduce_all": "MLETORCH-706: Support ScalarType::Bool in EthosUBackend.", +} + + +@common.parametrize("test_data", test_data) +def test_any_tosa_MI(test_data: input_t1): + op, test_input = test_data + pipeline = TosaPipelineMI[input_t1](op, test_input, op.aten_op, op.exir_op) + pipeline.run() + + +@common.parametrize("test_data", test_data) +def test_any_tosa_BI(test_data: input_t1): + op, test_input = test_data + pipeline = TosaPipelineBI[input_t1](op, test_input, op.aten_op, op.exir_op) + pipeline.pop_stage(pipeline.find_pos("quantize") + 1) + pipeline.pop_stage("quantize") + pipeline.run() + + +@common.parametrize("test_data", test_data) +def test_logical_u55_BI(test_data: input_t1): + # Tests that we don't delegate these ops since they are not supported on U55. + op, test_input = test_data + pipeline = OpNotSupportedPipeline[input_t1]( + op, test_input, "TOSA-0.80+BI+u55", {op.exir_op: 1} + ) + pipeline.run() + + +@common.parametrize("test_data", test_data) +def test_floor_u85_BI(test_data: input_t1): + op, test_input = test_data + pipeline = EthosU85PipelineBI[input_t1]( + op, test_input, op.aten_op, op.exir_op, run_on_fvp=False + ) + pipeline.pop_stage(pipeline.find_pos("quantize") + 1) + pipeline.pop_stage("quantize") + pipeline.run() + + +@common.parametrize("test_data", test_data, fvp_xfails) +@common.SkipIfNoCorstone320 +def test_floor_u85_BI_on_fvp(test_data: input_t1): + op, test_input = test_data + pipeline = EthosU85PipelineBI[input_t1]( + op, test_input, op.aten_op, op.exir_op, run_on_fvp=True + ) + pipeline.pop_stage(pipeline.find_pos("quantize") + 1) + pipeline.pop_stage("quantize") + pipeline.run()