|
| 1 | +# Copyright 2025 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import torch |
| 7 | +from executorch.backends.arm.quantizer import ( |
| 8 | + get_symmetric_a16w8_quantization_config, |
| 9 | + get_symmetric_quantization_config, |
| 10 | + is_annotated, |
| 11 | + QuantizationConfig, |
| 12 | + TOSAQuantizer, |
| 13 | +) |
| 14 | +from executorch.backends.arm.quantizer.quantization_config import QuantizationSpec |
| 15 | +from executorch.backends.arm.tosa import TosaSpecification |
| 16 | +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e |
| 17 | + |
| 18 | +DQ_PER_CHANNEL = torch.ops.quantized_decomposed.dequantize_per_channel.default |
| 19 | +DQ_PER_TENSOR = torch.ops.quantized_decomposed.dequantize_per_tensor.default |
| 20 | +Q_PER_TENSOR = torch.ops.quantized_decomposed.quantize_per_tensor.default |
| 21 | + |
| 22 | + |
| 23 | +class ConvModel(torch.nn.Module): |
| 24 | + def __init__(self): |
| 25 | + super().__init__() |
| 26 | + self.conv0 = torch.nn.Conv2d( |
| 27 | + 3, |
| 28 | + 16, |
| 29 | + kernel_size=4, |
| 30 | + ) |
| 31 | + self.conv1 = torch.nn.Conv2d(16, 32, kernel_size=3, bias=False) |
| 32 | + self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3) |
| 33 | + |
| 34 | + def forward(self, x): |
| 35 | + x = self.conv0(x) |
| 36 | + x = torch.sigmoid(x) |
| 37 | + x = self.conv1(x) |
| 38 | + x = torch.tanh(x) |
| 39 | + x = self.conv2(x) |
| 40 | + return x |
| 41 | + |
| 42 | + |
| 43 | +test_inputs = (torch.randn(1, 3, 64, 64),) |
| 44 | + |
| 45 | + |
| 46 | +def validate_per_tensor_quant(node: torch.fx.Node, qspec: QuantizationSpec): |
| 47 | + _, _, zero_point, qmin, qmax, dtype = node.args |
| 48 | + if qspec.qscheme == torch.per_tensor_symmetric: |
| 49 | + assert ( |
| 50 | + zero_point == 0 |
| 51 | + ), f"Zero point {zero_point} is not zero for symmetric quantization" |
| 52 | + assert ( |
| 53 | + qmin == qspec.quant_min |
| 54 | + ), f"Quant min {qmin} does not match expected {qspec.quant_min}" |
| 55 | + assert ( |
| 56 | + qmax == qspec.quant_max |
| 57 | + ), f"Quant max {qmax} does not match expected {qspec.quant_max}" |
| 58 | + assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}" |
| 59 | + |
| 60 | + |
| 61 | +def validate_per_channel_quant(node: torch.fx.Node, qspec: QuantizationSpec): |
| 62 | + _, _, _, channel_axis, qmin, qmax, dtype = node.args |
| 63 | + assert ( |
| 64 | + channel_axis == qspec.ch_axis |
| 65 | + ), f"Channel axis {channel_axis} does not match expected {qspec.ch_axis}" |
| 66 | + assert ( |
| 67 | + qmin == qspec.quant_min |
| 68 | + ), f"Quant min {qmin} does not match expected {qspec.quant_min}" |
| 69 | + assert ( |
| 70 | + qmax == qspec.quant_max |
| 71 | + ), f"Quant max {qmax} does not match expected {qspec.quant_max}" |
| 72 | + assert dtype == qspec.dtype, f"Dtype {dtype} does not match expected {qspec.dtype}" |
| 73 | + |
| 74 | + |
| 75 | +def validate_input(input_node: torch.fx.Node, qspec: QuantizationSpec | None): |
| 76 | + if qspec is None: |
| 77 | + return |
| 78 | + |
| 79 | + per_channel = qspec.qscheme == torch.per_channel_symmetric |
| 80 | + expected_dequant_op = DQ_PER_CHANNEL if per_channel else DQ_PER_TENSOR |
| 81 | + assert ( |
| 82 | + input_node.target == expected_dequant_op |
| 83 | + ), f"Input node {input_node} is not quantized as expected" |
| 84 | + if per_channel: |
| 85 | + validate_per_channel_quant(input_node, qspec) |
| 86 | + else: |
| 87 | + validate_per_tensor_quant(input_node, qspec) |
| 88 | + |
| 89 | + |
| 90 | +def validate_output(node: torch.fx.Node, qspec: QuantizationSpec | None): |
| 91 | + if qspec is None: |
| 92 | + return |
| 93 | + users = list(node.users) |
| 94 | + assert len(users) == 1, f"Node {node} should have exactly one user" |
| 95 | + assert ( |
| 96 | + users[0].target == Q_PER_TENSOR |
| 97 | + ), f"Output node {users[0]} is not quantized as expected" |
| 98 | + validate_per_tensor_quant(users[0], qspec) |
| 99 | + |
| 100 | + |
| 101 | +def validate_node( |
| 102 | + node: torch.fx.Node, quantization_config: QuantizationConfig | None |
| 103 | +) -> None: |
| 104 | + if quantization_config is None: |
| 105 | + assert not is_annotated(node), f"Node {node} is unexpectedly annotated" |
| 106 | + return |
| 107 | + |
| 108 | + assert is_annotated(node), f"Node {node} is not annotated" |
| 109 | + input_qspec = quantization_config.get_input_act_qspec() |
| 110 | + output_qspec = quantization_config.get_output_act_qspec() |
| 111 | + weight_qspec = quantization_config.get_weight_qspec() |
| 112 | + |
| 113 | + if len(node.all_input_nodes) == 3: |
| 114 | + input_node, weight_node, bias_node = node.all_input_nodes |
| 115 | + bias_qspec = quantization_config.get_bias_qspec(node) |
| 116 | + validate_input(bias_node, bias_qspec) |
| 117 | + else: |
| 118 | + input_node, weight_node = node.all_input_nodes |
| 119 | + |
| 120 | + validate_input(input_node, input_qspec) |
| 121 | + validate_input(weight_node, weight_qspec) |
| 122 | + validate_output(node, output_qspec) |
| 123 | + |
| 124 | + |
| 125 | +def test_set_module_name() -> None: |
| 126 | + model = ConvModel() |
| 127 | + model.eval() |
| 128 | + |
| 129 | + # Set up quantizer with different configs for different modules |
| 130 | + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") |
| 131 | + quantizer = TOSAQuantizer(tosa_spec) |
| 132 | + int8_config = get_symmetric_quantization_config(is_per_channel=False) |
| 133 | + a16w8_config = get_symmetric_a16w8_quantization_config() |
| 134 | + # Set module-specific configurations but don't set global config to test that |
| 135 | + # only specified modules are quantized |
| 136 | + quantizer.set_module_name("conv0", int8_config) |
| 137 | + quantizer.set_module_name("conv1", a16w8_config) |
| 138 | + |
| 139 | + # Export model |
| 140 | + exported_model = torch.export.export(model, test_inputs) |
| 141 | + |
| 142 | + # Prepare, calibrate and convert model |
| 143 | + prepared_model = prepare_pt2e(exported_model.module(), quantizer) |
| 144 | + prepared_model(*test_inputs) |
| 145 | + converted_model = convert_pt2e(prepared_model) |
| 146 | + |
| 147 | + validate_node( |
| 148 | + [node for node in converted_model.graph.nodes if node.name == "conv2d"][0], |
| 149 | + int8_config, |
| 150 | + ) |
| 151 | + validate_node( |
| 152 | + [node for node in converted_model.graph.nodes if node.name == "conv2d_1"][0], |
| 153 | + a16w8_config, |
| 154 | + ) |
| 155 | + validate_node( |
| 156 | + [node for node in converted_model.graph.nodes if node.name == "conv2d_2"][0], |
| 157 | + None, |
| 158 | + ) |
0 commit comments