Skip to content

Commit 6f12fec

Browse files
authored
Arm backend: Add 16A8W support for view and transpose operations
Differential Revision: D80511313 Pull Request resolved: #13799
1 parent 113c70a commit 6f12fec

File tree

2 files changed

+114
-1
lines changed

2 files changed

+114
-1
lines changed

backends/arm/test/ops/test_view.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@
99

1010
from typing import Tuple
1111

12+
import pytest
1213
import torch
14+
from executorch.backends.arm.quantizer.arm_quantizer import (
15+
get_symmetric_a16w8_quantization_config,
16+
TOSAQuantizer,
17+
)
1318

14-
from executorch.backends.arm.test import common
19+
from executorch.backends.arm.test import common, conftest
1520
from executorch.backends.arm.test.tester.test_pipeline import (
1621
EthosU55PipelineINT,
1722
EthosU85PipelineINT,
@@ -20,6 +25,8 @@
2025
TosaPipelineINT,
2126
VgfPipeline,
2227
)
28+
from executorch.backends.arm.tosa.specification import TosaSpecification
29+
from executorch.backends.xnnpack.test.tester import Quantize
2330

2431
aten_op = "torch.ops.aten.view.default"
2532

@@ -147,3 +154,108 @@ def test_view_u85_INT(test_data: Tuple):
147154
exir_ops=[],
148155
)
149156
pipeline.run()
157+
158+
159+
def get_symmetric_a16w8_view_quantizer(per_channel_quantization=False):
160+
tosa_version = conftest.get_option("tosa_version")
161+
tosa_profiles = {
162+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
163+
}
164+
165+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
166+
quantizer.set_global(
167+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
168+
)
169+
170+
return Quantize(
171+
quantizer,
172+
get_symmetric_a16w8_quantization_config(
173+
is_per_channel=per_channel_quantization
174+
),
175+
)
176+
177+
178+
@common.parametrize("test_data", View.needs_transpose_tests)
179+
@pytest.mark.xfail(
180+
reason="missing int16 view ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13977"
181+
)
182+
def test_view_16a8w_tosa_INT(test_data: Tuple):
183+
"""Test view operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
184+
per_channel_quantization = False
185+
test_tensor, new_shape = test_data()
186+
187+
pipeline = TosaPipelineINT[input_t1](
188+
View(new_shape),
189+
(test_tensor,),
190+
aten_op,
191+
exir_op=[],
192+
per_channel_quantization=per_channel_quantization,
193+
use_to_edge_transform_and_lower=True,
194+
tosa_extensions=["int16"],
195+
)
196+
197+
pipeline.change_args(
198+
"quantize",
199+
get_symmetric_a16w8_view_quantizer(
200+
per_channel_quantization=per_channel_quantization
201+
),
202+
)
203+
pipeline.run()
204+
205+
206+
@common.parametrize("test_data", View.needs_transpose_tests)
207+
@common.XfailIfNoCorstone300
208+
@pytest.mark.xfail(
209+
reason="Vela compilation fails with 'Invalid arguments' for int16 view operations"
210+
)
211+
def test_view_16a8w_u55_INT16(test_data: Tuple):
212+
"""Test view operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
213+
per_channel_quantization = False
214+
test_tensor, new_shape = test_data()
215+
216+
pipeline = EthosU55PipelineINT[input_t1](
217+
View(new_shape),
218+
(test_tensor,),
219+
aten_op,
220+
exir_ops=[],
221+
per_channel_quantization=per_channel_quantization,
222+
use_to_edge_transform_and_lower=True,
223+
run_on_fvp=True,
224+
)
225+
226+
pipeline.change_args(
227+
"quantize",
228+
get_symmetric_a16w8_view_quantizer(
229+
per_channel_quantization=per_channel_quantization
230+
),
231+
)
232+
pipeline.run()
233+
234+
235+
@common.parametrize("test_data", View.needs_transpose_tests)
236+
@common.XfailIfNoCorstone320
237+
@pytest.mark.xfail(
238+
reason="Vela compilation fails with 'Invalid arguments' for int16 view operations"
239+
)
240+
def test_view_16a8w_u85_INT16(test_data: Tuple):
241+
"""Test view operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
242+
per_channel_quantization = False
243+
test_tensor, new_shape = test_data()
244+
245+
pipeline = EthosU85PipelineINT[input_t1](
246+
View(new_shape),
247+
(test_tensor,),
248+
aten_op,
249+
exir_ops=[],
250+
per_channel_quantization=per_channel_quantization,
251+
use_to_edge_transform_and_lower=True,
252+
run_on_fvp=True,
253+
)
254+
255+
pipeline.change_args(
256+
"quantize",
257+
get_symmetric_a16w8_view_quantizer(
258+
per_channel_quantization=per_channel_quantization
259+
),
260+
)
261+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def define_arm_tests():
2020
"ops/test_slice.py",
2121
"ops/test_sigmoid.py",
2222
"ops/test_tanh.py",
23+
"ops/test_view.py",
2324
"ops/test_cos.py",
2425
]
2526

0 commit comments

Comments
 (0)