Skip to content

Commit 7ec2317

Browse files
3l1facebook-github-bot
authored andcommitted
Enable int16 for op permute (#15256)
Summary: Enable int16 for op permute Reviewed By: Ninja91, digantdesai Differential Revision: D84948536
1 parent 5d71c9b commit 7ec2317

File tree

3 files changed

+114
-4
lines changed

3 files changed

+114
-4
lines changed

backends/arm/operators/op_permute.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def define_node(
117117
validate_valid_dtype(
118118
self.target,
119119
[inputs[0], output],
120-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
120+
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
121121
output.tosa_spec,
122122
)
123123

backends/arm/test/ops/test_permute.py

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from typing import Tuple
1010

1111
import torch
12+
from executorch.backends.arm.quantizer.arm_quantizer import (
13+
get_symmetric_a16w8_quantization_config,
14+
TOSAQuantizer,
15+
)
16+
from executorch.backends.arm.test import common, conftest
1217

1318
from executorch.backends.arm.test import common
1419

@@ -19,7 +24,8 @@
1924
TosaPipelineINT,
2025
VgfPipeline,
2126
)
22-
from torchvision.ops import Permute
27+
from executorch.backends.arm.tosa import TosaSpecification
28+
from executorch.backends.xnnpack.test.tester import Quantize
2329

2430
input_t1 = Tuple[torch.Tensor] # Input x
2531

@@ -42,10 +48,10 @@ class SimplePermute(torch.nn.Module):
4248
def __init__(self, dims: list[int]):
4349
super().__init__()
4450

45-
self.permute = Permute(dims=dims)
51+
self.dims = dims
4652

4753
def forward(self, x):
48-
return self.permute(x)
54+
return torch.permute(x, self.dims)
4955

5056

5157
@common.parametrize("test_data", test_data_suite)
@@ -128,3 +134,106 @@ def test_permute_vgf_INT(test_data):
128134
tosa_version="TOSA-1.0+INT",
129135
)
130136
pipeline.run()
137+
138+
139+
140+
def get_symmetric_a16w8_permute_quantizer(
141+
u55_config=False, per_channel_quantization=False
142+
):
143+
tosa_version = conftest.get_option("tosa_version")
144+
tosa_profiles = {
145+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
146+
}
147+
148+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
149+
quantizer.set_global(
150+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
151+
)
152+
quantizer.set_module_type(
153+
torch.nn.Linear,
154+
get_symmetric_a16w8_quantization_config(
155+
is_per_channel=per_channel_quantization
156+
),
157+
)
158+
159+
return Quantize(
160+
quantizer,
161+
get_symmetric_a16w8_quantization_config(
162+
is_per_channel=per_channel_quantization
163+
),
164+
)
165+
166+
167+
@common.parametrize("test_data", test_data_suite)
168+
def test_permute_16a8w_tosa_INT(test_data: torch.Tensor):
169+
"""Test permute operation with int16 quantization"""
170+
test_data, dims = test_data()
171+
pipeline = TosaPipelineINT[input_t1](
172+
SimplePermute(dims=dims),
173+
(test_data,),
174+
aten_op,
175+
exir_op=[],
176+
per_channel_quantization=False,
177+
use_to_edge_transform_and_lower=True,
178+
tosa_extensions=["int16"],
179+
)
180+
181+
pipeline.change_args(
182+
"quantize",
183+
get_symmetric_a16w8_permute_quantizer(
184+
per_channel_quantization=False
185+
),
186+
)
187+
# Run the pipeline
188+
pipeline.run()
189+
190+
191+
@common.parametrize("test_data", test_data_suite)
192+
@common.XfailIfNoCorstone300
193+
def test_permute_16a8w_u55_INT16(test_data: torch.Tensor):
194+
"""Test permute operation with int16 quantization on U55"""
195+
test_data, dims = test_data()
196+
pipeline = EthosU55PipelineINT[input_t1](
197+
SimplePermute(dims=dims),
198+
(test_data,),
199+
aten_op,
200+
exir_ops=[],
201+
per_channel_quantization=True,
202+
use_to_edge_transform_and_lower=True,
203+
atol=1e-03,
204+
rtol=1e-03,
205+
run_on_fvp=True,
206+
)
207+
208+
pipeline.change_args(
209+
"quantize",
210+
get_symmetric_a16w8_permute_quantizer(
211+
per_channel_quantization=True
212+
),
213+
)
214+
pipeline.run()
215+
216+
217+
@common.parametrize("test_data", test_data_suite)
218+
@common.XfailIfNoCorstone320
219+
def test_permute_16a8w_u85_INT16(test_data: torch.Tensor):
220+
"""Test permute operation with int16 quantization on U85"""
221+
test_data, dims = test_data()
222+
pipeline = EthosU85PipelineINT[input_t1](
223+
SimplePermute(dims=dims),
224+
(test_data,),
225+
aten_op,
226+
exir_ops=[],
227+
use_to_edge_transform_and_lower=True,
228+
atol=1e-03,
229+
rtol=1e-03,
230+
run_on_fvp=True,
231+
)
232+
233+
pipeline.change_args(
234+
"quantize",
235+
get_symmetric_a16w8_permute_quantizer(
236+
per_channel_quantization=False
237+
),
238+
)
239+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def define_arm_tests():
2020
"ops/test_cat.py",
2121
"ops/test_linear.py",
2222
"ops/test_mul.py",
23+
"ops/test_permute.py",
2324
"ops/test_slice.py",
2425
"ops/test_sigmoid.py",
2526
"ops/test_sub.py",

0 commit comments

Comments
 (0)