Skip to content

Commit 683984b

Browse files
committed
Add 16A8W support and test for mul operation
Pull Request resolved: #13795 Add 16A8W quantization support and test for the mul operation in ExecutorTorch ARM backend. This follows the pattern established for linear operations, extending int16 support to mul operations. Changes: - Add INT16 dtype validation support in op_mul.py - Add test_mul_tensor_16a8w_tosa_INT test function - Enable test_mul.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. ghstack-source-id: 307682119 @exported-using-ghexport Differential Revision: [D80510628](https://our.internmc.facebook.com/intern/diff/D80510628/)
1 parent 3533b59 commit 683984b

File tree

3 files changed

+111
-4
lines changed

3 files changed

+111
-4
lines changed

backends/arm/operators/op_mul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_node(
5151
validate_valid_dtype(
5252
self.target,
5353
[*inputs, output],
54-
[ts.DType.INT8, ts.DType.INT32],
54+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
5555
output.tosa_spec,
5656
)
5757

@@ -80,15 +80,15 @@ def define_node(
8080
tosa_spec=self.tosa_spec,
8181
)
8282
else:
83-
# input[0].dtype == ts.DType.INT32
83+
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
8484
# Non quantized input, natively support by TOSA.MUL
8585
input_A_rescaled, input_B_rescaled = inputs[0], inputs[1]
8686

8787
if output.dtype == ts.DType.INT8:
8888
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
8989
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
9090
else:
91-
# output.dtype == ts.DType.INT32
91+
# output.dtype == ts.DType.INT16 or ts.DType.INT32
9292
mul_output = output
9393

9494
# Do the INT32 Mul

backends/arm/test/ops/test_mul.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,23 @@
88

99
from typing import Tuple
1010

11+
import pytest
1112
import torch
13+
from executorch.backends.arm.quantizer.arm_quantizer import (
14+
get_symmetric_a16w8_quantization_config,
15+
TOSAQuantizer,
16+
)
1217

13-
from executorch.backends.arm.test import common
18+
from executorch.backends.arm.test import common, conftest
1419
from executorch.backends.arm.test.tester.test_pipeline import (
1520
EthosU55PipelineINT,
1621
EthosU85PipelineINT,
1722
TosaPipelineFP,
1823
TosaPipelineINT,
1924
VgfPipeline,
2025
)
26+
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.xnnpack.test.tester import Quantize
2128

2229
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x
2330
aten_op = "torch.ops.aten.mul.Tensor"
@@ -284,3 +291,102 @@ def test_mul_tensor_vgf_INT_int32(test_data: torch.Tensor):
284291
)
285292
pipeline.pop_stage("check.quant_nodes")
286293
pipeline.run()
294+
295+
296+
def get_symmetric_a16w8_mul_quantizer(per_channel_quantization=False):
297+
tosa_version = conftest.get_option("tosa_version")
298+
tosa_profiles = {
299+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
300+
}
301+
302+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
303+
quantizer.set_global(
304+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
305+
)
306+
307+
return Quantize(
308+
quantizer,
309+
get_symmetric_a16w8_quantization_config(
310+
is_per_channel=per_channel_quantization
311+
),
312+
)
313+
314+
315+
@common.parametrize("test_data", test_data_suite)
316+
def test_mul_tensor_16a8w_tosa_INT(test_data: input_t1):
317+
"""Test mul operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
318+
per_channel_quantization = False
319+
320+
pipeline = TosaPipelineINT[input_t1](
321+
Mul(),
322+
test_data(),
323+
aten_op,
324+
exir_op=[],
325+
per_channel_quantization=per_channel_quantization,
326+
use_to_edge_transform_and_lower=True,
327+
tosa_extensions=["int16"],
328+
)
329+
330+
pipeline.change_args(
331+
"quantize",
332+
get_symmetric_a16w8_mul_quantizer(
333+
per_channel_quantization=per_channel_quantization
334+
),
335+
)
336+
pipeline.run()
337+
338+
339+
@common.parametrize("test_data", test_data_suite)
340+
@common.XfailIfNoCorstone300
341+
@pytest.mark.xfail(
342+
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947"
343+
)
344+
def test_mul_tensor_16a8w_u55_INT16(test_data: input_t1):
345+
"""Test mul operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
346+
per_channel_quantization = False
347+
348+
pipeline = EthosU55PipelineINT[input_t1](
349+
Mul(),
350+
test_data(),
351+
aten_op,
352+
exir_ops=[],
353+
per_channel_quantization=per_channel_quantization,
354+
use_to_edge_transform_and_lower=True,
355+
run_on_fvp=True,
356+
)
357+
358+
pipeline.change_args(
359+
"quantize",
360+
get_symmetric_a16w8_mul_quantizer(
361+
per_channel_quantization=per_channel_quantization
362+
),
363+
)
364+
pipeline.run()
365+
366+
367+
@common.parametrize("test_data", test_data_suite)
368+
@common.XfailIfNoCorstone320
369+
@pytest.mark.xfail(
370+
reason="Vela compilation fails with 'Invalid arguments' for int16 mul operations. See: https://github.com/pytorch/executorch/issues/13947"
371+
)
372+
def test_mul_tensor_16a8w_u85_INT16(test_data: input_t1):
373+
"""Test mul operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
374+
per_channel_quantization = False
375+
376+
pipeline = EthosU85PipelineINT[input_t1](
377+
Mul(),
378+
test_data(),
379+
aten_op,
380+
exir_ops=[],
381+
per_channel_quantization=per_channel_quantization,
382+
use_to_edge_transform_and_lower=True,
383+
run_on_fvp=True,
384+
)
385+
386+
pipeline.change_args(
387+
"quantize",
388+
get_symmetric_a16w8_mul_quantizer(
389+
per_channel_quantization=per_channel_quantization
390+
),
391+
)
392+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def define_arm_tests():
1616
"ops/test_add.py",
1717
"ops/test_avg_pool2d.py",
1818
"ops/test_linear.py",
19+
"ops/test_mul.py",
1920
"ops/test_slice.py",
2021
"ops/test_sigmoid.py",
2122
"ops/test_tanh.py",

0 commit comments

Comments
 (0)