Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ def call_operator(self, op, args, kwargs, meta):

conv_output = super().call_operator(
exir_ops.backend.tosa.RESCALE.default,
(convolution, torch.int32, conv_rescale_factor, 0, 0),
(convolution, torch.int32, [conv_rescale_factor], 0, 0),
{},
new_meta,
)

bias_rescaled = super().call_operator(
exir_ops.backend.tosa.RESCALE.default,
(channel_bias, torch.int32, bias_rescale_factor, 0, 0),
(channel_bias, torch.int32, [bias_rescale_factor], 0, 0),
{},
new_meta,
)
Expand All @@ -129,7 +129,7 @@ def call_operator(self, op, args, kwargs, meta):
(
add,
output_dtype,
(common_scale / (conv_output_scale * (1 << bits_left_to_shift))),
[(common_scale / (conv_output_scale * (1 << bits_left_to_shift)))],
0,
0,
),
Expand Down
16 changes: 9 additions & 7 deletions backends/arm/_passes/insert_rescales_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule
(
node.all_input_nodes[0],
q_args.dtype,
new_scale,
[new_scale],
dq_args.zp,
q_args.zp,
),
Expand Down Expand Up @@ -228,10 +228,10 @@ def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> b
(
arg_node,
torch.int32,
qp.get_scale_per_tensor()
/ rescale_qargs[
i
].get_scale_per_tensor(), # Old scale / new scale
[
qp.get_scale_per_tensor()
/ rescale_qargs[i].get_scale_per_tensor()
], # [Old scale / new scale]
qp.get_zp_per_tensor(), # Old zero point
rescale_qargs[i].get_zp_per_tensor(), # New zero point
),
Expand Down Expand Up @@ -264,8 +264,10 @@ def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> b
(
node,
qarg.dtype,
rescale_qargs.get_scale_per_tensor()
/ qarg.get_scale_per_tensor(), # Old scale / new scale
[
rescale_qargs.get_scale_per_tensor()
/ qarg.get_scale_per_tensor()
], # [Old scale / new scale]
rescale_qargs.get_zp_per_tensor(), # Old zero point
qarg.get_zp_per_tensor(), # New zero point
),
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
rescale_node = create_node(
graph=graph_module.graph,
op_target=exir_ops.backend.tosa.RESCALE.default,
args=(table_op_node, output_qparams[0].dtype, scale, 0, 0),
args=(table_op_node, output_qparams[0].dtype, [scale], 0, 0),
)
output_node = rescale_node

Expand Down
76 changes: 68 additions & 8 deletions backends/arm/_passes/rewrite_conv2d_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import itertools
from typing import Set, Type

import torch
Expand All @@ -16,6 +17,10 @@
is_buffer,
is_param,
)
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
)
from executorch.backends.arm.constants import HWCM_ORDER, NHWC_INVERSE_ORDER
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.transforms.utils import create_constant_placeholder
Expand Down Expand Up @@ -156,6 +161,40 @@ def _add_bias(
node.update_arg(2, bias_node)
return bias_node

def insert_output_rescale(self, graph_module, node):
input_qparams = get_input_qparams(node)
output_qparams = get_output_qparams(node)[0]
weight_qparams = input_qparams[1]
input_qparams = input_qparams[0]
is_per_channel = weight_qparams.per_channel
if is_per_channel:
weight_scale = weight_qparams.get_scale_per_channel()
else:
weight_scale = [weight_qparams.get_scale_per_tensor()]
input_scale = input_qparams.get_scale_per_tensor()
post_conv2d_scale = [
(inp * w) / out
for inp, w, out in zip(
itertools.cycle([input_scale]),
weight_scale,
itertools.cycle([output_qparams.get_scale_per_tensor()]),
)
]
with graph_module.graph.inserting_after(node):
rescale_node = create_node(
graph=graph_module.graph,
op_target=exir_ops.backend.tosa.RESCALE.default,
args=(
node,
output_qparams.dtype,
post_conv2d_scale,
0,
output_qparams.get_zp_per_tensor(),
),
from_node=node,
)
return rescale_node

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = False
for node in graph_module.graph.nodes:
Expand All @@ -180,20 +219,20 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
) = node.args

pad = [val for val in pad for _ in (0, 1)]
input_shape = get_first_fake_tensor(x).shape
weight_shape = get_first_fake_tensor(weight).shape
input_fake_tensor = get_first_fake_tensor(x)
weight_fake_tensor = get_first_fake_tensor(weight)
# Adjust the pad value if needed to meet the
# strict convolution output shape calculation.
pad[1] = self._adjust_pad_if_needed(
input_shape[2],
weight_shape[2],
input_fake_tensor.shape[2],
weight_fake_tensor.shape[2],
stride[0],
pad[1],
dilation[0],
)
pad[3] = self._adjust_pad_if_needed(
input_shape[3],
weight_shape[3],
input_fake_tensor.shape[3],
weight_fake_tensor.shape[3],
stride[1],
pad[3],
dilation[1],
Expand All @@ -204,7 +243,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:

if self._is_depthwise_conv2d(node):
target_op = exir_ops.backend.tosa.DEPTHWISE_CONV2D.default
self._reshape_weights(weight, input_shape[1])
self._reshape_weights(weight, input_fake_tensor.shape[1])
weight_fake_tensor = get_first_fake_tensor(weight)
else:
target_op = exir_ops.backend.tosa.CONV2D.default

Expand All @@ -227,9 +267,29 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
args=conv2d_args,
from_node=node,
)
bias_fake_tensor = get_first_fake_tensor(bias) if bias else None
tosa_node_fake_tensor = target_op(
input_fake_tensor,
weight_fake_tensor,
bias_fake_tensor,
*conv2d_args[3:],
)

if (
tosa_node_fake_tensor.dtype == torch.int32
and input_fake_tensor.dtype == torch.int8
) or (
tosa_node_fake_tensor.dtype == torch.int32
and input_fake_tensor.dtype == torch.int16
):
output_rescale = self.insert_output_rescale(graph_module, tosa_op)
node.replace_all_uses_with(output_rescale)
if input_fake_tensor.dtype == torch.int16:
tosa_op.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.INT48
else:
node.replace_all_uses_with(tosa_op)
graph_module.graph.erase_node(node)

graph_module.graph.erase_node(node)

if modified:
graph_module.recompile()
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/rewrite_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _insert_output_rescale(self, graph_module, node, tosa_matmul_node, dtype):
rescale_node.args = (
tosa_matmul_node,
dtype,
scale,
[scale],
0,
output_qparams.get_zp_per_tensor(),
)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/rewrite_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def call(self, graph_module):
rescale_node.args = (
tosa_resize_node,
output_dtype,
output_scale,
[output_scale],
0, # zero point
0, # zero point
)
Expand Down
59 changes: 4 additions & 55 deletions backends/arm/operators/op_tosa_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@

"""Provide a visitor for lowering 2D convolution to TOSA (INT/FP)."""

import itertools
from typing import Any, List

import torch

from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
)
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
Expand All @@ -26,9 +24,7 @@
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.quant_utils import build_rescale
from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification
from executorch.backends.arm.tosa.utils import tosa_shape


@register_node_visitor
Expand Down Expand Up @@ -58,7 +54,8 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
"""Define the TOSA CONV2D/DEPTHWISE_CONV2D operator and post-rescale."""
"""Define the TOSA CONV2D/DEPTHWISE_CONV2D operator."""

input, weight, bias, stride, pad, dilation, _, _, group = inputs
validate_num_inputs(self.target, inputs, 9)

Expand Down Expand Up @@ -105,23 +102,8 @@ def define_node(
input_qparams = get_input_qparams(node)
weight_zp = input_qparams[1].zp # type: ignore[assignment]

# The output type is int32 when input type is int8.
if inputs[0].dtype == ts.DType.INT8:
conv2d_res = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
)
conv2d_output_name = conv2d_res.name
acc_type = ts.DType.INT32
elif inputs[0].dtype == ts.DType.INT16:
conv2d_res = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), ts.DType.INT48
)
conv2d_output_name = conv2d_res.name
acc_type = ts.DType.INT48
else:
conv2d_output_name = output.name
conv2d_res = output
acc_type = ts.DType.FP32
conv2d_output_name = output.name
acc_type = output.dtype

tosa_graph.addConst(
[1], inputs[0].dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
Expand Down Expand Up @@ -158,36 +140,3 @@ def define_node(
[conv2d_output_name],
attr,
)

# For quantized convolution, rescale the output value back to the same
# integer value domain of the next op. Otherwise return float32 output.
if output.dtype == ts.DType.INT8 or output.dtype == ts.DType.INT16:
# Get scale_factor from input, weight, and output.
input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61]
per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61]
if per_channel_quant:
weight_scale = input_qparams[1].get_scale_per_channel()
else:
weight_scale = [
input_qparams[1].get_scale_per_tensor()
] # pyre-ignore [61]
output_qargs = get_output_qparams(node)
post_conv2d_scale = [
(inp * w) / out
for inp, w, out in zip(
itertools.cycle([input_scale]),
weight_scale,
itertools.cycle([output_qargs[0].get_scale_per_tensor()]),
)
]
build_rescale(
tosa_fb=tosa_graph,
scale=post_conv2d_scale,
input_node=conv2d_res, # type: ignore[possibly-undefined]
output_name=output.name,
output_type=output.dtype,
input_zp=[0],
output_zp=[output_qargs[0].get_zp_per_tensor()],
per_channel=per_channel_quant,
rounding_mode=ts.RoundingMode.SINGLE_ROUND,
)
4 changes: 4 additions & 0 deletions backends/arm/operators/op_tosa_depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

"""Provide a visitor for lowering 2D depthwise convolution to TOSA (INT/FP)."""

import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import register_node_visitor
from executorch.backends.arm.operators.op_tosa_conv2d import Conv2dVisitor
from executorch.backends.arm.tosa import TosaSpecification
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/operators/op_tosa_rescale.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def define_node(

input_dtype = inputs[0].dtype
output_dtype = cast(torch.dtype, node.args[1])
scale = cast(float, node.args[2])
scales = cast(list[float], node.args[2])
input_zp = cast(int, node.args[3])
output_zp = cast(int, node.args[4])

Expand All @@ -63,12 +63,12 @@ def define_node(

build_rescale(
tosa_graph,
scale=[scale],
scale=scales,
input_node=inputs[0],
output_name=output.name,
output_type=output.dtype,
input_zp=[input_zp],
output_zp=[output_zp],
rounding_mode=ts.RoundingMode.SINGLE_ROUND,
per_channel=False,
per_channel=len(scales) > 1,
)
4 changes: 2 additions & 2 deletions backends/arm/test/misc/test_tosa_dialect_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_conv2d_tosa_INT():
4,
),
(1, 8, 20, 20),
torch.int8,
torch.int32,
),
(
(
Expand All @@ -46,7 +46,7 @@ def test_conv2d_tosa_INT():
4,
),
(1, 4, 10, 10),
torch.int8,
torch.int32,
),
]

Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/misc/test_tosa_dialect_dw_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_depthwise_conv2d_tosa_INT():
8,
),
(1, 16, 20, 20),
torch.int8,
torch.int32,
),
(
(
Expand All @@ -48,7 +48,7 @@ def test_depthwise_conv2d_tosa_INT():
8,
),
(1, 32, 10, 10),
torch.int8,
torch.int32,
),
]

Expand Down
Loading
Loading