From bc0f1344c44275f875ab22469a68c0e51c3a6e80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 14 Aug 2025 19:01:30 +0200 Subject: [PATCH 1/6] Arm backend: Handle 16 bit activation for conv2d MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Per Åstrand Change-Id: I2b189b559f699c7eda6921ed515c0e8a849226ca --- backends/arm/operators/op_conv2d.py | 52 ++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 6bfe0ab21eb..41e422b5504 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -19,10 +19,11 @@ ) from executorch.backends.arm.operators.operator_validation_utils import ( validate_num_inputs, + validate_valid_dtype, ) -from executorch.backends.arm.tosa import TosaSpecification from executorch.backends.arm.tosa.mapping import TosaArg from executorch.backends.arm.tosa.quant_utils import build_rescale +from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification from executorch.backends.arm.tosa.utils import tosa_shape @@ -73,6 +74,32 @@ def define_node( input, weight, bias, stride, pad, dilation, _, _, group = inputs validate_num_inputs(self.target, inputs, 9) + valid_input_dtypes = [] + if self.tosa_spec.support_float(): + valid_input_dtypes.append(ts.DType.FP32) + if self.tosa_spec.support_integer(): + valid_input_dtypes.append(ts.DType.INT8) + + if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension( + "int16" + ): + valid_input_dtypes.append(ts.DType.INT16) + # Check constraints for int16 activations + if inputs[0].dtype == ts.DType.INT16: + validate_valid_dtype( + self.target, [inputs[1]], [ts.DType.INT8], self.tosa_spec + ) + validate_valid_dtype( + self.target, [inputs[2]], [ts.DType.INT48], self.tosa_spec + ) + + validate_valid_dtype( + self.target, + [inputs[0]], + valid_input_dtypes, + self.tosa_spec, + ) + # Get the attributes of convolution. attr = ts.TosaSerializerAttribute() pad_attr = [val for val in pad.special for _ in (0, 1)] @@ -97,8 +124,8 @@ def define_node( ) input_zp = 0 - if inputs[0].dtype == ts.DType.INT8: - # int8 input requires quantization information + if inputs[0].dtype in (ts.DType.INT8, ts.DType.INT16): + # int8 and int16 input requires quantization information input_qparams = get_input_qparams(node) input_zp = input_qparams[0].get_zp_per_tensor() @@ -109,15 +136,22 @@ def define_node( weight_zp = input_qparams[1].zp # type: ignore[assignment] # The output type is int32 when input type is int8. - conv2d_output_name = output.name - if output.dtype == ts.DType.INT8: + if inputs[0].dtype == ts.DType.INT8: conv2d_res = tosa_graph.addIntermediate( tosa_shape(output.shape, output.dim_order), ts.DType.INT32 ) conv2d_output_name = conv2d_res.name - acc_type = ( - inputs[0].dtype if inputs[0].dtype == ts.DType.FP32 else ts.DType.INT32 - ) + acc_type = ts.DType.INT32 + elif inputs[0].dtype == ts.DType.INT16: + conv2d_res = tosa_graph.addIntermediate( + tosa_shape(output.shape, output.dim_order), ts.DType.INT48 + ) + conv2d_output_name = conv2d_res.name + acc_type = ts.DType.INT48 + else: + conv2d_output_name = output.name + conv2d_res = output + acc_type = ts.DType.FP32 tosa_graph.addConst( [1], output.dtype, [input_zp], name=f"{conv2d_output_name}_input_zp" @@ -207,7 +241,7 @@ def define_node( # For quantized convolution, rescale the output value back to the same # integer value domain of the next op. Otherwise return float32 output. - if inputs[0].dtype == ts.DType.INT8: + if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16: # Get scale_factor from input, weight, and output. input_scale = input_qparams[0].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore [61] per_channel_quant = input_qparams[1].per_channel # pyre-ignore [61] From 60033f9ff59ed1c8961b58bc81b06247782b44e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 27 Feb 2025 15:50:42 +0100 Subject: [PATCH 2/6] Arm backend: Decompose conv2d with 16 bit activation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Support quantization to 16a8w. Since the resulting TOSA operator needs to have the bias in int48 which isn't avaiable as a type in torch, the conv2d needs to be decomposed into a conv + add, where the conv result is scaled down to 32 bit before the addition of the bias is done. Signed-off-by: Per Åstrand Change-Id: Ib8cae694035796374a55a9909e501596e983abf5 --- backends/arm/_passes/__init__.py | 3 + backends/arm/_passes/arm_pass_manager.py | 9 +- .../decompose_int16_activation_conv2d_pass.py | 145 ++++++++++++++++++ backends/arm/quantizer/quantization_config.py | 63 +++++--- 4 files changed, 197 insertions(+), 23 deletions(-) create mode 100644 backends/arm/_passes/decompose_int16_activation_conv2d_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index f9e23f73cc5..a5d8e17f0cd 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -46,6 +46,9 @@ from .decompose_glu_pass import DecomposeGluPass # noqa from .decompose_grouped_conv import DecomposeGroupedConv # noqa from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa +from .decompose_int16_activation_conv2d_pass import ( # noqa + DecomposeConv2dWithInt16ActivationPass, +) from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index c6530357f3b..70470890317 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -42,6 +42,7 @@ DecomposeAtanPass, DecomposeAvgPool2d, DecomposeBatchNormNoStatsPass, + DecomposeConv2dWithInt16ActivationPass, DecomposeCoshPass, DecomposeCosineSimilarityPass, DecomposeCumsumPass, @@ -183,6 +184,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(ComputeConstantOpsAOT(exported_program)) self.add_pass(DecomposeGroupedConv()) + self.add_pass(ConvertExpandCopyToRepeatPass()) self.add_pass(UnsqueezeBeforeRepeatPass()) self.add_pass(CastInt64BuffersToInt32Pass(exported_program)) @@ -196,9 +198,14 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule: self.add_pass(FuseViewCopyTransform()) self.add_pass(FuseConstantArgsPass(exported_program)) + self.add_pass(InsertTableOpsPass(exported_program)) + # If we have a conv2d with int16 activation split up into a convolution + # and an addition, to work-around the lack of support for int48 in torch + # needs to happen before AddBiasPass, but after the table ops are inserted + # to be able to validate that conv2d has right dtype arguments. + self.add_pass(DecomposeConv2dWithInt16ActivationPass()) self.add_pass(AddBiasPass(exported_program)) - self.add_pass(InsertTableOpsPass(exported_program)) self.add_pass(FuseEqualPlaceholdersPass(exported_program)) self.add_pass(ToTosaMemoryFormatPass(exported_program)) self.add_pass(RemoveNoopPass()) diff --git a/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py new file mode 100644 index 00000000000..d43c2a8c89c --- /dev/null +++ b/backends/arm/_passes/decompose_int16_activation_conv2d_pass.py @@ -0,0 +1,145 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from typing import cast + +import torch +from executorch.backends.arm._passes.quant_args import QuantArgs + +from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00 +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class DecomposeConv2dWithInt16ActivationPass(ExportPass): + """ + This pass decomposes a convolution with input dtype int16 and bias + into a convolution without bias followed by an addition of the bias + since the TOSA op requires the bias to be int48 which is hard to represent + in torch. Instead rescale the int48 output to int16 and add the bias in int16. + """ + + def call_operator(self, op, args, kwargs, meta): + if op != exir_ops.edge.aten.convolution.default: + return super().call_operator(op, args, kwargs, meta) + + tosa_spec = get_context_spec() + if not tosa_spec.support_integer(): + return super().call_operator(op, args, kwargs, meta) + + # return if no bias + if args[2] is None: + return super().call_operator(op, args, kwargs, meta) + + if args[0].data.dtype == torch.int8: + return super().call_operator(op, args, kwargs, meta) + elif args[0].data.dtype == torch.int16: + if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension( + "int16" + ): + raise ValueError( + "int16 activation for convolution requires TOSA int16 extension" + ) + else: + raise NotImplementedError( + "Decomposition to conv+add only implemented for activation of int16 type" + ) + + # convolution with bias and activation is int16 + # The bias is assumed to be quantized with the same quantization parameters as + # as the output of the convolution + bias = args[2] + assert ( + meta.data["output_qparams"][0].dtype == bias.data.dtype + ), "Bias needs to have same type as quantized output type" + no_bias_args = list(args) + no_bias_args[2] = None + # split up to convolution + bias + convolution = super().call_operator(op, tuple(no_bias_args), kwargs, meta) + + # create a copy of the meta without the qparams, to be used with the new nodes + new_meta = meta.copy() + new_meta.data.pop("output_qparams", None) + new_meta.data.pop("input_qparams", None) + + # reshape the tensor to the same rank as the convolution output to add the bias to the channels + channel_bias = super().call_operator( + exir_ops.edge.aten.view_copy.default, + (bias, [1, len(bias.data), 1, 1]), + {}, + new_meta, + ) + + output_dtype = meta.data["output_qparams"][0].dtype + + if output_dtype == torch.int16: + # The conv will get the output int48 scaled to int32 in serialization step. + # To be able to add the bias we need to first scale (cast?) the output to int32. + # The resulting i32 sum will then need to be scaled back to the output dtype. + + # calculate common rescale factor from convolution output and bias quantization + output_qparams = cast(QuantArgs, meta.data["output_qparams"][0]) + conv_output_scale = output_qparams.scale + bias_qparams = cast(QuantArgs, meta.data["input_qparams"][2]) + bias_scale = bias_qparams.scale + + common_scale = max(bias_scale, conv_output_scale) + + # calculate how we can rescale bias and conv to a common scale and maximize the output range + bias_rescale_factor = bias_scale / common_scale + conv_rescale_factor = conv_output_scale / common_scale + + # Either of conv output or bias now covers the full int16 range and the other one a smaller range. + # Since we are upscaling to int32 we have 16 additional bits to work with to maximize the output range. + # Worst case here is that both bias and conv output covers the full int16 range so we leave one bit + # and then one for the sign bit. + bits_left_to_shift = 14 + + # update rescale factors + bias_rescale_factor *= 1 << bits_left_to_shift + conv_rescale_factor *= 1 << bits_left_to_shift + + conv_output = super().call_operator( + exir_ops.backend.tosa.RESCALE.default, + (convolution, torch.int32, conv_rescale_factor, 0, 0), + {}, + new_meta, + ) + + bias_rescaled = super().call_operator( + exir_ops.backend.tosa.RESCALE.default, + (channel_bias, torch.int32, bias_rescale_factor, 0, 0), + {}, + new_meta, + ) + + add = super().call_operator( + exir_ops.edge.aten.add.Tensor, + (conv_output, bias_rescaled), + {}, + new_meta, + ) + + res_rescale = super().call_operator( + exir_ops.backend.tosa.RESCALE.default, + ( + add, + output_dtype, + (common_scale / (conv_output_scale * (1 << bits_left_to_shift))), + 0, + 0, + ), + {}, + new_meta, + ) + + else: + raise NotImplementedError( + f"Decomposition to conv+add only implemented for activation of int16 type, not for {output_dtype}" + ) + + return res_rescale diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index d5c3aab1060..29af10dfd1d 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -89,29 +89,48 @@ def _derive_qparams_fn( torch.ops.aten.linear.default, torch.ops.aten.conv2d.padding, ]: - input_act = node.args[0] - weight = node.args[1] - # If the weights are quantized per_tensor, do the same with bias - qscheme = ( - torch.per_tensor_symmetric - if self.weight is None - else self.weight.qscheme - ) - ch_axis = None - if self.weight is not None: - if qscheme == torch.per_channel_symmetric: - ch_axis = self.weight.ch_axis + if self.input_activation is None or self.weight is None: + raise ValueError( + "Input activation and weight QuantizationConfig must be specified." + ) + if self.input_activation.dtype == self.weight.dtype == torch.int8: + # This is the default int8 quantization which uses the derived quantization + # calculated from the activation and weight scale + input_act = node.args[0] + weight = node.args[1] - quantization_spec = DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item] - derive_qparams_fn=_derive_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max - 1, - qscheme=qscheme, - ch_axis=ch_axis, - ) - return quantization_spec # type: ignore[return-value] + # If the weights are quantized per_tensor, do the same with bias + qscheme = ( + torch.per_tensor_symmetric + if self.weight is None + else self.weight.qscheme + ) + ch_axis = None + if self.weight is not None: + if qscheme == torch.per_channel_symmetric: + ch_axis = self.weight.ch_axis + + quantization_spec = DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item] + derive_qparams_fn=_derive_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max - 1, + qscheme=qscheme, + ch_axis=ch_axis, + ) + return quantization_spec # type: ignore[return-value] + elif ( + self.input_activation.dtype == torch.int16 + and self.weight.dtype == torch.int8 + ): + # In case the activation is quantized to int16, the bias needs to be + # added after the convolution, so use the output quantization for this case. + return self.output_activation + else: + raise NotImplementedError( + f"Bias quantization of types: i:{self.input_activation.dtype}, w:{self.weight.dtype} not implemented" + ) if self.bias is None: return None From edc1c42a5e040b274a73e57e79ba352f0b786342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 14 Aug 2025 19:00:20 +0200 Subject: [PATCH 3/6] Arm backend: Handle i48 special case for bias tensor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For the case when the activation is 16 bit the bias in TOSA must be a int48_t tensor. Since that can't be represented using torch.dtypes the corresponding node.meta is set with a key 'tosa_dtype_48bit' to pass through the note to the creation of the TOSA Tensor. Also make sure to distinguish between int32 and int48 tensors in fuse constant ops pass. Signed-off-by: Per Åstrand Change-Id: Iefe64f2b02f388c905c9c818ee7d2a6af40bc9e3 --- backends/arm/_passes/add_bias_pass.py | 2 ++ backends/arm/_passes/fuse_equal_placeholders_pass.py | 5 +++++ backends/arm/process_node.py | 8 +++++++- backends/arm/tosa/mapping.py | 8 +++++++- 4 files changed, 21 insertions(+), 2 deletions(-) diff --git a/backends/arm/_passes/add_bias_pass.py b/backends/arm/_passes/add_bias_pass.py index a8a76c0a47b..83f91351557 100644 --- a/backends/arm/_passes/add_bias_pass.py +++ b/backends/arm/_passes/add_bias_pass.py @@ -59,6 +59,8 @@ def call(self, graph_module): persistent_buffer=True, name=f"{node.name}_bias", ) + if node.args[0].meta["val"].dtype == torch.int16: + bias_node.meta["tosa_dtype_48bit"] = True node.update_arg(2, bias_node) if modified: diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index cf1177a0448..8263798c99e 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -47,9 +47,14 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: continue # Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes # Ensure tensor is on CPU and contiguous + + # ensure we don't merge any special case int48_t tensors with int32_t tensors + # since int48_t tensors needs to be instantiated separately. + is_int48 = node.meta.get("tosa_dtype_48bit", False) t_cpu = tensor.detach().cpu().contiguous() data_bytes = t_cpu.numpy().tobytes() key = ( + is_int48, str(t_cpu.dtype), tuple(t_cpu.shape), hashlib.sha1(data_bytes).hexdigest(), diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 5093ea32d4c..3a9e92e1a78 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -112,10 +112,16 @@ def process_inputs_to_parameters( if tosa_arg.dtype == torch.float32: assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float" + # Handle special case for INT48 tensors + if node.meta.get("tosa_dtype_48bit", False): + tosa_dtype = ts.DType.INT48 + else: + tosa_dtype = tosa_arg.dtype + parameter_values = np.transpose(parameter_values, tosa_arg.dim_order) tosa_graph.addConst( - parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name + parameter_values.shape, tosa_dtype, parameter_values, name=tosa_arg.name ) diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 935d9f8da77..f57f057ff10 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -130,10 +130,16 @@ def __process_node(self, argument: torch.fx.Node): """ self.name: str = argument.name - self.dtype, self.shape, self.dim_order = extract_tensor_meta( + output_dtype, self.shape, self.dim_order = extract_tensor_meta( argument.meta, self.tosa_spec ) + # Handle special case of int + if argument.meta.get("tosa_dtype_48bit", False): + output_dtype = ts.DType.INT48 + + self.dtype = output_dtype + def __process_list(self, argument): """Capture a sequence argument as ``special``. From 2bd09f9a6d0e2ee48214ef04559f223fbf2d6faa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Thu, 14 Aug 2025 19:11:18 +0200 Subject: [PATCH 4/6] Arm backend: Fix mult and scale calculation for int48_t MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Per Åstrand Change-Id: Ibe158d8d35a632547290f1b9a055d061ae267d77 --- backends/arm/tosa/quant_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/backends/arm/tosa/quant_utils.py b/backends/arm/tosa/quant_utils.py index 027c26fc20a..68ceec8d97c 100644 --- a/backends/arm/tosa/quant_utils.py +++ b/backends/arm/tosa/quant_utils.py @@ -268,6 +268,9 @@ def compute_multiplier_and_shift( if shift > 62: multiplier = multiplier >> min(31, shift - 62) shift = 62 + + assert multiplier >= 0, "Multiplier should be non-negative" + assert shift >= 2 and shift <= 62, "Shift should be in range [2, 62]" multipliers.append(multiplier) shifts.append(shift) return multipliers, shifts @@ -322,8 +325,8 @@ def build_rescale( import tosa.Op as TosaOp # type: ignore - scaleWidth = 32 - is_scale32 = True + scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 + is_scale32 = False if input_node.dtype == ts.DType.INT48 else True multipliers, shifts = compute_multiplier_and_shift(scale, scaleWidth) rescale_inputs = create_const_ops_for_rescale( tosa_fb, From aafcede019fc8d2d83dcfe5c0a870fcedcd756a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Tue, 26 Aug 2025 09:18:38 +0200 Subject: [PATCH 5/6] Arm backend: Enable linear 16a8w tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable tests of int16 activations and int8 weight quantization. Test for large_rand is disabled to sort out why the test is flaky. Signed-off-by: Per Åstrand Change-Id: I9de5d472f8862edebcf82c140399985db930c069 --- backends/arm/scripts/parse_test_names.py | 3 +++ backends/arm/test/ops/test_linear.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index c6eaafa597b..2629d8eb257 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -95,6 +95,9 @@ def parse_test_name( op = op.removesuffix("_1d") op = op.removesuffix("_2d") + # Remove suffix for 16 bit activation and 8 bit weight test cases + op = op.removesuffix("_16a8w") + assert target != "None", f"{test_name} does not contain one of {TARGETS}" assert ( op in op_name_map.keys() diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index f9aa4f14048..ebc2ead8a83 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -277,10 +277,14 @@ def get_symmetric_a16w8_linear_quantizer( ) -@common.parametrize("test_data", test_data_rank1_INT | test_data_rank4_INT) -@pytest.mark.xfail( - reason="missing int16 linear ops support; fails at TOSA reference model run with Invalid TOSA graph" -) +test_data_all_16a8w = test_data_rank1_INT | test_data_rank4_INT +# TODO: Remove large rand test as they are flaky until sorted out why: MLETORCH-1377 +for k in list(test_data_all_16a8w.keys()): + if "large_rand" in k: + test_data_all_16a8w.pop(k) + + +@common.parametrize("test_data", test_data_all_16a8w) def test_linear_16a8w_tosa_INT(test_data: torch.Tensor): """Test linear operation with 16A8W quantization (16-bit activations, 8-bit weights)""" test_data, out_features, has_bias, per_channel_quantization = test_data() From 18c9985be3ef9aa6c66f882e8adccb38a59a4c93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Mon, 15 Sep 2025 23:03:09 +0200 Subject: [PATCH 6/6] Arm backend: Add special dtype TOSA handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a enum class to handle special dtypes that can't be represented in torch (i.e. int48_t) to avoid leaking serializer types into the pass handling of the backend. Signed-off-by: Per Åstrand Change-Id: I3388cec3c8a26f28790eedc3f124c336b6724cb4 --- backends/arm/_passes/add_bias_pass.py | 5 +++- .../_passes/fuse_equal_placeholders_pass.py | 4 +++- backends/arm/process_node.py | 7 +++--- backends/arm/tosa/mapping.py | 23 ++++++++++++++++--- 4 files changed, 31 insertions(+), 8 deletions(-) diff --git a/backends/arm/_passes/add_bias_pass.py b/backends/arm/_passes/add_bias_pass.py index 83f91351557..fd5476f51b8 100644 --- a/backends/arm/_passes/add_bias_pass.py +++ b/backends/arm/_passes/add_bias_pass.py @@ -8,6 +8,7 @@ import torch from executorch.backends.arm._passes import ArmPass from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import create_constant_placeholder from executorch.exir.dialects._ops import ops as exir_ops @@ -60,7 +61,9 @@ def call(self, graph_module): name=f"{node.name}_bias", ) if node.args[0].meta["val"].dtype == torch.int16: - bias_node.meta["tosa_dtype_48bit"] = True + bias_node.meta[TosaSpecialDtype.meta_key()] = ( + TosaSpecialDtype.INT48 + ) node.update_arg(2, bias_node) if modified: diff --git a/backends/arm/_passes/fuse_equal_placeholders_pass.py b/backends/arm/_passes/fuse_equal_placeholders_pass.py index 8263798c99e..b8b8143e6c5 100644 --- a/backends/arm/_passes/fuse_equal_placeholders_pass.py +++ b/backends/arm/_passes/fuse_equal_placeholders_pass.py @@ -8,11 +8,13 @@ from typing import Set, Type import torch + from executorch.backends.arm._passes.arm_pass_utils import ( get_constant_placeholder_kind, get_param_tensor, is_param_node, ) +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.transforms.utils import ( create_constant_placeholder, delete_constant_placeholder, @@ -50,7 +52,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # ensure we don't merge any special case int48_t tensors with int32_t tensors # since int48_t tensors needs to be instantiated separately. - is_int48 = node.meta.get("tosa_dtype_48bit", False) + is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None) t_cpu = tensor.detach().cpu().contiguous() data_bytes = t_cpu.numpy().tobytes() key = ( diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index 3a9e92e1a78..50257bc9180 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -12,7 +12,7 @@ import torch import torch.fx from executorch.backends.arm.operators.node_visitor import NodeVisitor -from executorch.backends.arm.tosa.mapping import TosaArg +from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.backends.arm.tosa.utils import tosa_shape from torch._export.utils import ( @@ -113,8 +113,9 @@ def process_inputs_to_parameters( assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float" # Handle special case for INT48 tensors - if node.meta.get("tosa_dtype_48bit", False): - tosa_dtype = ts.DType.INT48 + special_type = node.meta.get(TosaSpecialDtype.meta_key(), None) + if isinstance(special_type, TosaSpecialDtype): + tosa_dtype = special_type.get_tosa_dtype() else: tosa_dtype = tosa_arg.dtype diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index f57f057ff10..64e4ae96e08 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -11,6 +11,7 @@ """ +from enum import Enum from typing import Any, Optional, Sequence import serializer.tosa_serializer as ts # type: ignore @@ -31,6 +32,22 @@ ) +class TosaSpecialDtype(Enum): + """ + Special TOSA data types that are not natively supported in PyTorch, to be + used in specific scenarios as a value in the key from meta_key(). + """ + + INT48 = ts.DType.INT48 + + def get_tosa_dtype(self) -> ts.TosaDType.DType: + return self.value + + @staticmethod + def meta_key() -> str: + return "tosa_special_dtype" + + def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any: """Map a ``torch.dtype`` to a ``ts.DType``. @@ -134,9 +151,9 @@ def __process_node(self, argument: torch.fx.Node): argument.meta, self.tosa_spec ) - # Handle special case of int - if argument.meta.get("tosa_dtype_48bit", False): - output_dtype = ts.DType.INT48 + # Handle special case of types not representable in torch (i.e. i48_t) + if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None): + output_dtype = special_type.get_tosa_dtype() self.dtype = output_dtype