Skip to content

Commit b8320ed

Browse files
committed
Add 16A8W support and test for sigmoid operation
Pull Request resolved: #13796 Add 16A8W quantization support and test for the sigmoid operation in ExecutorTorch ARM backend. This follows the pattern established for linear and mul operations, extending int16 support to sigmoid operations. Changes: - Add INT16 dtype validation support in op_sigmoid.py - Add test_sigmoid_tensor_16a8w_tosa_INT test function - Enable test_sigmoid.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 308986667 @exported-using-ghexport Differential Revision: [D80510729](https://our.internmc.facebook.com/intern/diff/D80510729/)
1 parent 1d37845 commit b8320ed

File tree

1 file changed

+110
-1
lines changed

1 file changed

+110
-1
lines changed

backends/arm/test/ops/test_sigmoid.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,22 @@
88

99
from typing import Tuple
1010

11+
import pytest
1112
import torch
12-
from executorch.backends.arm.test import common
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
17+
from executorch.backends.arm.test import common, conftest
1318
from executorch.backends.arm.test.tester.test_pipeline import (
1419
EthosU55PipelineINT,
1520
EthosU85PipelineINT,
1621
TosaPipelineFP,
1722
TosaPipelineINT,
1823
VgfPipeline,
1924
)
25+
from executorch.backends.arm.tosa.specification import TosaSpecification
26+
from executorch.backends.xnnpack.test.tester import Quantize
2027

2128
aten_op = "torch.ops.aten.sigmoid.default" # Used for checking that we do not have softmax in the graph after decompose
2229
exir_op = "executorch_exir_dialects_edge__ops_aten_sigmoid_default"
@@ -253,3 +260,105 @@ def test_sigmoid_vgf_INT_add_3():
253260
tosa_version="TOSA-1.0+INT",
254261
)
255262
pipeline.run()
263+
264+
265+
def get_symmetric_a16w8_sigmoid_quantizer(per_channel_quantization=False):
266+
tosa_version = conftest.get_option("tosa_version")
267+
tosa_profiles = {
268+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
269+
}
270+
271+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
272+
quantizer.set_global(
273+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
274+
)
275+
276+
return Quantize(
277+
quantizer,
278+
get_symmetric_a16w8_quantization_config(
279+
is_per_channel=per_channel_quantization
280+
),
281+
)
282+
283+
284+
@common.parametrize("test_data", test_data_suite)
285+
@pytest.mark.xfail(
286+
reason="missing int16 sigmoid ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13974"
287+
)
288+
def test_sigmoid_16a8w_tosa_INT(test_data: torch.Tensor):
289+
"""Test sigmoid operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
290+
per_channel_quantization = False
291+
292+
pipeline = TosaPipelineINT[input_t1](
293+
Sigmoid(),
294+
(test_data(),),
295+
aten_op,
296+
exir_op=[],
297+
per_channel_quantization=per_channel_quantization,
298+
use_to_edge_transform_and_lower=True,
299+
tosa_extensions=["int16"],
300+
)
301+
302+
pipeline.change_args(
303+
"quantize",
304+
get_symmetric_a16w8_sigmoid_quantizer(
305+
per_channel_quantization=per_channel_quantization
306+
),
307+
)
308+
pipeline.run()
309+
310+
311+
@common.parametrize("test_data", test_data_suite)
312+
@common.XfailIfNoCorstone300
313+
@pytest.mark.xfail(
314+
reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations"
315+
)
316+
def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor):
317+
"""Test sigmoid operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
318+
per_channel_quantization = False
319+
320+
pipeline = EthosU55PipelineINT[input_t1](
321+
Sigmoid(),
322+
(test_data(),),
323+
aten_op,
324+
exir_op,
325+
per_channel_quantization=per_channel_quantization,
326+
use_to_edge_transform_and_lower=True,
327+
run_on_fvp=True,
328+
)
329+
330+
pipeline.change_args(
331+
"quantize",
332+
get_symmetric_a16w8_sigmoid_quantizer(
333+
per_channel_quantization=per_channel_quantization
334+
),
335+
)
336+
pipeline.run()
337+
338+
339+
@common.parametrize("test_data", test_data_suite)
340+
@common.XfailIfNoCorstone320
341+
@pytest.mark.xfail(
342+
reason="Vela compilation fails with 'Invalid arguments' for int16 sigmoid operations"
343+
)
344+
def test_sigmoid_16a8w_u85_INT16(test_data: torch.Tensor):
345+
"""Test sigmoid operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
346+
per_channel_quantization = False
347+
348+
pipeline = EthosU85PipelineINT[input_t1](
349+
Sigmoid(),
350+
(test_data(),),
351+
aten_op,
352+
exir_op,
353+
per_channel_quantization=per_channel_quantization,
354+
use_to_edge_transform_and_lower=True,
355+
run_on_fvp=True,
356+
)
357+
358+
pipeline.change_args(
359+
"quantize",
360+
get_symmetric_a16w8_sigmoid_quantizer(
361+
per_channel_quantization=per_channel_quantization
362+
),
363+
)
364+
pipeline.run()

0 commit comments

Comments
 (0)