Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
)
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
from executorch.backends.arm._passes.convert_any_default_dim_dims_pass import (
ConvertAnyDefaultDimDimsPass,
)
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
ConvertExpandCopyToRepeatPass,
)
Expand Down Expand Up @@ -110,6 +113,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertFullLikeToFullPass())
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
self.add_pass(ConvertAnyDefaultDimDimsPass())

self.add_pass(ReplaceScalarWithTensorArgPass())
self.add_pass(AnnotateDecomposedMatmulPass())
Expand Down Expand Up @@ -155,6 +159,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(ConvertFullLikeToFullPass())
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
self.add_pass(ConvertAnyDefaultDimDimsPass())

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
Expand Down
106 changes: 106 additions & 0 deletions backends/arm/_passes/convert_any_default_dim_dims_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import ( # type: ignore[import-not-found]
ops as exir_ops,
)
from executorch.exir.pass_base import ( # type: ignore[import-not-found]
ExportPass,
PassResult,
)


class ConvertAnyDefaultDimDimsPass(ExportPass):
"""
Converts any.default, any.dim and any.dims to a sequence of any.dim by unrolling multi-dimensional reduction.
Please refer to KeepDimsFalseToSqueezePass for an explanation of this coversion.

Example 1
Original:
any() # x.shape: [dim1, dim2, ..., dimn]
After pass:
any.dim(dim1, keepdim = True)
any.dim(dim2, keepdim = True)
...
any.dim(dimn, keepdim = True)
squeeze(dim = [dim1, dim2, ...., dimn])

Example 2
Original:
any.dim(dim1, keepdim = False)
After pass:
any.dim(dim1, keepdim = True)
squeeze(dim = [dim1])

Example 3
Original:
any.dims([dim1, dim2], keepdim = False)
After pass:
any.dim(dim1, keepdim = True)
any.dim(dim2, keepdim = True)
squeeze(dim = [dim1, dim2])
"""

def call(self, graph_module: torch.fx.GraphModule):
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target not in [
exir_ops.edge.aten.any.default,
exir_ops.edge.aten.any.dim,
exir_ops.edge.aten.any.dims,
]:
continue

if len(node.args) == 1:
# any.default(input)
input_node = (node.args)[0]
dims = range(len(input_node.meta["val"].shape))
keepdim = False
elif len(node.args) == 2:
# any.dim/dims(input, dims=dims)
input_node, dims = node.args
keepdim = False
elif len(node.args) == 3:
# any.dim/dims(input, dims=dims, keepdim=keepdim)
input_node, dims, keepdim = node.args
else:
raise RuntimeError(
f"Unexpected arg size {len(node.args)} in {node.name}"
)
try:
iter(dims)
except:
dims = [dims] # type: ignore[assignment]
else:
dims = list(dims) # type: ignore[assignment]

# Unroll multi-dimensional reduction and keep-dims arg
with graph_module.graph.inserting_before(node):
for dim in dims:
args = (input_node, dim, True)
input_node = graph_module.graph.create_node(
"call_function", exir_ops.edge.aten.any.dim, args, node.kwargs
)

if not keepdim:
args = (input_node, dims) # type: ignore[assignment]
input_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten.squeeze_copy.dims,
args,
)

node.replace_all_uses_with(input_node)
modified = True

if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
5 changes: 3 additions & 2 deletions backends/arm/_passes/keep_dims_false_to_squeeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ class KeepDimsFalseToSqueezePass(ExportPass):
"""

# CURRENTLY NOT HANDLED OPS
# exir_ops.edge.aten.any.dim,
# exir_ops.edge.aten.any.dims,
# exir_ops.edge.aten.argmax,
# exir_ops.edge.aten.argmin,
# exir_ops.edge.aten.prod.dim_int,

# HANDLED OPS
# exir_ops.edge.aten.sum.dim_IntList
# exir_ops.edge.aten.any.default (decomposed in convert_any_default_dim_dims_pass)
# exir_ops.edge.aten.any.dim (decomposed in convert_any_default_dim_dims_pass)
# exir_ops.edge.aten.any.dims (decomposed in convert_any_default_dim_dims_pass)
# exir_ops.edge.aten.max.dim (decomposed in convert_minmax_pass)
# exir_ops.edge.aten.min.dim (decomposed in convert_minmax_pass)
# exir_ops.edge.aten.amin (decomposed in convert_minmax_pass)
Expand Down
6 changes: 6 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def is_node_supported(
supported = node.op == "call_function" and node.target in [
exir_ops.edge.aten.abs.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.any.default,
exir_ops.edge.aten.any.dim,
exir_ops.edge.aten.any.dims,
exir_ops.edge.aten.logical_and.default,
exir_ops.edge.aten.logical_or.default,
exir_ops.edge.aten.logical_xor.default,
Expand Down Expand Up @@ -194,6 +197,9 @@ def is_node_supported(
) -> bool:
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
unsupported_ops = [
exir_ops.edge.aten.any.default,
exir_ops.edge.aten.any.dim,
exir_ops.edge.aten.any.dims,
exir_ops.edge.aten.bitwise_and.Tensor,
exir_ops.edge.aten.bitwise_or.Tensor,
exir_ops.edge.aten.bitwise_xor.Tensor,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
op_add,
op_amax,
op_amin,
op_any,
op_avg_pool2d,
op_bmm,
op_cat,
Expand Down
53 changes: 53 additions & 0 deletions backends/arm/operators/op_any.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import cast, List

import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
NodeVisitor,
register_node_visitor,
)

from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class AnyVisitor(NodeVisitor):
target = "aten.any.dim"

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
)
if not (inputs[0].dtype == ts.DType.BOOL):
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")

input_shape = list(inputs[0].shape)
dim = cast(int, inputs[1].number) % len(
input_shape
) # process the negative index
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
if not keep_dim:
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")

attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(inputs[0].dim_order.index(dim))

tosa_graph.addOperator(
TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
)
3 changes: 1 addition & 2 deletions backends/arm/test/models/test_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,10 @@ class TestConformer(unittest.TestCase):
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
"executorch_exir_dialects_edge__ops_aten_any_dim": 2,
"torch.ops.aten._assert_scalar.default": 10,
"torch.ops.aten._local_scalar_dense.default": 1,
"torch.ops.aten.scalar_tensor.default": 2,
"torch.ops.higher_order.executorch_call_delegate": 4,
"torch.ops.higher_order.executorch_call_delegate": 6,
}

dim = 16
Expand Down
Loading
Loading