Skip to content

Commit d62d5fb

Browse files
committed
Add INT16 support to rescale operation
Add INT16 support for RequantizeNode rescale operations in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, sigmoid, tanh, slice, view/transpose, cat, and FCNode operations, extending int16 support to RequantizeNode rescale operations. Changes: - Add INT16 dtype validation support in op_rescale.py - Enable rescale operations for 16A8W quantization configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. RequantizeNode rescale operations are essential for proper quantization scaling in the 16A8W pipeline. Differential Revision: [D80513725](https://our.internmc.facebook.com/intern/diff/D80513725/) ghstack-source-id: 304555411 Pull Request resolved: #13802
1 parent 1fac836 commit d62d5fb

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

backends/arm/operators/op_rescale.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,20 @@ def define_node(
4646
input_zp = cast(int, node.args[3])
4747
output_zp = cast(int, node.args[4])
4848

49-
if input_dtype != map_dtype(torch.int8, self.tosa_spec) and input_zp != 0:
49+
if (
50+
input_dtype
51+
not in [
52+
map_dtype(torch.int8, self.tosa_spec),
53+
map_dtype(torch.int16, self.tosa_spec),
54+
]
55+
and input_zp != 0
56+
):
5057
raise ValueError(
51-
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
58+
f"If input dtype is not int8 or int16, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
5259
)
53-
if output_dtype != torch.int8 and output_zp != 0:
60+
if output_dtype not in [torch.int8, torch.int16] and output_zp != 0:
5461
raise ValueError(
55-
f"If output dtype is not int8, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
62+
f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
5663
)
5764

5865
build_rescale(

0 commit comments

Comments
 (0)