Skip to content

Commit 8fcd2b2

Browse files
committed
Move ReplaceScalarWithTensorArgPass to transforms
The pass is general and can be used by multiple backends. Use it in Arm backend and make small adjustments to make it work. Signed-off-by: Erik Lundell <[email protected]> Change-Id: I61863d9cefb1753c604d67a6e44845af46ef7c60
1 parent 5e4d6b6 commit 8fcd2b2

File tree

7 files changed

+123
-65
lines changed

7 files changed

+123
-65
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@
7777
)
7878
from executorch.backends.arm.tosa_specification import TosaSpecification
7979

80+
from executorch.backends.transforms.replace_scalar_with_tensor import (
81+
ReplaceScalarWithTensorArgPass,
82+
)
8083
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
8184
from executorch.exir import ExportedProgram
8285
from executorch.exir.pass_manager import PassManager
@@ -102,6 +105,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
102105
self.add_pass(ConvertMeanDimToAveragePoolPass())
103106
self.add_pass(ConvertFullLikeToFullPass())
104107

108+
self.add_pass(ReplaceScalarWithTensorArgPass())
105109
self.add_pass(AnnotateDecomposedMatmulPass())
106110
self.add_pass(QuantizeOperatorArguments())
107111
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
@@ -125,7 +129,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
125129
return self._transform(exported_program.graph_module)
126130

127131
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
128-
132+
self.add_pass(ReplaceScalarWithTensorArgPass())
129133
self.add_pass(FuseQuantizedActivationPass())
130134
self.add_pass(RemoveGetItemPass())
131135
self.add_pass(ConvertSplitToSlicePass())
@@ -176,6 +180,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
176180

177181
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
178182
self.add_pass(ScalarsToAttributePass())
183+
self.add_pass(ReplaceScalarWithTensorArgPass())
179184
self.add_pass(DecomposeLayerNormPass())
180185
self.add_pass(DecomposeVarPass())
181186
self.add_pass(DecomposeMeanDimPass())

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
111111
exir_ops.edge.aten.le.Tensor,
112112
exir_ops.edge.aten.lt.Tensor,
113113
exir_ops.edge.aten.mul.Tensor,
114+
exir_ops.edge.aten.add.Scalar,
115+
exir_ops.edge.aten.sub.Scalar,
116+
exir_ops.edge.aten.mul.Scalar,
117+
exir_ops.edge.aten.div.Scalar,
114118
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
115119
exir_ops.edge.aten.native_layer_norm.default,
116120
exir_ops.edge.aten.sigmoid.default,

backends/arm/test/models/test_conformer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,14 @@ class TestConformer(unittest.TestCase):
3232
ops_after_partitioner = {
3333
"executorch_exir_dialects_edge__ops_aten_arange_start_step": 1,
3434
"executorch_exir_dialects_edge__ops_aten_max_default": 1,
35-
"executorch_exir_dialects_edge__ops_aten_mul_Scalar": 4,
3635
"executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2,
3736
"executorch_exir_dialects_edge__ops_aten_where_self": 4,
3837
"executorch_exir_dialects_edge__ops_aten_logical_not_default": 4,
3938
"executorch_exir_dialects_edge__ops_aten_any_dim": 2,
4039
"torch.ops.aten._assert_scalar.default": 10,
4140
"torch.ops.aten._local_scalar_dense.default": 1,
4241
"torch.ops.aten.scalar_tensor.default": 2,
43-
"torch.ops.higher_order.executorch_call_delegate": 5,
42+
"torch.ops.higher_order.executorch_call_delegate": 4,
4443
}
4544

4645
dim = 16

backends/arm/test/ops/test_scalars.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
import unittest
27

38
import torch
@@ -50,6 +55,22 @@ class Mul(torch.nn.Module):
5055
def forward(self, x, y):
5156
return x * y
5257

58+
class MulScalar(torch.nn.Module):
59+
def forward(self, x, y):
60+
return torch.ops.aten.mul.Scalar(x, y)
61+
62+
class DivScalar(torch.nn.Module):
63+
def forward(self, x, y):
64+
return torch.ops.aten.div.Scalar(x, y)
65+
66+
class AddScalar(torch.nn.Module):
67+
def forward(self, x, y):
68+
return torch.ops.aten.add.Scalar(x, y)
69+
70+
class SubScalar(torch.nn.Module):
71+
def forward(self, x, y):
72+
return torch.ops.aten.sub.Scalar(x, y)
73+
5374
class AddInplace(torch.nn.Module):
5475
def forward(self, x, y):
5576
x += y
@@ -91,6 +112,10 @@ def forward(self, x):
91112
("Sub_", SubInplace()),
92113
("Mul_", MulInplace()),
93114
("Div_", DivInplace()),
115+
("MulScalar", MulScalar()),
116+
("DivScalar", DivScalar()),
117+
("AddScalar", AddScalar()),
118+
("SubScalar", SubScalar()),
94119
]
95120

96121
const_ops = [("Add", AddConst())]
@@ -108,8 +133,8 @@ def forward(self, x):
108133
scalar = dtype[1]
109134
tensor_scalar_tests.append((test_name + "_ts", op[1], tensor, scalar))
110135

111-
# Don't add (scalar, tensor) test case for inplace ops.
112-
if op[0][-1] == "_":
136+
# Don't add (scalar, tensor) test case for inplace and .Scalar ops.
137+
if op[0][-1] == "_" or op[0][-6:] == "Scalar":
113138
continue
114139

115140
# sub(scalar, tensor) does not work in any case.

backends/arm/tosa_mapping.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# the standardised TOSA representation.
1212
#
1313

14+
from typing import Sequence
15+
1416
import serializer.tosa_serializer as ts # type: ignore
1517
import torch
1618

@@ -95,7 +97,7 @@ def __init__(self, argument) -> None:
9597
if isinstance(argument, torch.fx.node.Node):
9698
self.__process_node(argument)
9799
return
98-
if isinstance(argument, list):
100+
if isinstance(argument, Sequence):
99101
self.__process_list(argument)
100102
return
101103
if isinstance(argument, int):

backends/cadence/aot/replace_ops.py

Lines changed: 7 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -37,6 +38,9 @@
3738
)
3839
from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass
3940
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
41+
from executorch.backends.transforms.replace_scalar_with_tensor import (
42+
ReplaceScalarWithTensorArgPass,
43+
)
4044
from executorch.exir.dialects._ops import ops as exir_ops
4145
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
4246
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
@@ -1713,65 +1717,9 @@ def call_operator(self, op, args, kwargs, meta):
17131717
)
17141718

17151719

1716-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1717-
class ReplaceScalarWithTensorArgPass(ExportPass):
1718-
"""
1719-
For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar,
1720-
replace the scalar arg with Tensor arg.
1721-
"""
1722-
1723-
scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
1724-
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
1725-
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
1726-
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
1727-
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
1728-
}
1729-
1730-
def get_replacement(self, op, args, kwargs, meta):
1731-
return super().call_operator(
1732-
# Replace with .Tensor variant.
1733-
op=self.scalar_to_tensor_ops[op],
1734-
args=(
1735-
# Tensor arg.
1736-
args[0],
1737-
# Scalar arg - replace with aten.full tensor.
1738-
super().call_operator(
1739-
exir_ops.edge.aten.full.default,
1740-
args=(
1741-
(1,),
1742-
args[1],
1743-
),
1744-
kwargs={"dtype": args[0].to_tensor().dtype},
1745-
meta=meta,
1746-
),
1747-
# Other args.
1748-
*args[2:],
1749-
),
1750-
kwargs=kwargs,
1751-
meta=meta,
1752-
)
1753-
1754-
def call_operator(self, op, args, kwargs, meta):
1755-
if op not in self.scalar_to_tensor_ops:
1756-
return super().call_operator(op, args, kwargs, meta)
1757-
1758-
# There must be exactly 2 args (3 for add and sub containing alpha)
1759-
assert len(args) == 2 or len(args) == 3
1760-
1761-
# If there are two args, just replace the op.
1762-
if len(args) == 2:
1763-
return self.get_replacement(op, args, kwargs, meta)
1764-
1765-
# In case the op has three args, it must be scalar add/sub op.
1766-
if (
1767-
op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar}
1768-
or "alpha" in kwargs
1769-
):
1770-
return super().call_operator(op, args, kwargs, meta)
1771-
1772-
return self.get_replacement(op, args, kwargs, meta)
1773-
1774-
1720+
@register_cadence_pass(CadencePassAttribute(opt_level=0))(
1721+
ReplaceScalarWithTensorArgPass()
1722+
)
17751723
@register_cadence_pass(CadencePassAttribute(opt_level=0))
17761724
class ReplaceScalarTensorWithFullPass(ExportPass):
17771725
"""
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from typing import Dict
9+
10+
import torch
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
13+
from executorch.exir.pass_base import ExportPass
14+
15+
16+
class ReplaceScalarWithTensorArgPass(ExportPass):
17+
"""
18+
For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar,
19+
replace the scalar arg with Tensor arg.
20+
"""
21+
22+
scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = {
23+
exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor,
24+
exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor,
25+
exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor,
26+
exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor,
27+
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
28+
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
29+
torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor,
30+
torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor,
31+
}
32+
33+
def get_replacement(self, op, args, kwargs, meta):
34+
return super().call_operator(
35+
# Replace with .Tensor variant.
36+
op=self.scalar_to_tensor_ops[op],
37+
args=(
38+
# Tensor arg.
39+
args[0],
40+
# Scalar arg - replace with aten.full tensor.
41+
super().call_operator(
42+
exir_ops.edge.aten.full.default,
43+
args=(
44+
(1,),
45+
args[1],
46+
),
47+
kwargs={"dtype": args[0].to_tensor().dtype},
48+
meta=meta,
49+
),
50+
# Other args.
51+
*args[2:],
52+
),
53+
kwargs=kwargs,
54+
meta=meta,
55+
)
56+
57+
def call_operator(self, op, args, kwargs, meta):
58+
if op not in self.scalar_to_tensor_ops:
59+
return super().call_operator(op, args, kwargs, meta)
60+
61+
# There must be exactly 2 args (3 for add and sub containing alpha)
62+
assert len(args) == 2 or len(args) == 3
63+
64+
# If there are two args, just replace the op.
65+
if len(args) == 2:
66+
return self.get_replacement(op, args, kwargs, meta)
67+
68+
# In case the op has three args, it must be scalar add/sub op.
69+
if (
70+
op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar}
71+
or "alpha" in kwargs
72+
):
73+
return super().call_operator(op, args, kwargs, meta)
74+
75+
return self.get_replacement(op, args, kwargs, meta)

0 commit comments

Comments
 (0)