Skip to content

Commit c4cd274

Browse files
authored
Arm backend: Move rescales from SUM visitor to pass (#15299)
In the SUM node visitor, an INT8->INT32 RESCALE node is inserted prior to the SUM node; similarly, an INT32->INT8 RESCALE node is inserted after. This patch moves the insertion to `InsertRescaleInt32Pass`. Since SUM is decomposed, insertion of RESCALE nodes should be carried out before `DecomposeSumPass` (which decomposes SUM into a chain of single dim SUMs). The ordering is important to avoid redundant INT8/INT32 RESCALE nodes being inserted between each SUM node in the chain after decomposition. Only one INT8->INT32 RESCALE is needed before the chain, and an INT32->INT8 after it; between the SUM nodes in the chain, the edges are already in the correct INT32 data type. ### Test plan Tests exercising the modified pass of this patch have been added to backends/arm/test/passes/test_insert_rescale_i32_pass.py. Signed-off-by: Martin Lindström <[email protected]>
1 parent b87ef5d commit c4cd274

File tree

7 files changed

+63
-83
lines changed

7 files changed

+63
-83
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
194194
self.add_pass(ConvertExpandCopyToRepeatPass())
195195
self.add_pass(UnsqueezeBeforeRepeatPass())
196196
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
197-
self.add_pass(DecomposeSumPass())
198197
self.add_pass(DecomposeCumsumPass(exported_program))
199198
self.add_pass(Conv1dUnsqueezePass())
200199
self.add_pass(DecomposeMaxPool2DPass())
@@ -215,10 +214,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
215214
self.add_pass(RewriteMatmulPass())
216215
self.add_pass(RewriteUpsamplePass())
217216
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
217+
self.add_pass(InsertRescaleInt32Pass())
218+
self.add_pass(DecomposeSumPass())
218219
self.add_pass(ToTosaMemoryFormatPass(exported_program))
219220
self.add_pass(RemoveNoopPass())
220221
self.add_pass(InsertRescalePass())
221-
self.add_pass(InsertRescaleInt32Pass())
222222

223223
self.validate_constraints_mandatory()
224224
return self._transform(exported_program.graph_module)
@@ -361,7 +361,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
361361

362362
self.add_pass(ConvertMinMaxPass())
363363
self.add_pass(ReplaceInfValues())
364-
self.add_pass(DecomposeSumPass())
365364

366365
if not self.tosa_spec.is_U55_subset:
367366
# Uses where which is not supported on Ethos-U55

backends/arm/_passes/decompose_sum_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def call_operator(self, op, args, kwargs, meta):
8383
if not keepdims:
8484
shape = list(meta["val"].size())
8585
input_node = super().call_operator(
86-
view_op, (input_node, shape), kwargs, meta, updated=True
86+
view_op, (input_node, shape), {}, meta, updated=True
8787
)
8888

8989
return input_node

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from executorch.backends.arm._passes.arm_pass import ArmPass
1212
from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg
13+
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
1314
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1415
get_output_qparams,
1516
)
@@ -84,7 +85,11 @@ class InsertRescaleInt32Pass(ArmPass):
8485
parameters.
8586
"""
8687

87-
_passes_required_after: Set[Type[ExportPass]] = set()
88+
# SUM must be decomposed after this pass to prevent insertion of RESCALE
89+
# nodes between each subsequent SUM node after decomposition. RESCALE nodes
90+
# should only be inserted before and after the SUM node prior to its
91+
# decomposition.
92+
_passes_required_after: Set[Type[ExportPass]] = {DecomposeSumPass}
8893

8994
included_targets = [
9095
exir_ops.edge.aten.abs.default,
@@ -96,6 +101,7 @@ class InsertRescaleInt32Pass(ArmPass):
96101
exir_ops.edge.aten.maximum.default,
97102
exir_ops.edge.aten.minimum.default,
98103
exir_ops.edge.aten.mul.Tensor,
104+
exir_ops.edge.aten.sum.dim_IntList,
99105
]
100106

101107
def _int32_qargs(self, s):
@@ -138,6 +144,7 @@ def _get_inputs_rescaled_qparams(
138144
}
139145
elif target in [
140146
exir_ops.edge.aten.mul.Tensor,
147+
exir_ops.edge.aten.sum.dim_IntList,
141148
]:
142149
# The input scales do not need to be adjusted for these ops; they
143150
# can remain the same.
@@ -160,6 +167,7 @@ def _get_output_qparams(
160167
exir_ops.edge.aten.abs.default,
161168
exir_ops.edge.aten.maximum.default,
162169
exir_ops.edge.aten.minimum.default,
170+
exir_ops.edge.aten.sum.dim_IntList,
163171
]:
164172
# The op has not altered the scale; the output scale is equal to
165173
# the operands' scales.

backends/arm/operator_support/reduce_sum_support.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
2929

3030
# U55 case, Vela 4.2.0 (25.02 release)
3131
input_shape = node.all_input_nodes[0].meta["val"].shape
32-
dim_list = cast(list[int], node.args[1])
33-
dim_list = [dim % len(input_shape) for dim in dim_list]
32+
33+
if node.args[1] is None:
34+
# Dim is allowed to be None, which means to sum all dimensions
35+
dim_list = list(range(len(input_shape)))
36+
else:
37+
dim_list = cast(list[int], node.args[1])
38+
dim_list = [dim % len(input_shape) for dim in dim_list]
3439

3540
for dim in dim_list:
3641
if not 1 <= input_shape[dim] <= 65536:

backends/arm/operators/op_sum.py

Lines changed: 2 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from typing import Any, List
99

10-
import executorch.backends.arm.tosa.quant_utils as tqutils
11-
import executorch.backends.arm.tosa.utils as tutils
1210
import tosa_serializer as ts
1311

1412
from executorch.backends.arm.operators.node_visitor import (
@@ -25,69 +23,14 @@
2523

2624

2725
@register_node_visitor
28-
class SumVisitor_INT(NodeVisitor):
26+
class SumVisitor(NodeVisitor):
2927
target = "aten.sum.dim_IntList"
3028

3129
tosa_specs = [
30+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3231
TosaSpecification.create_from_string("TOSA-1.0+INT"),
3332
]
3433

35-
def __init__(self, *args):
36-
super().__init__(*args)
37-
38-
def define_node(
39-
self,
40-
node: Node,
41-
tosa_graph: Any,
42-
inputs: List[TosaArg],
43-
output: TosaArg,
44-
) -> None:
45-
validate_num_inputs(self.target, inputs, 3)
46-
validate_same_dtype(self.target, [inputs[0], output], ts)
47-
48-
tensor = inputs[0]
49-
input_shape = list(tensor.shape)
50-
dim = int(inputs[1].number % len(input_shape))
51-
52-
output_shape = input_shape
53-
output_shape[dim] = 1 # Output shape is input shape with dim reduced
54-
55-
# Rescale input to 32 bit
56-
rescaled_inputs, scale = tqutils.insert_rescale_ops_to_int32(
57-
tosa_graph, [tensor], node, self.tosa_spec
58-
)
59-
60-
attr = ts.TosaSerializerAttribute()
61-
attr.ReduceSumAttribute(tensor.dim_order.index(dim))
62-
63-
intermediate = tosa_graph.addIntermediate(
64-
tutils.tosa_shape(output_shape, tensor.dim_order),
65-
dtype=ts.DType.INT32,
66-
)
67-
68-
self._serialize_operator(
69-
node,
70-
tosa_graph,
71-
ts.Op.REDUCE_SUM,
72-
[rescaled_inputs[0].name],
73-
[intermediate.name],
74-
attr,
75-
)
76-
77-
tqutils.insert_rescale_op_to_int8(
78-
tosa_graph, intermediate, scale, node, self.tosa_spec
79-
)
80-
81-
82-
@register_node_visitor
83-
class SumVisitor_FP(SumVisitor_INT):
84-
# inheriting 'target' from INT class
85-
86-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
87-
88-
def __init__(self, *args):
89-
super().__init__(*args)
90-
9134
def define_node(
9235
self,
9336
node: Node,
@@ -102,9 +45,6 @@ def define_node(
10245
input_shape = list(tensor.shape)
10346
dim = int(inputs[1].number % len(input_shape))
10447

105-
output_shape = input_shape
106-
output_shape[dim] = 1 # Output shape is input shape with dim reduced
107-
10848
attr = ts.TosaSerializerAttribute()
10949
attr.ReduceSumAttribute(tensor.dim_order.index(dim))
11050

backends/arm/test/ops/test_sum.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class Sum(torch.nn.Module):
3535
"4d_dim_3_keep": lambda: (torch.rand(1, 2, 3, 4), 3, True),
3636
"4d_dims_keep": lambda: (torch.rand(1, 2, 8, 8), [2, 3, 0], True),
3737
"dim_None": lambda: (torch.rand(10), None, True),
38+
"dim_None_4d_tensor": lambda: (torch.rand(10, 3, 2, 1), None, True),
3839
}
3940

4041
def forward(self, x: torch.Tensor, dim: int, keepdim: bool):

backends/arm/test/passes/test_insert_rescale_i32_pass.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@
1313
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
1414

1515

16-
class NeedsRescaleOps(torch.nn.Module):
16+
class MultipleOpsModel(torch.nn.Module):
1717
"""A module containing ops that require INT32 inputs/outputs."""
1818

1919
input_t = Tuple[torch.Tensor, torch.Tensor]
2020

21-
def __init__(self):
22-
super().__init__()
23-
2421
def forward(self, x, y):
2522
a = x * y
2623
b = torch.maximum(a, y)
@@ -39,19 +36,41 @@ def get_inputs(self, dtype) -> input_t:
3936
else:
4037
raise ValueError("Not a valid input dtype for model")
4138

39+
def get_num_expected_rescales(self):
40+
# "number of op nodes with i8 output" + "number of i8 node inputs"
41+
return 3 + 7
4242

43-
def test_insert_rescales():
44-
module = NeedsRescaleOps()
45-
input_t = Tuple[torch.Tensor, torch.Tensor]
43+
44+
class SumModel(torch.nn.Module):
45+
input_t = Tuple[torch.Tensor]
46+
47+
def forward(self, x):
48+
a = torch.sum(x, 2, keepdim=True) # (1, 2, 1, 4)
49+
b = torch.sum(a, [1, 3], keepdim=True) # (1, 1, 1, 1)
50+
c = torch.sum(b, [0, 2], keepdim=False) # (1, 1)
51+
return c
52+
53+
def get_inputs(self, dtype) -> input_t:
54+
if dtype == torch.float32:
55+
return (torch.rand(1, 2, 3, 4),)
56+
elif dtype == torch.int32:
57+
return (torch.randint(0, 10, (1, 2, 3, 4), dtype=torch.int32),)
58+
else:
59+
raise ValueError("Not a valid input dtype for model")
60+
61+
def get_num_expected_rescales(self):
62+
# Two RESCALE nodes per SUM node
63+
return 6
64+
65+
66+
def _test_model_with_f32_data(model):
4667
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
4768
ops_after = {
48-
# "number of op nodes with i8 output" + "number of i8 node inputs"
49-
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3
50-
+ 7,
69+
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": model.get_num_expected_rescales(),
5170
}
52-
pipeline = PassPipeline[input_t](
53-
module,
54-
module.get_inputs(torch.float32),
71+
pipeline = PassPipeline[model.input_t](
72+
model,
73+
model.get_inputs(torch.float32),
5574
quantize=True,
5675
ops_not_before_pass=ops_not_before,
5776
ops_after_pass=ops_after,
@@ -61,8 +80,16 @@ def test_insert_rescales():
6180
pipeline.run()
6281

6382

83+
def test_insert_rescales_sum_model():
84+
_test_model_with_f32_data(SumModel())
85+
86+
87+
def test_insert_rescales_multiple_ops_model():
88+
_test_model_with_f32_data(MultipleOpsModel())
89+
90+
6491
def test_dont_insert_rescales():
65-
module = NeedsRescaleOps()
92+
module = MultipleOpsModel()
6693
input_t = Tuple[torch.Tensor, torch.Tensor]
6794
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
6895
# All inputs are already i32. Rescales should not be added.

0 commit comments

Comments
 (0)