|  | 
| 4 | 4 | # LICENSE file in the root directory of this source tree. | 
| 5 | 5 | 
 | 
| 6 | 6 | from copy import copy | 
| 7 |  | -from typing import cast, Set, Type | 
|  | 7 | +from typing import cast, Dict, Optional, Set, Tuple, Type | 
| 8 | 8 | 
 | 
| 9 |  | -from executorch.backends.arm._passes.arm_pass_utils import create_node | 
|  | 9 | +import torch | 
|  | 10 | +from executorch.backends.arm._passes.arm_pass import ArmPass | 
|  | 11 | +from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg | 
|  | 12 | +from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( | 
|  | 13 | +    get_output_qparams, | 
|  | 14 | +) | 
| 10 | 15 | from executorch.backends.arm._passes.quant_args import QuantArgs | 
| 11 | 16 | from executorch.backends.arm.constants import DQ_OPS, Q_OPS | 
| 12 | 17 | from executorch.exir.dialects._ops import ops as exir_ops | 
| @@ -65,3 +70,234 @@ def call(self, graph_module: GraphModule) -> PassResult: | 
| 65 | 70 |         graph_module = super().call(graph_module).graph_module | 
| 66 | 71 |         graph_module.recompile() | 
| 67 | 72 |         return PassResult(graph_module, modified) | 
|  | 73 | + | 
|  | 74 | + | 
|  | 75 | +class InsertRescaleInt32Pass(ArmPass): | 
|  | 76 | +    """ | 
|  | 77 | +    Numerous TOSA ops require inputs and outputs to be 32-bit integers in their | 
|  | 78 | +    quantized implementations. This pass treats such operator nodes by | 
|  | 79 | +    inserting rescale ops before and after them if needed. Note that extra logic | 
|  | 80 | +    that handles the scales and zero points must be in place because the affected | 
|  | 81 | +    TOSA have naive implementations that do not account for the quantization | 
|  | 82 | +    parameters. | 
|  | 83 | +    """ | 
|  | 84 | + | 
|  | 85 | +    _passes_required_after: Set[Type[ExportPass]] = set() | 
|  | 86 | + | 
|  | 87 | +    included_targets = [ | 
|  | 88 | +        exir_ops.edge.aten.abs.default, | 
|  | 89 | +        exir_ops.edge.aten.eq.Tensor, | 
|  | 90 | +        exir_ops.edge.aten.ge.Tensor, | 
|  | 91 | +        exir_ops.edge.aten.gt.Tensor, | 
|  | 92 | +        exir_ops.edge.aten.le.Tensor, | 
|  | 93 | +        exir_ops.edge.aten.lt.Tensor, | 
|  | 94 | +        exir_ops.edge.aten.maximum.default, | 
|  | 95 | +        exir_ops.edge.aten.minimum.default, | 
|  | 96 | +    ] | 
|  | 97 | + | 
|  | 98 | +    def _int32_qargs(self, s): | 
|  | 99 | +        """Helper creator function for INT32-based QuantArgs""" | 
|  | 100 | + | 
|  | 101 | +        return QuantArgs( | 
|  | 102 | +            scale=s, | 
|  | 103 | +            zp=0, | 
|  | 104 | +            qmin=torch.iinfo(torch.int32).min, | 
|  | 105 | +            qmax=torch.iinfo(torch.int32).max, | 
|  | 106 | +            dtype=torch.int32, | 
|  | 107 | +        ) | 
|  | 108 | + | 
|  | 109 | +    def _get_inputs_rescaled_qparams( | 
|  | 110 | +        self, target, input_qparams: Dict[int, QuantArgs] | 
|  | 111 | +    ) -> Dict[int, QuantArgs]: | 
|  | 112 | +        """Get the qparams for the INT32 operands to the op ``target`` | 
|  | 113 | +
 | 
|  | 114 | +        Inputs to the INT32-based operator must be rescaled from INT8 to INT32. | 
|  | 115 | +        This function computes the ``QuantArgs`` for each of the operands and returns | 
|  | 116 | +        it as a dict, mapping tensor index to ``QuantArgs``. | 
|  | 117 | +        """ | 
|  | 118 | + | 
|  | 119 | +        if target in [ | 
|  | 120 | +            exir_ops.edge.aten.abs.default, | 
|  | 121 | +            exir_ops.edge.aten.eq.Tensor, | 
|  | 122 | +            exir_ops.edge.aten.ge.Tensor, | 
|  | 123 | +            exir_ops.edge.aten.gt.Tensor, | 
|  | 124 | +            exir_ops.edge.aten.le.Tensor, | 
|  | 125 | +            exir_ops.edge.aten.lt.Tensor, | 
|  | 126 | +            exir_ops.edge.aten.minimum.default, | 
|  | 127 | +            exir_ops.edge.aten.maximum.default, | 
|  | 128 | +        ]: | 
|  | 129 | +            # For these ops, use the smallest scale among the INT8 operands. | 
|  | 130 | +            min_scale = min( | 
|  | 131 | +                [qp.get_scale_per_tensor() for qp in input_qparams.values()] | 
|  | 132 | +            ) | 
|  | 133 | +            qparams = { | 
|  | 134 | +                i: self._int32_qargs(min_scale) for i in range(len(input_qparams)) | 
|  | 135 | +            } | 
|  | 136 | +        else: | 
|  | 137 | +            raise ValueError(f"Not a valid target: {target}") | 
|  | 138 | + | 
|  | 139 | +        return qparams | 
|  | 140 | + | 
|  | 141 | +    def _get_output_qparams( | 
|  | 142 | +        self, target, inputs_qparams: Dict[int, QuantArgs] | 
|  | 143 | +    ) -> Optional[QuantArgs]: | 
|  | 144 | +        """Given an op ``target`` and the ``QuantArgs`` for each of its inputs, compute | 
|  | 145 | +        the scale of the output based on how the operator itself affects it.""" | 
|  | 146 | + | 
|  | 147 | +        if target in [ | 
|  | 148 | +            exir_ops.edge.aten.abs.default, | 
|  | 149 | +            exir_ops.edge.aten.maximum.default, | 
|  | 150 | +            exir_ops.edge.aten.minimum.default, | 
|  | 151 | +        ]: | 
|  | 152 | +            # The op has not altered the scale; the output scale is equal to | 
|  | 153 | +            # the operands' scales. | 
|  | 154 | +            return self._int32_qargs(inputs_qparams[0].get_scale_per_tensor()) | 
|  | 155 | +        elif target in [ | 
|  | 156 | +            exir_ops.edge.aten.eq.Tensor, | 
|  | 157 | +            exir_ops.edge.aten.ge.Tensor, | 
|  | 158 | +            exir_ops.edge.aten.gt.Tensor, | 
|  | 159 | +            exir_ops.edge.aten.le.Tensor, | 
|  | 160 | +            exir_ops.edge.aten.lt.Tensor, | 
|  | 161 | +        ]: | 
|  | 162 | +            # Output is bool for these ops and thus no qparams are present | 
|  | 163 | +            return None | 
|  | 164 | +        else: | 
|  | 165 | +            raise ValueError(f"Not a valid target: {target}") | 
|  | 166 | + | 
|  | 167 | +    def _get_rescale_qparams( | 
|  | 168 | +        self, target, input_qparams: Dict[int, QuantArgs] | 
|  | 169 | +    ) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]: | 
|  | 170 | +        """ | 
|  | 171 | +        Get the quantization parameters of the INT32 inputs/outputs that will | 
|  | 172 | +        surround the node after the new RESCALE ops have been inserted. | 
|  | 173 | +        """ | 
|  | 174 | + | 
|  | 175 | +        inputs_rescaled_qparams = self._get_inputs_rescaled_qparams( | 
|  | 176 | +            target, input_qparams | 
|  | 177 | +        ) | 
|  | 178 | +        output_qparams = self._get_output_qparams(target, inputs_rescaled_qparams) | 
|  | 179 | + | 
|  | 180 | +        return (inputs_rescaled_qparams, output_qparams) | 
|  | 181 | + | 
|  | 182 | +    def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> bool: | 
|  | 183 | +        qargs = node.meta["input_qparams"] | 
|  | 184 | + | 
|  | 185 | +        args_copy = list(node.args) | 
|  | 186 | +        seen_args = set() | 
|  | 187 | +        modified = False | 
|  | 188 | +        for i in qargs: | 
|  | 189 | +            qp = qargs[i] | 
|  | 190 | +            if qp.dtype != torch.int8: | 
|  | 191 | +                continue | 
|  | 192 | + | 
|  | 193 | +            arg_node = args_copy[i] | 
|  | 194 | +            if arg_node in seen_args: | 
|  | 195 | +                continue | 
|  | 196 | +            seen_args.add(arg_node) | 
|  | 197 | + | 
|  | 198 | +            with graph.inserting_after(arg_node): | 
|  | 199 | +                rescale_node = create_node( | 
|  | 200 | +                    graph, | 
|  | 201 | +                    exir_ops.backend.tosa.RESCALE.default, | 
|  | 202 | +                    ( | 
|  | 203 | +                        arg_node, | 
|  | 204 | +                        torch.int32, | 
|  | 205 | +                        qp.get_scale_per_tensor() | 
|  | 206 | +                        / rescale_qargs[ | 
|  | 207 | +                            i | 
|  | 208 | +                        ].get_scale_per_tensor(),  # Old scale / new scale | 
|  | 209 | +                        qp.get_zp_per_tensor(),  # Old zero point | 
|  | 210 | +                        rescale_qargs[i].get_zp_per_tensor(),  # New zero point | 
|  | 211 | +                    ), | 
|  | 212 | +                    from_node=node, | 
|  | 213 | +                ) | 
|  | 214 | + | 
|  | 215 | +                node.replace_input_with(arg_node, rescale_node) | 
|  | 216 | +                modified = True | 
|  | 217 | + | 
|  | 218 | +        return modified | 
|  | 219 | + | 
|  | 220 | +    def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> bool: | 
|  | 221 | +        if "output_qparams" not in node.meta or len(node.meta["output_qparams"]) == 0: | 
|  | 222 | +            return False | 
|  | 223 | + | 
|  | 224 | +        qargs = get_output_qparams(node) | 
|  | 225 | +        assert len(qargs) == 1 | 
|  | 226 | +        assert rescale_qargs is not None | 
|  | 227 | + | 
|  | 228 | +        qarg = qargs[0] | 
|  | 229 | +        if qarg.dtype != torch.int8: | 
|  | 230 | +            return False | 
|  | 231 | + | 
|  | 232 | +        users_copy = list(node.users) | 
|  | 233 | + | 
|  | 234 | +        with graph.inserting_after(node): | 
|  | 235 | +            rescale_node = create_node( | 
|  | 236 | +                graph, | 
|  | 237 | +                exir_ops.backend.tosa.RESCALE.default, | 
|  | 238 | +                ( | 
|  | 239 | +                    node, | 
|  | 240 | +                    torch.int8, | 
|  | 241 | +                    rescale_qargs.get_scale_per_tensor() | 
|  | 242 | +                    / qarg.get_scale_per_tensor(),  # Old scale / new scale | 
|  | 243 | +                    rescale_qargs.get_zp_per_tensor(),  # Old zero point | 
|  | 244 | +                    qarg.get_zp_per_tensor(),  # New zero point | 
|  | 245 | +                ), | 
|  | 246 | +                from_node=node, | 
|  | 247 | +            ) | 
|  | 248 | + | 
|  | 249 | +        for user in users_copy: | 
|  | 250 | +            user.replace_input_with(node, rescale_node) | 
|  | 251 | + | 
|  | 252 | +        return True | 
|  | 253 | + | 
|  | 254 | +    def call(self, graph_module: GraphModule) -> PassResult: | 
|  | 255 | +        graph = graph_module.graph | 
|  | 256 | + | 
|  | 257 | +        modified = False | 
|  | 258 | +        for node in list(graph.nodes): | 
|  | 259 | +            node = cast(Node, node) | 
|  | 260 | + | 
|  | 261 | +            if node.op != "call_function" or node.target not in self.included_targets: | 
|  | 262 | +                continue | 
|  | 263 | + | 
|  | 264 | +            if "input_qparams" not in node.meta or len(node.meta["input_qparams"]) == 0: | 
|  | 265 | +                continue | 
|  | 266 | +            input_qparams = node.meta["input_qparams"] | 
|  | 267 | + | 
|  | 268 | +            inputs_rescale_qargs, output_rescale_qargs = self._get_rescale_qparams( | 
|  | 269 | +                node.target, input_qparams | 
|  | 270 | +            ) | 
|  | 271 | + | 
|  | 272 | +            inputs_was_rescaled = self._rescale_inputs( | 
|  | 273 | +                graph, node, inputs_rescale_qargs | 
|  | 274 | +            ) | 
|  | 275 | +            outputs_was_rescaled = False | 
|  | 276 | +            if inputs_was_rescaled: | 
|  | 277 | +                outputs_was_rescaled = self._rescale_outputs( | 
|  | 278 | +                    graph, node, output_rescale_qargs | 
|  | 279 | +                ) | 
|  | 280 | +                modified = True | 
|  | 281 | + | 
|  | 282 | +            # Update node metadata | 
|  | 283 | + | 
|  | 284 | +            if inputs_was_rescaled: | 
|  | 285 | +                assert len(inputs_rescale_qargs) == len(node.meta["input_qparams"]) | 
|  | 286 | +                node.meta["input_qparams"] = inputs_rescale_qargs | 
|  | 287 | + | 
|  | 288 | +            if outputs_was_rescaled: | 
|  | 289 | +                assert len(node.meta["output_qparams"]) == 1 | 
|  | 290 | +                node.meta["output_qparams"] = {0: output_rescale_qargs} | 
|  | 291 | + | 
|  | 292 | +                # If the output type is specified in the node, change it such | 
|  | 293 | +                # that it matches the subsequent rescale node(s) that this node | 
|  | 294 | +                # now has output edges to. | 
|  | 295 | +                if "dtype" in node.kwargs: | 
|  | 296 | +                    set_node_arg(node, "dtype", torch.int32) | 
|  | 297 | + | 
|  | 298 | +        if modified: | 
|  | 299 | +            # Retrace the graph to update the fake tensor types | 
|  | 300 | +            graph_module = super().call(graph_module).graph_module | 
|  | 301 | +            graph_module.recompile() | 
|  | 302 | + | 
|  | 303 | +        return PassResult(graph_module, modified) | 
0 commit comments