Skip to content

Commit 0b65f38

Browse files
Martin LindströmMartin Lindström
authored andcommitted
Revert "Arm backend: Move rescales from SUB visitor to pass"
This reverts commit f21cf7f.
1 parent f21cf7f commit 0b65f38

File tree

3 files changed

+108
-46
lines changed

3 files changed

+108
-46
lines changed

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ class InsertRescaleInt32Pass(ArmPass):
9393
exir_ops.edge.aten.lt.Tensor,
9494
exir_ops.edge.aten.maximum.default,
9595
exir_ops.edge.aten.minimum.default,
96-
exir_ops.edge.aten.sub.Tensor,
9796
]
9897

9998
def _int32_qargs(self, s):
@@ -134,33 +133,6 @@ def _get_inputs_rescaled_qparams(
134133
qparams = {
135134
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
136135
}
137-
elif target in [
138-
exir_ops.edge.aten.sub.Tensor,
139-
]:
140-
if input_qparams[0].dtype != input_qparams[1].dtype:
141-
raise ValueError(
142-
"Mismatch in dtype args: {input_qparams[0].dtype} != {input_qparams[1].dtype}"
143-
)
144-
145-
# We are handling two INT8 or two INT16 numbers. For INT8, if the
146-
# zero point is non-null, the result will be in the range [-255;
147-
# 255], therefore we need 9 bits for the result. We have a 32-bit
148-
# accumulator, so we can divide the scale by (1 << 20) which is
149-
# equivalent to shifting the INT8 operands 20 bits to the left
150-
# before rescaling them both to 2 * max(lhs, rhs).
151-
#
152-
# For INT16, similary logic can be applied, but we instead end up
153-
# with a left shift of 12.
154-
lhs_scale, rhs_scale = (
155-
qp.get_scale_per_tensor() for qp in input_qparams.values()
156-
)
157-
max_scale_2x = 2 * max(lhs_scale, rhs_scale)
158-
159-
# Select shift based on input dtype.
160-
shift_bits = 12 if input_qparams[0].dtype == torch.int16 else 20
161-
162-
scale = max_scale_2x / (1 << shift_bits)
163-
qparams = {i: self._int32_qargs(scale) for i in range(len(input_qparams))}
164136
else:
165137
raise ValueError(f"Not a valid target: {target}")
166138

@@ -176,7 +148,6 @@ def _get_output_qparams(
176148
exir_ops.edge.aten.abs.default,
177149
exir_ops.edge.aten.maximum.default,
178150
exir_ops.edge.aten.minimum.default,
179-
exir_ops.edge.aten.sub.Tensor,
180151
]:
181152
# The op has not altered the scale; the output scale is equal to
182153
# the operands' scales.
@@ -216,7 +187,7 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b
216187
modified = False
217188
for i in qargs:
218189
qp = qargs[i]
219-
if qp.dtype not in (torch.int8, torch.int16):
190+
if qp.dtype != torch.int8:
220191
continue
221192

222193
arg_node = args_copy[i]
@@ -255,7 +226,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
255226
assert rescale_qargs is not None
256227

257228
qarg = qargs[0]
258-
if qarg.dtype not in (torch.int8, torch.int16):
229+
if qarg.dtype != torch.int8:
259230
return False
260231

261232
users_copy = list(node.users)
@@ -266,7 +237,7 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
266237
exir_ops.backend.tosa.RESCALE.default,
267238
(
268239
node,
269-
qarg.dtype,
240+
torch.int8,
270241
rescale_qargs.get_scale_per_tensor()
271242
/ qarg.get_scale_per_tensor(), # Old scale / new scale
272243
rescale_qargs.get_zp_per_tensor(), # Old zero point

backends/arm/operators/op_sub.py

Lines changed: 99 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
from typing import Any, List
99

10+
import executorch.backends.arm.tosa.quant_utils as tqutils
11+
import executorch.backends.arm.tosa.utils as tutils
12+
1013
from executorch.backends.arm.operators.node_visitor import (
1114
NodeVisitor,
1215
register_node_visitor,
@@ -16,20 +19,22 @@
1619
validate_same_dtype,
1720
validate_valid_dtype,
1821
)
22+
from executorch.backends.arm.tosa import TosaSpecification
1923
from executorch.backends.arm.tosa.mapping import TosaArg
20-
from executorch.backends.arm.tosa.specification import TosaSpecification
2124
from torch.fx import Node
2225

2326

2427
@register_node_visitor
25-
class SubVisitor(NodeVisitor):
28+
class SubVisitor_INT(NodeVisitor):
2629
target = "aten.sub.Tensor"
2730

2831
tosa_specs = [
2932
TosaSpecification.create_from_string("TOSA-1.0+INT"),
30-
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3133
]
3234

35+
def __init__(self, *args):
36+
super().__init__(*args)
37+
3338
def define_node(
3439
self,
3540
node: Node,
@@ -45,18 +50,105 @@ def define_node(
4550
validate_valid_dtype(
4651
self.target,
4752
[*inputs, output],
48-
[ts.DType.INT32, ts.DType.FP32],
53+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
4954
output.tosa_spec,
5055
)
5156

57+
scale_back = 1.0
58+
if inputs[0].dtype == ts.DType.INT8:
59+
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
60+
tosa_graph, inputs, node, self.tosa_spec
61+
)
62+
elif inputs[0].dtype == ts.DType.INT16:
63+
rescaled_inputs, scale_back = (
64+
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
65+
tosa_graph, inputs, node, self.tosa_spec
66+
)
67+
)
68+
else:
69+
# input[0].dtype == ts.DType.INT32
70+
# Non quantized input, natively support by TOSA.SUB
71+
rescaled_inputs = inputs
72+
73+
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
74+
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
75+
sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
76+
else:
77+
# output.dtype == ts.DType.INT32
78+
sub_output = output
79+
80+
# Do the INT32 Sub
5281
self._serialize_operator(
5382
node,
5483
tosa_graph,
5584
ts.TosaOp.Op().SUB,
5685
[
57-
inputs[0].name,
58-
inputs[1].name,
86+
rescaled_inputs[0].name,
87+
rescaled_inputs[1].name,
5988
],
60-
[output.name],
89+
[sub_output.name],
6190
None,
6291
)
92+
93+
if output.dtype == ts.DType.INT8:
94+
# Scale output back to 8 bit
95+
# pyre-ignore
96+
tqutils.insert_rescale_op_to_int8(
97+
tosa_graph,
98+
sub_output,
99+
scale_back,
100+
node,
101+
compute_rescale=False,
102+
tosa_spec=self.tosa_spec,
103+
) # type: ignore[possibly-undefined]
104+
elif output.dtype == ts.DType.INT16:
105+
tqutils.insert_rescale_op_to_int16(
106+
tosa_graph,
107+
sub_output,
108+
scale_back,
109+
node,
110+
compute_rescale=False,
111+
tosa_spec=self.tosa_spec,
112+
) # type: ignore[possibly-undefined]
113+
114+
115+
@register_node_visitor
116+
class SubVisitor_FP(SubVisitor_INT):
117+
# inheriting 'target' from INT class
118+
119+
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
120+
121+
def __init__(self, *args):
122+
super().__init__(*args)
123+
124+
def define_node(
125+
self,
126+
node: Node,
127+
tosa_graph: Any,
128+
inputs: List[TosaArg],
129+
output: TosaArg,
130+
) -> None:
131+
132+
import serializer.tosa_serializer as ts # type: ignore
133+
134+
validate_num_inputs(self.target, inputs, 2)
135+
validate_same_dtype(self.target, [*inputs, output], ts)
136+
137+
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
138+
# Call the inherited define_node for handling integers
139+
super().define_node(node, tosa_graph, inputs, output)
140+
else:
141+
# FP32 Sub lowering
142+
validate_valid_dtype(
143+
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
144+
)
145+
146+
# MI lowering
147+
self._serialize_operator(
148+
node,
149+
tosa_graph,
150+
ts.TosaOp.Op().SUB,
151+
[inputs[0].name, inputs[1].name],
152+
[output.name],
153+
None,
154+
)

backends/arm/test/passes/test_insert_rescale_i32_pass.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@ def __init__(self):
2222
super().__init__()
2323

2424
def forward(self, x, y):
25-
a = x - y
26-
c = torch.maximum(a, y)
27-
d = torch.abs(c)
28-
e = d > c
29-
return e
25+
a = torch.maximum(x, y)
26+
b = torch.abs(a)
27+
c = a > b
28+
return c
3029

3130
def get_inputs(self, dtype) -> input_t:
3231
if dtype == torch.float32:
@@ -46,8 +45,8 @@ def test_insert_rescales():
4645
ops_not_before = {"executorch_exir_dialects_backend__ops_tosa_RESCALE_default"}
4746
ops_after = {
4847
# "number of op nodes with i8 output" + "number of i8 node inputs"
49-
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 3
50-
+ 7,
48+
"executorch_exir_dialects_backend__ops_tosa_RESCALE_default": 2
49+
+ 5,
5150
}
5251
pipeline = PassPipeline[input_t](
5352
module,

0 commit comments

Comments
 (0)