Skip to content

Commit 8b65cf0

Browse files
committed
Add 16A8W support and test for tanh operation
Pull Request resolved: #13797 Add 16A8W quantization support and test for the tanh operation in ExecutorTorch ARM backend. This follows the pattern established for linear, mul, and sigmoid operations, extending int16 support to tanh operations. Changes: - Add INT16 dtype validation support in op_tanh.py - Add test_tanh_tensor_16a8w_tosa_INT test function - Enable test_tanh.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 308986671 @exported-using-ghexport Differential Revision: [D80510815](https://our.internmc.facebook.com/intern/diff/D80510815/)
1 parent b8320ed commit 8b65cf0

File tree

1 file changed

+110
-1
lines changed

1 file changed

+110
-1
lines changed

backends/arm/test/ops/test_tanh.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,23 @@
66

77
from typing import Tuple
88

9+
import pytest
910
import torch
11+
from executorch.backends.arm.quantizer.arm_quantizer import (
12+
get_symmetric_a16w8_quantization_config,
13+
TOSAQuantizer,
14+
)
1015

11-
from executorch.backends.arm.test import common
16+
from executorch.backends.arm.test import common, conftest
1217
from executorch.backends.arm.test.tester.test_pipeline import (
1318
EthosU55PipelineINT,
1419
EthosU85PipelineINT,
1520
TosaPipelineFP,
1621
TosaPipelineINT,
1722
VgfPipeline,
1823
)
24+
from executorch.backends.arm.tosa.specification import TosaSpecification
25+
from executorch.backends.xnnpack.test.tester import Quantize
1926

2027
aten_op = "torch.ops.aten.tanh.default"
2128
input_t1 = Tuple[torch.Tensor] # Input x
@@ -105,3 +112,105 @@ def test_tanh_vgf_INT(test_data: Tuple):
105112
tosa_version="TOSA-1.0+INT",
106113
)
107114
pipeline.run()
115+
116+
117+
def get_symmetric_a16w8_tanh_quantizer(per_channel_quantization=False):
118+
tosa_version = conftest.get_option("tosa_version")
119+
tosa_profiles = {
120+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
121+
}
122+
123+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
124+
quantizer.set_global(
125+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
126+
)
127+
128+
return Quantize(
129+
quantizer,
130+
get_symmetric_a16w8_quantization_config(
131+
is_per_channel=per_channel_quantization
132+
),
133+
)
134+
135+
136+
@common.parametrize("test_data", test_data_suite)
137+
@pytest.mark.xfail(
138+
reason="missing int16 tanh ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13975"
139+
)
140+
def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor):
141+
"""Test tanh operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
142+
per_channel_quantization = False
143+
144+
pipeline = TosaPipelineINT[input_t1](
145+
Tanh(),
146+
(test_data(),),
147+
aten_op,
148+
exir_op=[],
149+
per_channel_quantization=per_channel_quantization,
150+
use_to_edge_transform_and_lower=True,
151+
tosa_extensions=["int16"],
152+
)
153+
154+
pipeline.change_args(
155+
"quantize",
156+
get_symmetric_a16w8_tanh_quantizer(
157+
per_channel_quantization=per_channel_quantization
158+
),
159+
)
160+
pipeline.run()
161+
162+
163+
@common.parametrize("test_data", test_data_suite)
164+
@common.XfailIfNoCorstone300
165+
@pytest.mark.xfail(
166+
reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations"
167+
)
168+
def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor):
169+
"""Test tanh operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
170+
per_channel_quantization = False
171+
172+
pipeline = EthosU55PipelineINT[input_t1](
173+
Tanh(),
174+
(test_data(),),
175+
aten_op,
176+
exir_ops=[],
177+
per_channel_quantization=per_channel_quantization,
178+
use_to_edge_transform_and_lower=True,
179+
run_on_fvp=True,
180+
)
181+
182+
pipeline.change_args(
183+
"quantize",
184+
get_symmetric_a16w8_tanh_quantizer(
185+
per_channel_quantization=per_channel_quantization
186+
),
187+
)
188+
pipeline.run()
189+
190+
191+
@common.parametrize("test_data", test_data_suite)
192+
@common.XfailIfNoCorstone320
193+
@pytest.mark.xfail(
194+
reason="Vela compilation fails with 'Invalid arguments' for int16 tanh operations"
195+
)
196+
def test_tanh_16a8w_u85_INT16(test_data: torch.Tensor):
197+
"""Test tanh operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
198+
per_channel_quantization = False
199+
200+
pipeline = EthosU85PipelineINT[input_t1](
201+
Tanh(),
202+
(test_data(),),
203+
aten_op,
204+
exir_ops=[],
205+
per_channel_quantization=per_channel_quantization,
206+
use_to_edge_transform_and_lower=True,
207+
run_on_fvp=True,
208+
)
209+
210+
pipeline.change_args(
211+
"quantize",
212+
get_symmetric_a16w8_tanh_quantizer(
213+
per_channel_quantization=per_channel_quantization
214+
),
215+
)
216+
pipeline.run()

0 commit comments

Comments
 (0)