Skip to content

Commit ada91d5

Browse files
committed
Add INT16 support to rescale operation
Pull Request resolved: #13802 Add INT16 support for RequantizeNode rescale operations in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, tanh, slice, view/transpose, cat, and FCNode operations, extending int16 support to RequantizeNode rescale operations. Changes: - Add INT16 dtype validation support in op_rescale.py - Enable rescale operations for 16A8W quantization configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. RequantizeNode rescale operations are essential for proper quantization scaling in the 16A8W pipeline. ghstack-source-id: 308021436 @exported-using-ghexport Differential Revision: [D80513725](https://our.internmc.facebook.com/intern/diff/D80513725/)
1 parent 2133199 commit ada91d5

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

backends/arm/operators/op_rescale.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,20 @@ def define_node(
4646
input_zp = cast(int, node.args[3])
4747
output_zp = cast(int, node.args[4])
4848

49-
if input_dtype != map_dtype(torch.int8, self.tosa_spec) and input_zp != 0:
49+
if (
50+
input_dtype
51+
not in [
52+
map_dtype(torch.int8, self.tosa_spec),
53+
map_dtype(torch.int16, self.tosa_spec),
54+
]
55+
and input_zp != 0
56+
):
5057
raise ValueError(
51-
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
58+
f"If input dtype is not int8 or int16, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
5259
)
53-
if output_dtype != torch.int8 and output_zp != 0:
60+
if output_dtype not in [torch.int8, torch.int16] and output_zp != 0:
5461
raise ValueError(
55-
f"If output dtype is not int8, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
62+
f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
5663
)
5764

5865
build_rescale(

0 commit comments

Comments
 (0)