Skip to content

Commit a3ff326

Browse files
committed
Revert "Arm backend: support mean.default (#15363)"
This reverts commit 9075855.
1 parent 008a014 commit a3ff326

File tree

4 files changed

+8
-57
lines changed

4 files changed

+8
-57
lines changed

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919

2020

2121
def get_meandim_decomposition(op) -> tuple:
22-
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
22+
if op == exir_ops.edge.aten.mean.dim:
2323
return (
2424
exir_ops.edge.aten.sum.dim_IntList,
2525
exir_ops.edge.aten.full.default,
2626
exir_ops.edge.aten.mul.Tensor,
2727
)
28-
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
28+
if op == torch.ops.aten.mean.dim:
2929
return (
3030
torch.ops.aten.sum.dim_IntList,
3131
torch.ops.aten.full.default,
@@ -35,17 +35,17 @@ def get_meandim_decomposition(op) -> tuple:
3535

3636

3737
def get_avgpool(op):
38-
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
38+
if op == exir_ops.edge.aten.mean.dim:
3939
return exir_ops.edge.aten.avg_pool2d.default
40-
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
40+
if op == torch.ops.aten.mean.dim:
4141
return torch.ops.aten.avg_pool2d.default
4242
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
4343

4444

4545
def get_view(op):
46-
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
46+
if op == exir_ops.edge.aten.mean.dim:
4747
return exir_ops.edge.aten.view_copy.default
48-
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
48+
if op == torch.ops.aten.mean.dim:
4949
return torch.ops.aten.view_copy.default
5050
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5151

@@ -87,18 +87,13 @@ def __init__(self, graph_module, tosa_spec):
8787
)
8888

8989
def call_operator(self, op, args, kwargs, meta):
90-
if op not in (
91-
exir_ops.edge.aten.mean.dim,
92-
torch.ops.aten.mean.dim,
93-
exir_ops.edge.aten.mean.default,
94-
torch.ops.aten.mean.default,
95-
):
90+
if op not in (exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim):
9691
return super().call_operator(op, args, kwargs, meta)
9792

9893
x = get_node_arg(args, 0)
9994
input_shape = list(x.data.shape)
10095
output_shape = list(meta["val"].shape)
101-
dims_to_reduce = get_node_arg(args, 1, range(len(input_shape)))
96+
dims_to_reduce = get_node_arg(args, 1)
10297
if dims_to_reduce is None:
10398
dims_to_reduce = range(len(input_shape))
10499
dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]

backends/arm/operator_support/tosa_profile_supported_op_lists.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@
178178
exir_ops.edge.aten.native_group_norm.default,
179179
exir_ops.edge.aten.sigmoid.default,
180180
exir_ops.edge.aten.mean.dim,
181-
exir_ops.edge.aten.mean.default,
182181
exir_ops.edge.aten.mm.default,
183182
exir_ops.edge.aten.minimum.default,
184183
exir_ops.edge.aten.maximum.default,

backends/arm/scripts/parse_test_names.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"hardswish.default",
1515
"linear.default",
1616
"maximum.default",
17-
"mean.default",
1817
"multihead_attention.default",
1918
"adaptive_avg_pool2d.default",
2019
"bitwise_right_shift.Tensor",

backends/arm/test/ops/test_mean_dim.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
7-
from typing import Callable
8-
97
import torch
108
from executorch.backends.arm.test import common
119
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -346,43 +344,3 @@ def test_mean_dim_vgf_INT(test_data):
346344
tosa_version="TOSA-1.0+INT",
347345
)
348346
pipeline.run()
349-
350-
351-
mean_input_t = tuple[torch.Tensor, bool]
352-
353-
354-
class MeanDefault(torch.nn.Module):
355-
def forward(self, tensor: torch.Tensor, keepdim: bool):
356-
return tensor.mean()
357-
358-
test_data_suite: dict[str, Callable[[], mean_input_t]] = {
359-
"rank1": lambda: (
360-
torch.rand(
361-
1,
362-
),
363-
False,
364-
),
365-
"rank2": lambda: (torch.rand(5, 5), True),
366-
"rank4": lambda: (torch.rand(5, 1, 10, 1), False),
367-
}
368-
369-
370-
@common.parametrize("test_data", MeanDefault.test_data_suite)
371-
def test_mean_tosa_FP(test_data):
372-
pipeline = TosaPipelineFP[mean_input_t](
373-
MeanDefault(),
374-
test_data(),
375-
[], # Might be sum, avgpool, or both
376-
)
377-
pipeline.run()
378-
379-
380-
@common.parametrize("test_data", MeanDefault.test_data_suite)
381-
def test_mean_tosa_INT(test_data):
382-
pipeline = TosaPipelineINT[mean_input_t](
383-
MeanDefault(),
384-
test_data(),
385-
[], # Might be sum, avgpool, or both
386-
symmetric_io_quantization=True,
387-
)
388-
pipeline.run()

0 commit comments

Comments
 (0)