99from typing import Tuple
1010
1111import torch
12+ from executorch .backends .arm .quantizer .arm_quantizer import (
13+ get_symmetric_a16w8_quantization_config ,
14+ TOSAQuantizer ,
15+ )
16+ from executorch .backends .arm .test import common , conftest
1217
1318from executorch .backends .arm .test import common
1419
1924 TosaPipelineINT ,
2025 VgfPipeline ,
2126)
22- from torchvision .ops import Permute
27+ from executorch .backends .arm .tosa import TosaSpecification
28+ from executorch .backends .xnnpack .test .tester import Quantize
2329
2430input_t1 = Tuple [torch .Tensor ] # Input x
2531
@@ -42,10 +48,10 @@ class SimplePermute(torch.nn.Module):
4248 def __init__ (self , dims : list [int ]):
4349 super ().__init__ ()
4450
45- self .permute = Permute ( dims = dims )
51+ self .dims = dims
4652
4753 def forward (self , x ):
48- return self .permute (x )
54+ return torch .permute (x , self . dims )
4955
5056
5157@common .parametrize ("test_data" , test_data_suite )
@@ -128,3 +134,107 @@ def test_permute_vgf_INT(test_data):
128134 tosa_version = "TOSA-1.0+INT" ,
129135 )
130136 pipeline .run ()
137+
138+
139+
140+ def get_symmetric_a16w8_permute_quantizer (
141+ u55_config = False , per_channel_quantization = False
142+ ):
143+ tosa_version = conftest .get_option ("tosa_version" )
144+ tosa_profiles = {
145+ "1.0" : TosaSpecification .create_from_string ("TOSA-1.0+INT+int16" ),
146+ }
147+
148+ quantizer = TOSAQuantizer (tosa_profiles [tosa_version ])
149+ quantizer .set_global (
150+ get_symmetric_a16w8_quantization_config (is_per_channel = per_channel_quantization )
151+ )
152+ quantizer .set_module_type (
153+ torch .nn .Linear ,
154+ get_symmetric_a16w8_quantization_config (
155+ is_per_channel = per_channel_quantization
156+ ),
157+ )
158+
159+ return Quantize (
160+ quantizer ,
161+ get_symmetric_a16w8_quantization_config (
162+ is_per_channel = per_channel_quantization
163+ ),
164+ )
165+
166+
167+ @common .parametrize ("test_data" , test_data_suite )
168+ def test_permute_16a8w_tosa_INT (test_data : torch .Tensor ):
169+ """Test permute operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
170+ # Create pipeline with custom 16A8W quantization config
171+ test_data , dims = test_data ()
172+ pipeline = TosaPipelineINT [input_t1 ](
173+ SimplePermute (dims = dims ),
174+ (test_data ,),
175+ aten_op ,
176+ exir_op = [],
177+ per_channel_quantization = False ,
178+ use_to_edge_transform_and_lower = True ,
179+ tosa_extensions = ["int16" ],
180+ )
181+
182+ pipeline .change_args (
183+ "quantize" ,
184+ get_symmetric_a16w8_permute_quantizer (
185+ per_channel_quantization = False
186+ ),
187+ )
188+ # Run the pipeline
189+ pipeline .run ()
190+
191+
192+ @common .parametrize ("test_data" , test_data_suite )
193+ @common .XfailIfNoCorstone300
194+ def test_permute_16a8w_u55_INT16 (test_data : torch .Tensor ):
195+ """Test permute operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
196+ test_data , dims = test_data ()
197+ pipeline = EthosU55PipelineINT [input_t1 ](
198+ SimplePermute (dims = dims ),
199+ (test_data ,),
200+ aten_op ,
201+ exir_ops = [],
202+ per_channel_quantization = True ,
203+ use_to_edge_transform_and_lower = True ,
204+ atol = 1e-03 ,
205+ rtol = 1e-03 ,
206+ run_on_fvp = True ,
207+ )
208+
209+ pipeline .change_args (
210+ "quantize" ,
211+ get_symmetric_a16w8_permute_quantizer (
212+ per_channel_quantization = True
213+ ),
214+ )
215+ pipeline .run ()
216+
217+
218+ @common .parametrize ("test_data" , test_data_suite )
219+ @common .XfailIfNoCorstone320
220+ def test_permute_16a8w_u85_INT16 (test_data : torch .Tensor ):
221+ """Test permute operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
222+ test_data , dims = test_data ()
223+ pipeline = EthosU85PipelineINT [input_t1 ](
224+ SimplePermute (dims = dims ),
225+ (test_data ,),
226+ aten_op ,
227+ exir_ops = [],
228+ use_to_edge_transform_and_lower = True ,
229+ atol = 1e-03 ,
230+ rtol = 1e-03 ,
231+ run_on_fvp = True ,
232+ )
233+
234+ pipeline .change_args (
235+ "quantize" ,
236+ get_symmetric_a16w8_permute_quantizer (
237+ per_channel_quantization = False
238+ ),
239+ )
240+ pipeline .run ()
0 commit comments