|  | 
| 8 | 8 | 
 | 
| 9 | 9 | from typing import Tuple | 
| 10 | 10 | 
 | 
|  | 11 | +import pytest | 
| 11 | 12 | import torch | 
| 12 |  | -from executorch.backends.arm.test import common | 
|  | 13 | +from executorch.backends.arm.quantizer.arm_quantizer import ( | 
|  | 14 | +    get_symmetric_a16w8_quantization_config, | 
|  | 15 | +    TOSAQuantizer, | 
|  | 16 | +) | 
|  | 17 | +from executorch.backends.arm.test import common, conftest | 
| 13 | 18 | 
 | 
| 14 | 19 | from executorch.backends.arm.test.tester.test_pipeline import ( | 
| 15 | 20 |     EthosU55PipelineINT, | 
|  | 
| 18 | 23 |     TosaPipelineINT, | 
| 19 | 24 |     VgfPipeline, | 
| 20 | 25 | ) | 
|  | 26 | +from executorch.backends.arm.tosa.specification import TosaSpecification | 
|  | 27 | +from executorch.backends.xnnpack.test.tester import Quantize | 
| 21 | 28 | 
 | 
| 22 | 29 | input_t1 = Tuple[torch.Tensor]  # Input x | 
| 23 | 30 | 
 | 
| @@ -151,3 +158,105 @@ def test_cat_vgf_INT(test_data: Tuple): | 
| 151 | 158 |         tosa_version="TOSA-1.0+INT", | 
| 152 | 159 |     ) | 
| 153 | 160 |     pipeline.run() | 
|  | 161 | + | 
|  | 162 | + | 
|  | 163 | +def get_symmetric_a16w8_cat_quantizer(per_channel_quantization=False): | 
|  | 164 | +    tosa_version = conftest.get_option("tosa_version") | 
|  | 165 | +    tosa_profiles = { | 
|  | 166 | +        "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), | 
|  | 167 | +    } | 
|  | 168 | + | 
|  | 169 | +    quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) | 
|  | 170 | +    quantizer.set_global( | 
|  | 171 | +        get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) | 
|  | 172 | +    ) | 
|  | 173 | + | 
|  | 174 | +    return Quantize( | 
|  | 175 | +        quantizer, | 
|  | 176 | +        get_symmetric_a16w8_quantization_config( | 
|  | 177 | +            is_per_channel=per_channel_quantization | 
|  | 178 | +        ), | 
|  | 179 | +    ) | 
|  | 180 | + | 
|  | 181 | + | 
|  | 182 | +@common.parametrize("test_data", Cat.test_parameters) | 
|  | 183 | +@pytest.mark.xfail( | 
|  | 184 | +    reason="missing int16 cat ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13978" | 
|  | 185 | +) | 
|  | 186 | +def test_cat_16a8w_tosa_INT(test_data: Tuple): | 
|  | 187 | +    """Test cat operation with 16A8W quantization (16-bit activations, 8-bit weights)""" | 
|  | 188 | +    per_channel_quantization = False | 
|  | 189 | + | 
|  | 190 | +    pipeline = TosaPipelineINT[input_t1]( | 
|  | 191 | +        Cat(), | 
|  | 192 | +        test_data(), | 
|  | 193 | +        aten_op, | 
|  | 194 | +        exir_op=[], | 
|  | 195 | +        per_channel_quantization=per_channel_quantization, | 
|  | 196 | +        use_to_edge_transform_and_lower=True, | 
|  | 197 | +        tosa_extensions=["int16"], | 
|  | 198 | +    ) | 
|  | 199 | + | 
|  | 200 | +    pipeline.change_args( | 
|  | 201 | +        "quantize", | 
|  | 202 | +        get_symmetric_a16w8_cat_quantizer( | 
|  | 203 | +            per_channel_quantization=per_channel_quantization | 
|  | 204 | +        ), | 
|  | 205 | +    ) | 
|  | 206 | +    pipeline.run() | 
|  | 207 | + | 
|  | 208 | + | 
|  | 209 | +@common.parametrize("test_data", Cat.test_parameters) | 
|  | 210 | +@common.XfailIfNoCorstone300 | 
|  | 211 | +@pytest.mark.xfail( | 
|  | 212 | +    reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations" | 
|  | 213 | +) | 
|  | 214 | +def test_cat_16a8w_u55_INT16(test_data: Tuple): | 
|  | 215 | +    """Test cat operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" | 
|  | 216 | +    per_channel_quantization = False | 
|  | 217 | + | 
|  | 218 | +    pipeline = EthosU55PipelineINT[input_t1]( | 
|  | 219 | +        Cat(), | 
|  | 220 | +        test_data(), | 
|  | 221 | +        aten_op, | 
|  | 222 | +        exir_op, | 
|  | 223 | +        per_channel_quantization=per_channel_quantization, | 
|  | 224 | +        use_to_edge_transform_and_lower=True, | 
|  | 225 | +        run_on_fvp=True, | 
|  | 226 | +    ) | 
|  | 227 | + | 
|  | 228 | +    pipeline.change_args( | 
|  | 229 | +        "quantize", | 
|  | 230 | +        get_symmetric_a16w8_cat_quantizer( | 
|  | 231 | +            per_channel_quantization=per_channel_quantization | 
|  | 232 | +        ), | 
|  | 233 | +    ) | 
|  | 234 | +    pipeline.run() | 
|  | 235 | + | 
|  | 236 | + | 
|  | 237 | +@common.parametrize("test_data", Cat.test_parameters) | 
|  | 238 | +@common.XfailIfNoCorstone320 | 
|  | 239 | +@pytest.mark.xfail( | 
|  | 240 | +    reason="Vela compilation fails with 'Invalid arguments' for int16 cat operations" | 
|  | 241 | +) | 
|  | 242 | +def test_cat_16a8w_u85_INT16(test_data: Tuple): | 
|  | 243 | +    """Test cat operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" | 
|  | 244 | +    per_channel_quantization = False | 
|  | 245 | + | 
|  | 246 | +    pipeline = EthosU85PipelineINT[input_t1]( | 
|  | 247 | +        Cat(), | 
|  | 248 | +        test_data(), | 
|  | 249 | +        aten_op, | 
|  | 250 | +        exir_op, | 
|  | 251 | +        per_channel_quantization=per_channel_quantization, | 
|  | 252 | +        use_to_edge_transform_and_lower=True, | 
|  | 253 | +        run_on_fvp=True, | 
|  | 254 | +    ) | 
|  | 255 | + | 
|  | 256 | +    pipeline.change_args( | 
|  | 257 | +        "quantize", | 
|  | 258 | +        get_symmetric_a16w8_cat_quantizer( | 
|  | 259 | +            per_channel_quantization=per_channel_quantization | 
|  | 260 | +        ), | 
|  | 261 | +    ) | 
|  | 262 | +    pipeline.run() | 
0 commit comments