Skip to content

Commit 28e5228

Browse files
Arm backend: Decompose sub/add with alpha!=1 (#14939)
This was previously not supported, causing crashes in quantization, and incorrect output in floating point. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Erik Lundell <[email protected]> Co-authored-by: Erik Lundell <[email protected]>
1 parent 1121cba commit 28e5228

File tree

5 files changed

+125
-1
lines changed

5 files changed

+125
-1
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .convert_to_clamp import ConvertToClampPass # noqa
2828
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
2929
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
30+
from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa
3031
from .decompose_addmm_pass import DecomposeAddmmPass # noqa
3132
from .decompose_asin_and_acos_pass import DecomposeAsinAndAcosPass # noqa
3233
from .decompose_asinh_pass import DecomposeAsinhPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
DecomposeAcoshPass,
3434
DecomposeAdaptiveAvgPool2dPass,
3535
DecomposeAddmmPass,
36+
DecomposeAddSubAlphaPass,
3637
DecomposeAsinAndAcosPass,
3738
DecomposeAsinhPass,
3839
DecomposeAtanhPass,
@@ -217,6 +218,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
217218
)
218219
self.add_pass(DecomposeNotEqualPass())
219220
self.add_pass(DecomposeDivPass())
221+
self.add_pass(DecomposeAddSubAlphaPass())
220222
self.add_pass(DecomposeSoftmaxPass())
221223
self.add_pass(DecomposeGeluPass())
222224
self.add_pass(ConvertFullLikeToFullPass())
@@ -286,6 +288,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
286288
self.add_pass(DecomposeSignPass())
287289
self.add_pass(DecomposeAddmmPass())
288290
self.add_pass(DecomposeDivTensorModePass())
291+
self.add_pass(DecomposeAddSubAlphaPass())
289292
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
290293
self.add_pass(ScalarsToAttributePass())
291294
self.add_pass(DecomposeGroupNormPass())
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
import numbers
9+
from typing import Set, Type
10+
11+
import torch
12+
from executorch.backends.arm._passes import ArmPass
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass
15+
16+
17+
_ADD_OPS = (
18+
exir_ops.edge.aten.add.Tensor,
19+
torch.ops.aten.add.Tensor,
20+
)
21+
22+
_SUB_OPS = (
23+
exir_ops.edge.aten.sub.Tensor,
24+
torch.ops.aten.sub.Tensor,
25+
)
26+
27+
28+
def _get_ops(op):
29+
if op in _ADD_OPS:
30+
if op is exir_ops.edge.aten.add.Tensor:
31+
return (
32+
exir_ops.edge.aten.mul.Tensor,
33+
exir_ops.edge.aten.full.default,
34+
exir_ops.edge.aten.add.Tensor,
35+
)
36+
return (
37+
torch.ops.aten.mul.Tensor,
38+
torch.ops.aten.full.default,
39+
torch.ops.aten.add.Tensor,
40+
)
41+
if op in _SUB_OPS:
42+
if op is exir_ops.edge.aten.sub.Tensor:
43+
return (
44+
exir_ops.edge.aten.mul.Tensor,
45+
exir_ops.edge.aten.full.default,
46+
exir_ops.edge.aten.sub.Tensor,
47+
)
48+
return (
49+
torch.ops.aten.mul.Tensor,
50+
torch.ops.aten.full.default,
51+
torch.ops.aten.sub.Tensor,
52+
)
53+
raise RuntimeError(f"Unsupported operator {op}")
54+
55+
56+
def _should_decompose(alpha) -> bool:
57+
if isinstance(alpha, numbers.Number):
58+
return alpha != 1
59+
return False
60+
61+
62+
class DecomposeAddSubAlphaPass(ArmPass):
63+
"""Rewrite add/sub with alpha into a mul followed by add/sub."""
64+
65+
_passes_required_after: Set[Type[ExportPass]] = set()
66+
67+
def call_operator(self, op, args, kwargs, meta, updated: bool | None = False):
68+
if op not in _ADD_OPS + _SUB_OPS:
69+
return super().call_operator(op, args, kwargs, meta, updated)
70+
71+
alpha = kwargs.get("alpha", 1)
72+
if not _should_decompose(alpha):
73+
return super().call_operator(op, args, kwargs, meta, updated)
74+
75+
mul_op, full_op, binary_op = _get_ops(op)
76+
lhs, rhs = args
77+
78+
alpha_full = super().call_operator(
79+
full_op, ((1,), float(alpha)), {}, meta, updated=True
80+
)
81+
scaled_rhs = super().call_operator(
82+
mul_op,
83+
(rhs, alpha_full),
84+
{},
85+
meta,
86+
updated=True,
87+
)
88+
return super().call_operator(
89+
binary_op,
90+
(lhs, scaled_rhs),
91+
{},
92+
meta,
93+
updated=True,
94+
)

backends/arm/test/ops/test_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
7878

7979
class Add3(torch.nn.Module):
8080
def forward(self, x: torch.Tensor, y: torch.Tensor):
81-
return x + y
81+
return torch.add(x, y, alpha=1.5)
8282

8383
test_data: list[input_t2] = {
8484
"3d_randn_diff_rank": lambda: (torch.randn(1, 4, 5), torch.randn(4, 1)),

backends/arm/test/ops/test_sub.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
7373
return x - y
7474

7575

76+
class SubAlpha(torch.nn.Module):
77+
def forward(self, x: torch.Tensor, y: torch.Tensor):
78+
return torch.sub(x, y, alpha=5)
79+
80+
7681
class SubTan(torch.nn.Module):
7782

7883
def forward(self, x: torch.Tensor, y: torch.Tensor):
@@ -109,6 +114,18 @@ def test_sub_tensor_tosa_FP_2(test_data: Tuple[torch.Tensor, torch.Tensor]):
109114
pipeline.run()
110115

111116

117+
@common.parametrize("test_data", sub_tan_test_data)
118+
def test_sub_tensor_tosa_FP_alpha(test_data: Tuple[torch.Tensor, torch.Tensor]):
119+
"""Test Two-Operand Subtraction with alpha (TOSA FP)"""
120+
pipeline = TosaPipelineFP[input_t2](
121+
SubAlpha(),
122+
test_data(),
123+
aten_op,
124+
exir_op,
125+
)
126+
pipeline.run()
127+
128+
112129
@common.parametrize("test_data", sub_test_data)
113130
def test_sub_tensor_tosa_INT(test_data):
114131
"""Test Subtraction (TOSA INT)"""
@@ -132,6 +149,15 @@ def test_sub_tensor_tosa_INT_3(test_data: Tuple[torch.Tensor, torch.Tensor]):
132149
pipeline.run()
133150

134151

152+
@common.parametrize("test_data", sub_tan_test_data)
153+
def test_sub_tensor_tosa_INT_alpha(test_data: Tuple[torch.Tensor, torch.Tensor]):
154+
"""Test Two-Operand Subtraction with alpha (TOSA INT)"""
155+
pipeline = TosaPipelineINT[input_t2](
156+
SubAlpha(), test_data(), aten_op, exir_op, qtol=0
157+
)
158+
pipeline.run()
159+
160+
135161
@common.parametrize("test_data", sub_test_data)
136162
@common.XfailIfNoCorstone300
137163
def test_sub_tensor_u55_INT(test_data):

0 commit comments

Comments
 (0)