Skip to content

Commit e9030db

Browse files
3l1facebook-github-bot
authored andcommitted
Enable int16 for op permute (#15256)
Summary: Enable int16 for op permute Reviewed By: Ninja91, digantdesai Differential Revision: D84948536
1 parent 5d71c9b commit e9030db

File tree

3 files changed

+115
-4
lines changed

3 files changed

+115
-4
lines changed

backends/arm/operators/op_permute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def define_node(
117117
validate_valid_dtype(
118118
self.target,
119119
[inputs[0], output],
120-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
120+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
121121
output.tosa_spec,
122122
)
123123

backends/arm/test/ops/test_permute.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from typing import Tuple
1010

1111
import torch
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_a16w8_quantization_config,
14+
TOSAQuantizer,
15+
)
16+
from executorch.backends.arm.test import common, conftest
1217

1318
from executorch.backends.arm.test import common
1419

@@ -19,7 +24,8 @@
1924
TosaPipelineINT,
2025
VgfPipeline,
2126
)
22-
from torchvision.ops import Permute
27+
from executorch.backends.arm.tosa import TosaSpecification
28+
from executorch.backends.xnnpack.test.tester import Quantize
2329

2430
input_t1 = Tuple[torch.Tensor] # Input x
2531

@@ -42,10 +48,10 @@ class SimplePermute(torch.nn.Module):
4248
def __init__(self, dims: list[int]):
4349
super().__init__()
4450

45-
self.permute = Permute(dims=dims)
51+
self.dims = dims
4652

4753
def forward(self, x):
48-
return self.permute(x)
54+
return torch.permute(x, self.dims)
4955

5056

5157
@common.parametrize("test_data", test_data_suite)
@@ -128,3 +134,107 @@ def test_permute_vgf_INT(test_data):
128134
tosa_version="TOSA-1.0+INT",
129135
)
130136
pipeline.run()
137+
138+
139+
140+
def get_symmetric_a16w8_permute_quantizer(
141+
u55_config=False, per_channel_quantization=False
142+
):
143+
tosa_version = conftest.get_option("tosa_version")
144+
tosa_profiles = {
145+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
146+
}
147+
148+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
149+
quantizer.set_global(
150+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
151+
)
152+
quantizer.set_module_type(
153+
torch.nn.Linear,
154+
get_symmetric_a16w8_quantization_config(
155+
is_per_channel=per_channel_quantization
156+
),
157+
)
158+
159+
return Quantize(
160+
quantizer,
161+
get_symmetric_a16w8_quantization_config(
162+
is_per_channel=per_channel_quantization
163+
),
164+
)
165+
166+
167+
@common.parametrize("test_data", test_data_suite)
168+
def test_permute_16a8w_tosa_INT(test_data: torch.Tensor):
169+
"""Test permute operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
170+
# Create pipeline with custom 16A8W quantization config
171+
test_data, dims = test_data()
172+
pipeline = TosaPipelineINT[input_t1](
173+
SimplePermute(dims=dims),
174+
(test_data,),
175+
aten_op,
176+
exir_op=[],
177+
per_channel_quantization=False,
178+
use_to_edge_transform_and_lower=True,
179+
tosa_extensions=["int16"],
180+
)
181+
182+
pipeline.change_args(
183+
"quantize",
184+
get_symmetric_a16w8_permute_quantizer(
185+
per_channel_quantization=False
186+
),
187+
)
188+
# Run the pipeline
189+
pipeline.run()
190+
191+
192+
@common.parametrize("test_data", test_data_suite)
193+
@common.XfailIfNoCorstone300
194+
def test_permute_16a8w_u55_INT16(test_data: torch.Tensor):
195+
"""Test permute operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
196+
test_data, dims = test_data()
197+
pipeline = EthosU55PipelineINT[input_t1](
198+
SimplePermute(dims=dims),
199+
(test_data,),
200+
aten_op,
201+
exir_ops=[],
202+
per_channel_quantization=True,
203+
use_to_edge_transform_and_lower=True,
204+
atol=1e-03,
205+
rtol=1e-03,
206+
run_on_fvp=True,
207+
)
208+
209+
pipeline.change_args(
210+
"quantize",
211+
get_symmetric_a16w8_permute_quantizer(
212+
per_channel_quantization=True
213+
),
214+
)
215+
pipeline.run()
216+
217+
218+
@common.parametrize("test_data", test_data_suite)
219+
@common.XfailIfNoCorstone320
220+
def test_permute_16a8w_u85_INT16(test_data: torch.Tensor):
221+
"""Test permute operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
222+
test_data, dims = test_data()
223+
pipeline = EthosU85PipelineINT[input_t1](
224+
SimplePermute(dims=dims),
225+
(test_data,),
226+
aten_op,
227+
exir_ops=[],
228+
use_to_edge_transform_and_lower=True,
229+
atol=1e-03,
230+
rtol=1e-03,
231+
run_on_fvp=True,
232+
)
233+
234+
pipeline.change_args(
235+
"quantize",
236+
get_symmetric_a16w8_permute_quantizer(
237+
per_channel_quantization=False
238+
),
239+
)
240+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def define_arm_tests():
2020
"ops/test_cat.py",
2121
"ops/test_linear.py",
2222
"ops/test_mul.py",
23+
"ops/test_permute.py",
2324
"ops/test_slice.py",
2425
"ops/test_sigmoid.py",
2526
"ops/test_sub.py",

0 commit comments

Comments
 (0)