99from typing import Tuple
1010
1111import torch
12+ from executorch .backends .arm .quantizer .arm_quantizer import (
13+ get_symmetric_a16w8_quantization_config ,
14+ TOSAQuantizer ,
15+ )
16+ from executorch .backends .arm .test import common , conftest
1217
1318from executorch .backends .arm .test import common
1419
1924 TosaPipelineINT ,
2025 VgfPipeline ,
2126)
22- from torchvision .ops import Permute
27+ from executorch .backends .arm .tosa import TosaSpecification
28+ from executorch .backends .xnnpack .test .tester import Quantize
2329
2430input_t1 = Tuple [torch .Tensor ] # Input x
2531
@@ -42,10 +48,10 @@ class SimplePermute(torch.nn.Module):
4248 def __init__ (self , dims : list [int ]):
4349 super ().__init__ ()
4450
45- self .permute = Permute ( dims = dims )
51+ self .dims = dims
4652
4753 def forward (self , x ):
48- return self .permute (x )
54+ return torch .permute (x , self . dims )
4955
5056
5157@common .parametrize ("test_data" , test_data_suite )
@@ -128,3 +134,106 @@ def test_permute_vgf_INT(test_data):
128134 tosa_version = "TOSA-1.0+INT" ,
129135 )
130136 pipeline .run ()
137+
138+
139+
140+ def get_symmetric_a16w8_permute_quantizer (
141+ u55_config = False , per_channel_quantization = False
142+ ):
143+ tosa_version = conftest .get_option ("tosa_version" )
144+ tosa_profiles = {
145+ "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT+int16" ),
146+ }
147+
148+ quantizer = TOSAQuantizer (tosa_profiles [tosa_version ])
149+ quantizer .set_global (
150+ get_symmetric_a16w8_quantization_config (is_per_channel = per_channel_quantization )
151+ )
152+ quantizer .set_module_type (
153+ torch .nn .Linear ,
154+ get_symmetric_a16w8_quantization_config (
155+ is_per_channel = per_channel_quantization
156+ ),
157+ )
158+
159+ return Quantize (
160+ quantizer ,
161+ get_symmetric_a16w8_quantization_config (
162+ is_per_channel = per_channel_quantization
163+ ),
164+ )
165+
166+
167+ @common .parametrize ("test_data" , test_data_suite )
168+ def test_permute_16a8w_tosa_INT (test_data : torch .Tensor ):
169+ """Test permute operation with int16 quantization"""
170+ test_data , dims = test_data ()
171+ pipeline = TosaPipelineINT [input_t1 ](
172+ SimplePermute (dims = dims ),
173+ (test_data ,),
174+ aten_op ,
175+ exir_op = [],
176+ per_channel_quantization = False ,
177+ use_to_edge_transform_and_lower = True ,
178+ tosa_extensions = ["int16" ],
179+ )
180+
181+ pipeline .change_args (
182+ "quantize" ,
183+ get_symmetric_a16w8_permute_quantizer (
184+ per_channel_quantization = False
185+ ),
186+ )
187+ # Run the pipeline
188+ pipeline .run ()
189+
190+
191+ @common .parametrize ("test_data" , test_data_suite )
192+ @common .XfailIfNoCorstone300
193+ def test_permute_16a8w_u55_INT16 (test_data : torch .Tensor ):
194+ """Test permute operation with int16 quantization on U55"""
195+ test_data , dims = test_data ()
196+ pipeline = EthosU55PipelineINT [input_t1 ](
197+ SimplePermute (dims = dims ),
198+ (test_data ,),
199+ aten_op ,
200+ exir_ops = [],
201+ per_channel_quantization = True ,
202+ use_to_edge_transform_and_lower = True ,
203+ atol = 1e-03 ,
204+ rtol = 1e-03 ,
205+ run_on_fvp = True ,
206+ )
207+
208+ pipeline .change_args (
209+ "quantize" ,
210+ get_symmetric_a16w8_permute_quantizer (
211+ per_channel_quantization = True
212+ ),
213+ )
214+ pipeline .run ()
215+
216+
217+ @common .parametrize ("test_data" , test_data_suite )
218+ @common .XfailIfNoCorstone320
219+ def test_permute_16a8w_u85_INT16 (test_data : torch .Tensor ):
220+ """Test permute operation with int16 quantization on U85"""
221+ test_data , dims = test_data ()
222+ pipeline = EthosU85PipelineINT [input_t1 ](
223+ SimplePermute (dims = dims ),
224+ (test_data ,),
225+ aten_op ,
226+ exir_ops = [],
227+ use_to_edge_transform_and_lower = True ,
228+ atol = 1e-03 ,
229+ rtol = 1e-03 ,
230+ run_on_fvp = True ,
231+ )
232+
233+ pipeline .change_args (
234+ "quantize" ,
235+ get_symmetric_a16w8_permute_quantizer (
236+ per_channel_quantization = False
237+ ),
238+ )
239+ pipeline .run ()
0 commit comments