Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from .decompose_glu_pass import DecomposeGluPass # noqa
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
from .decompose_int16_activation_conv2d_pass import ( # noqa
DecomposeConv2dWithInt16ActivationPass,
)
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
Expand Down
5 changes: 5 additions & 0 deletions backends/arm/_passes/add_bias_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.transforms.utils import create_constant_placeholder

from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -59,6 +60,10 @@ def call(self, graph_module):
persistent_buffer=True,
name=f"{node.name}_bias",
)
if node.args[0].meta["val"].dtype == torch.int16:
bias_node.meta[TosaSpecialDtype.meta_key()] = (
TosaSpecialDtype.INT48
)
node.update_arg(2, bias_node)

if modified:
Expand Down
9 changes: 8 additions & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
DecomposeAtanPass,
DecomposeAvgPool2d,
DecomposeBatchNormNoStatsPass,
DecomposeConv2dWithInt16ActivationPass,
DecomposeCoshPass,
DecomposeCosineSimilarityPass,
DecomposeCumsumPass,
Expand Down Expand Up @@ -183,6 +184,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(ComputeConstantOpsAOT(exported_program))

self.add_pass(DecomposeGroupedConv())

self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
Expand All @@ -196,9 +198,14 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:

self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(InsertTableOpsPass(exported_program))
# If we have a conv2d with int16 activation split up into a convolution
# and an addition, to work-around the lack of support for int48 in torch
Copy link
Contributor

@digantdesai digantdesai Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and an addition, to work-around the lack of support for int48 in torch

Or can it be done by using torch.dtype.int64 instead and then detecting and lowering it as int48 downstream?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was starting of in that direction, but it interfere a bit with the int64->int32 handling, so rather keep it separate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah given int64 is treated as radioactive :P

# needs to happen before AddBiasPass, but after the table ops are inserted
# to be able to validate that conv2d has right dtype arguments.
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
self.add_pass(AddBiasPass(exported_program))

self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(RemoveNoopPass())
Expand Down
145 changes: 145 additions & 0 deletions backends/arm/_passes/decompose_int16_activation_conv2d_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import cast

import torch
from executorch.backends.arm._passes.quant_args import QuantArgs

from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class DecomposeConv2dWithInt16ActivationPass(ExportPass):
"""
This pass decomposes a convolution with input dtype int16 and bias
into a convolution without bias followed by an addition of the bias
since the TOSA op requires the bias to be int48 which is hard to represent
in torch. Instead rescale the int48 output to int16 and add the bias in int16.
"""

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.convolution.default:
return super().call_operator(op, args, kwargs, meta)

tosa_spec = get_context_spec()
if not tosa_spec.support_integer():
return super().call_operator(op, args, kwargs, meta)

# return if no bias
if args[2] is None:
return super().call_operator(op, args, kwargs, meta)

if args[0].data.dtype == torch.int8:
return super().call_operator(op, args, kwargs, meta)
elif args[0].data.dtype == torch.int16:
if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension(
"int16"
):
raise ValueError(
"int16 activation for convolution requires TOSA int16 extension"
)
else:
raise NotImplementedError(
"Decomposition to conv+add only implemented for activation of int16 type"
)

# convolution with bias and activation is int16
# The bias is assumed to be quantized with the same quantization parameters as
# as the output of the convolution
bias = args[2]
assert (
meta.data["output_qparams"][0].dtype == bias.data.dtype
), "Bias needs to have same type as quantized output type"
no_bias_args = list(args)
no_bias_args[2] = None
# split up to convolution + bias
convolution = super().call_operator(op, tuple(no_bias_args), kwargs, meta)

# create a copy of the meta without the qparams, to be used with the new nodes
new_meta = meta.copy()
new_meta.data.pop("output_qparams", None)
new_meta.data.pop("input_qparams", None)

# reshape the tensor to the same rank as the convolution output to add the bias to the channels
channel_bias = super().call_operator(
exir_ops.edge.aten.view_copy.default,
(bias, [1, len(bias.data), 1, 1]),
{},
new_meta,
)

output_dtype = meta.data["output_qparams"][0].dtype

if output_dtype == torch.int16:
# The conv will get the output int48 scaled to int32 in serialization step.
# To be able to add the bias we need to first scale (cast?) the output to int32.
# The resulting i32 sum will then need to be scaled back to the output dtype.

# calculate common rescale factor from convolution output and bias quantization
output_qparams = cast(QuantArgs, meta.data["output_qparams"][0])
conv_output_scale = output_qparams.scale
bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2])
bias_scale = bias_qparams.scale

common_scale = max(bias_scale, conv_output_scale)

# calculate how we can rescale bias and conv to a common scale and maximize the output range
bias_rescale_factor = bias_scale / common_scale
conv_rescale_factor = conv_output_scale / common_scale

# Either of conv output or bias now covers the full int16 range and the other one a smaller range.
# Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range.
# Worst case here is that both bias and conv output covers the full int16 range so we leave one bit
# and then one for the sign bit.
bits_left_to_shift = 14

# update rescale factors
bias_rescale_factor *= 1 << bits_left_to_shift
conv_rescale_factor *= 1 << bits_left_to_shift

conv_output = super().call_operator(
exir_ops.backend.tosa.RESCALE.default,
(convolution, torch.int32, conv_rescale_factor, 0, 0),
{},
new_meta,
)

bias_rescaled = super().call_operator(
exir_ops.backend.tosa.RESCALE.default,
(channel_bias, torch.int32, bias_rescale_factor, 0, 0),
{},
new_meta,
)

add = super().call_operator(
exir_ops.edge.aten.add.Tensor,
(conv_output, bias_rescaled),
{},
new_meta,
)

res_rescale = super().call_operator(
exir_ops.backend.tosa.RESCALE.default,
(
add,
output_dtype,
(common_scale / (conv_output_scale * (1 << bits_left_to_shift))),
0,
0,
),
{},
new_meta,
)

else:
raise NotImplementedError(
f"Decomposition to conv+add only implemented for activation of int16 type, not for {output_dtype}"
)

return res_rescale
7 changes: 7 additions & 0 deletions backends/arm/_passes/fuse_equal_placeholders_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from typing import Set, Type

import torch

from executorch.backends.arm._passes.arm_pass_utils import (
get_constant_placeholder_kind,
get_param_tensor,
is_param_node,
)
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.backends.transforms.utils import (
create_constant_placeholder,
delete_constant_placeholder,
Expand Down Expand Up @@ -47,9 +49,14 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
continue
# Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes
# Ensure tensor is on CPU and contiguous

# ensure we don't merge any special case int48_t tensors with int32_t tensors
# since int48_t tensors needs to be instantiated separately.
is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None)
t_cpu = tensor.detach().cpu().contiguous()
data_bytes = t_cpu.numpy().tobytes()
key = (
is_int48,
str(t_cpu.dtype),
tuple(t_cpu.shape),
hashlib.sha1(data_bytes).hexdigest(),
Expand Down
52 changes: 43 additions & 9 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_valid_dtype,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.quant_utils import build_rescale
from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification
from executorch.backends.arm.tosa.utils import tosa_shape


Expand Down Expand Up @@ -73,6 +74,32 @@ def define_node(
input, weight, bias, stride, pad, dilation, _, _, group = inputs
validate_num_inputs(self.target, inputs, 9)

valid_input_dtypes = []
if self.tosa_spec.support_float():
valid_input_dtypes.append(ts.DType.FP32)
if self.tosa_spec.support_integer():
valid_input_dtypes.append(ts.DType.INT8)

if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension(
"int16"
):
valid_input_dtypes.append(ts.DType.INT16)
# Check constraints for int16 activations
if inputs[0].dtype == ts.DType.INT16:
validate_valid_dtype(
self.target, [inputs[1]], [ts.DType.INT8], self.tosa_spec
)
validate_valid_dtype(
self.target, [inputs[2]], [ts.DType.INT48], self.tosa_spec
)

validate_valid_dtype(
self.target,
[inputs[0]],
valid_input_dtypes,
self.tosa_spec,
)

# Get the attributes of convolution.
attr = ts.TosaSerializerAttribute()
pad_attr = [val for val in pad.special for _ in (0, 1)]
Expand All @@ -97,8 +124,8 @@ def define_node(
)

input_zp = 0
if inputs[0].dtype == ts.DType.INT8:
# int8 input requires quantization information
if inputs[0].dtype in (ts.DType.INT8, ts.DType.INT16):
# int8 and int16 input requires quantization information
input_qparams = get_input_qparams(node)
input_zp = input_qparams[0].get_zp_per_tensor()

Expand All @@ -109,15 +136,22 @@ def define_node(
weight_zp = input_qparams[1].zp # type: ignore[assignment]

# The output type is int32 when input type is int8.
conv2d_output_name = output.name
if output.dtype == ts.DType.INT8:
if inputs[0].dtype == ts.DType.INT8:
conv2d_res = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
)
conv2d_output_name = conv2d_res.name
acc_type = (
inputs[0].dtype if inputs[0].dtype == ts.DType.FP32 else ts.DType.INT32
)
acc_type = ts.DType.INT32
elif inputs[0].dtype == ts.DType.INT16:
conv2d_res = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), ts.DType.INT48
)
conv2d_output_name = conv2d_res.name
acc_type = ts.DType.INT48
else:
conv2d_output_name = output.name
conv2d_res = output
acc_type = ts.DType.FP32

tosa_graph.addConst(
[1], output.dtype, [input_zp], name=f"{conv2d_output_name}_input_zp"
Expand Down Expand Up @@ -207,7 +241,7 @@ def define_node(

# For quantized convolution, rescale the output value back to the same
# integer value domain of the next op. Otherwise return float32 output.
if inputs[0].dtype == ts.DType.INT8:
if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16:
# Get scale_factor from input, weight, and output.
input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61]
per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61]
Expand Down
11 changes: 9 additions & 2 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import torch.fx
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype
from executorch.backends.arm.tosa.specification import TosaSpecification
from executorch.backends.arm.tosa.utils import tosa_shape
from torch._export.utils import (
Expand Down Expand Up @@ -112,10 +112,17 @@ def process_inputs_to_parameters(
if tosa_arg.dtype == torch.float32:
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"

# Handle special case for INT48 tensors
special_type = node.meta.get(TosaSpecialDtype.meta_key(), None)
if isinstance(special_type, TosaSpecialDtype):
tosa_dtype = special_type.get_tosa_dtype()
else:
tosa_dtype = tosa_arg.dtype

parameter_values = np.transpose(parameter_values, tosa_arg.dim_order)

tosa_graph.addConst(
parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name
parameter_values.shape, tosa_dtype, parameter_values, name=tosa_arg.name
)


Expand Down
Loading
Loading