Skip to content

Commit 53bb98b

Browse files
committed
Revert "Arm backend: Support per-channel in TOSA.RESCALE (#15267)"
This reverts commit f7ca57e.
1 parent a3ff326 commit 53bb98b

14 files changed

+94
-110
lines changed

backends/arm/_passes/decompose_int16_activation_conv2d_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ def call_operator(self, op, args, kwargs, meta):
105105

106106
conv_output = super().call_operator(
107107
exir_ops.backend.tosa.RESCALE.default,
108-
(convolution, torch.int32, [conv_rescale_factor], 0, 0),
108+
(convolution, torch.int32, conv_rescale_factor, 0, 0),
109109
{},
110110
new_meta,
111111
)
112112

113113
bias_rescaled = super().call_operator(
114114
exir_ops.backend.tosa.RESCALE.default,
115-
(channel_bias, torch.int32, [bias_rescale_factor], 0, 0),
115+
(channel_bias, torch.int32, bias_rescale_factor, 0, 0),
116116
{},
117117
new_meta,
118118
)
@@ -129,7 +129,7 @@ def call_operator(self, op, args, kwargs, meta):
129129
(
130130
add,
131131
output_dtype,
132-
[(common_scale / (conv_output_scale * (1 << bits_left_to_shift)))],
132+
(common_scale / (conv_output_scale * (1 << bits_left_to_shift))),
133133
0,
134134
0,
135135
),

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule
4545
(
4646
node.all_input_nodes[0],
4747
q_args.dtype,
48-
[new_scale],
48+
new_scale,
4949
dq_args.zp,
5050
q_args.zp,
5151
),
@@ -228,10 +228,10 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b
228228
(
229229
arg_node,
230230
torch.int32,
231-
[
232-
qp.get_scale_per_tensor()
233-
/ rescale_qargs[i].get_scale_per_tensor()
234-
], # [Old scale / new scale]
231+
qp.get_scale_per_tensor()
232+
/ rescale_qargs[
233+
i
234+
].get_scale_per_tensor(), # Old scale / new scale
235235
qp.get_zp_per_tensor(), # Old zero point
236236
rescale_qargs[i].get_zp_per_tensor(), # New zero point
237237
),
@@ -264,10 +264,8 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
264264
(
265265
node,
266266
qarg.dtype,
267-
[
268-
rescale_qargs.get_scale_per_tensor()
269-
/ qarg.get_scale_per_tensor()
270-
], # [Old scale / new scale]
267+
rescale_qargs.get_scale_per_tensor()
268+
/ qarg.get_scale_per_tensor(), # Old scale / new scale
271269
rescale_qargs.get_zp_per_tensor(), # Old zero point
272270
qarg.get_zp_per_tensor(), # New zero point
273271
),

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
286286
rescale_node = create_node(
287287
graph=graph_module.graph,
288288
op_target=exir_ops.backend.tosa.RESCALE.default,
289-
args=(table_op_node, output_qparams[0].dtype, [scale], 0, 0),
289+
args=(table_op_node, output_qparams[0].dtype, scale, 0, 0),
290290
)
291291
output_node = rescale_node
292292

backends/arm/_passes/rewrite_conv2d_pass.py

Lines changed: 8 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7-
import itertools
87
from typing import Set, Type
98

109
import torch
@@ -17,10 +16,6 @@
1716
is_buffer,
1817
is_param,
1918
)
20-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
21-
get_input_qparams,
22-
get_output_qparams,
23-
)
2419
from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER
2520
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2621
from executorch.backends.transforms.utils import create_constant_placeholder
@@ -161,40 +156,6 @@ def _add_bias(
161156
node.update_arg(2, bias_node)
162157
return bias_node
163158

164-
def insert_output_rescale(self, graph_module, node):
165-
input_qparams = get_input_qparams(node)
166-
output_qparams = get_output_qparams(node)[0]
167-
weight_qparams = input_qparams[1]
168-
input_qparams = input_qparams[0]
169-
is_per_channel = weight_qparams.per_channel
170-
if is_per_channel:
171-
weight_scale = weight_qparams.get_scale_per_channel()
172-
else:
173-
weight_scale = [weight_qparams.get_scale_per_tensor()]
174-
input_scale = input_qparams.get_scale_per_tensor()
175-
post_conv2d_scale = [
176-
(inp * w) / out
177-
for inp, w, out in zip(
178-
itertools.cycle([input_scale]),
179-
weight_scale,
180-
itertools.cycle([output_qparams.get_scale_per_tensor()]),
181-
)
182-
]
183-
with graph_module.graph.inserting_after(node):
184-
rescale_node = create_node(
185-
graph=graph_module.graph,
186-
op_target=exir_ops.backend.tosa.RESCALE.default,
187-
args=(
188-
node,
189-
output_qparams.dtype,
190-
post_conv2d_scale,
191-
0,
192-
output_qparams.get_zp_per_tensor(),
193-
),
194-
from_node=node,
195-
)
196-
return rescale_node
197-
198159
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
199160
modified = False
200161
for node in graph_module.graph.nodes:
@@ -219,20 +180,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
219180
) = node.args
220181

221182
pad = [val for val in pad for _ in (0, 1)]
222-
input_fake_tensor = get_first_fake_tensor(x)
223-
weight_fake_tensor = get_first_fake_tensor(weight)
183+
input_shape = get_first_fake_tensor(x).shape
184+
weight_shape = get_first_fake_tensor(weight).shape
224185
# Adjust the pad value if needed to meet the
225186
# strict convolution output shape calculation.
226187
pad[1] = self._adjust_pad_if_needed(
227-
input_fake_tensor.shape[2],
228-
weight_fake_tensor.shape[2],
188+
input_shape[2],
189+
weight_shape[2],
229190
stride[0],
230191
pad[1],
231192
dilation[0],
232193
)
233194
pad[3] = self._adjust_pad_if_needed(
234-
input_fake_tensor.shape[3],
235-
weight_fake_tensor.shape[3],
195+
input_shape[3],
196+
weight_shape[3],
236197
stride[1],
237198
pad[3],
238199
dilation[1],
@@ -243,8 +204,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
243204

244205
if self._is_depthwise_conv2d(node):
245206
target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default
246-
self._reshape_weights(weight, input_fake_tensor.shape[1])
247-
weight_fake_tensor = get_first_fake_tensor(weight)
207+
self._reshape_weights(weight, input_shape[1])
248208
else:
249209
target_op = exir_ops.backend.tosa.CONV2D.default
250210

@@ -267,29 +227,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
267227
args=conv2d_args,
268228
from_node=node,
269229
)
270-
bias_fake_tensor = get_first_fake_tensor(bias) if bias else None
271-
tosa_node_fake_tensor = target_op(
272-
input_fake_tensor,
273-
weight_fake_tensor,
274-
bias_fake_tensor,
275-
*conv2d_args[3:],
276-
)
277230

278-
if (
279-
tosa_node_fake_tensor.dtype == torch.int32
280-
and input_fake_tensor.dtype == torch.int8
281-
) or (
282-
tosa_node_fake_tensor.dtype == torch.int32
283-
and input_fake_tensor.dtype == torch.int16
284-
):
285-
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
286-
node.replace_all_uses_with(output_rescale)
287-
if input_fake_tensor.dtype == torch.int16:
288-
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
289-
else:
290231
node.replace_all_uses_with(tosa_op)
291-
292-
graph_module.graph.erase_node(node)
232+
graph_module.graph.erase_node(node)
293233

294234
if modified:
295235
graph_module.recompile()

backends/arm/_passes/rewrite_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _insert_output_rescale(self, graph_module, node, tosa_matmul_node, dtype):
4444
rescale_node.args = (
4545
tosa_matmul_node,
4646
dtype,
47-
[scale],
47+
scale,
4848
0,
4949
output_qparams.get_zp_per_tensor(),
5050
)

backends/arm/_passes/rewrite_upsample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def call(self, graph_module):
7474
rescale_node.args = (
7575
tosa_resize_node,
7676
output_dtype,
77-
[output_scale],
77+
output_scale,
7878
0, # zero point
7979
0, # zero point
8080
)

backends/arm/operators/op_tosa_conv2d.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88

99
"""Provide a visitor for lowering 2D convolution to TOSA (INT/FP)."""
1010

11+
import itertools
1112
from typing import Any, List
1213

1314
import torch
1415

1516
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1617
get_input_qparams,
18+
get_output_qparams,
1719
)
1820
from executorch.backends.arm.operators.node_visitor import (
1921
NodeVisitor,
@@ -24,7 +26,9 @@
2426
validate_valid_dtype,
2527
)
2628
from executorch.backends.arm.tosa.mapping import TosaArg
29+
from executorch.backends.arm.tosa.quant_utils import build_rescale
2730
from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification
31+
from executorch.backends.arm.tosa.utils import tosa_shape
2832

2933

3034
@register_node_visitor
@@ -54,8 +58,7 @@ def define_node(
5458
inputs: List[TosaArg],
5559
output: TosaArg,
5660
) -> None:
57-
"""Define the TOSA CONV2D/DEPTHWISE_CONV2D operator."""
58-
61+
"""Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale."""
5962
input, weight, bias, stride, pad, dilation, _, _, group = inputs
6063
validate_num_inputs(self.target, inputs, 9)
6164

@@ -102,8 +105,23 @@ def define_node(
102105
input_qparams = get_input_qparams(node)
103106
weight_zp = input_qparams[1].zp # type: ignore[assignment]
104107

105-
conv2d_output_name = output.name
106-
acc_type = output.dtype
108+
# The output type is int32 when input type is int8.
109+
if inputs[0].dtype == ts.DType.INT8:
110+
conv2d_res = tosa_graph.addIntermediate(
111+
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
112+
)
113+
conv2d_output_name = conv2d_res.name
114+
acc_type = ts.DType.INT32
115+
elif inputs[0].dtype == ts.DType.INT16:
116+
conv2d_res = tosa_graph.addIntermediate(
117+
tosa_shape(output.shape, output.dim_order), ts.DType.INT48
118+
)
119+
conv2d_output_name = conv2d_res.name
120+
acc_type = ts.DType.INT48
121+
else:
122+
conv2d_output_name = output.name
123+
conv2d_res = output
124+
acc_type = ts.DType.FP32
107125

108126
tosa_graph.addConst(
109127
[1], inputs[0].dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
@@ -140,3 +158,36 @@ def define_node(
140158
[conv2d_output_name],
141159
attr,
142160
)
161+
162+
# For quantized convolution, rescale the output value back to the same
163+
# integer value domain of the next op. Otherwise return float32 output.
164+
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
165+
# Get scale_factor from input, weight, and output.
166+
input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61]
167+
per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61]
168+
if per_channel_quant:
169+
weight_scale = input_qparams[1].get_scale_per_channel()
170+
else:
171+
weight_scale = [
172+
input_qparams[1].get_scale_per_tensor()
173+
] # pyre-ignore [61]
174+
output_qargs = get_output_qparams(node)
175+
post_conv2d_scale = [
176+
(inp * w) / out
177+
for inp, w, out in zip(
178+
itertools.cycle([input_scale]),
179+
weight_scale,
180+
itertools.cycle([output_qargs[0].get_scale_per_tensor()]),
181+
)
182+
]
183+
build_rescale(
184+
tosa_fb=tosa_graph,
185+
scale=post_conv2d_scale,
186+
input_node=conv2d_res, # type: ignore[possibly-undefined]
187+
output_name=output.name,
188+
output_type=output.dtype,
189+
input_zp=[0],
190+
output_zp=[output_qargs[0].get_zp_per_tensor()],
191+
per_channel=per_channel_quant,
192+
rounding_mode=ts.RoundingMode.SINGLE_ROUND,
193+
)

backends/arm/operators/op_tosa_depthwise_conv2d.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
8-
"""Provide a visitor for lowering 2D depthwise convolution to TOSA (INT/FP)."""
9-
107
import tosa_serializer as ts
11-
128
from executorch.backends.arm.operators.node_visitor import register_node_visitor
139
from executorch.backends.arm.operators.op_tosa_conv2d import Conv2dVisitor
1410
from executorch.backends.arm.tosa import TosaSpecification

backends/arm/operators/op_tosa_rescale.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def define_node(
4141

4242
input_dtype = inputs[0].dtype
4343
output_dtype = cast(torch.dtype, node.args[1])
44-
scales = cast(list[float], node.args[2])
44+
scale = cast(float, node.args[2])
4545
input_zp = cast(int, node.args[3])
4646
output_zp = cast(int, node.args[4])
4747

@@ -63,12 +63,12 @@ def define_node(
6363

6464
build_rescale(
6565
tosa_graph,
66-
scale=scales,
66+
scale=[scale],
6767
input_node=inputs[0],
6868
output_name=output.name,
6969
output_type=output.dtype,
7070
input_zp=[input_zp],
7171
output_zp=[output_zp],
7272
rounding_mode=ts.RoundingMode.SINGLE_ROUND,
73-
per_channel=len(scales) > 1,
73+
per_channel=False,
7474
)

backends/arm/test/misc/test_tosa_dialect_conv2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def test_conv2d_tosa_INT():
3131
4,
3232
),
3333
(1, 8, 20, 20),
34-
torch.int32,
34+
torch.int8,
3535
),
3636
(
3737
(
@@ -46,7 +46,7 @@ def test_conv2d_tosa_INT():
4646
4,
4747
),
4848
(1, 4, 10, 10),
49-
torch.int32,
49+
torch.int8,
5050
),
5151
]
5252

0 commit comments

Comments
 (0)