diff --git a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp b/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp index 1430f72aa33..43d0b588403 100644 --- a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp +++ b/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp @@ -32,6 +32,28 @@ std::unique_ptr CreateQuantizationParamWrapper( quantize_param_wrapper = std::make_unique( axis, scale_offset); + } else if (encoding == QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { + uint32_t bitwidth = quant_info["bitwidth"].cast(); + int32_t axis = quant_info["axis"].cast(); + std::vector scale_offset = + quant_info["scale_offset"].cast>(); + uint32_t num_elements = scale_offset.size(); + std::vector scales; + std::vector offsets; + for (const auto& scale_offset : scale_offset) { + scales.push_back(scale_offset.scale); + offsets.push_back(scale_offset.offset); + } + quantize_param_wrapper = + std::make_unique( + bitwidth, axis, num_elements, scales, offsets); + } else if (encoding == QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) { + uint32_t bitwidth = quant_info["bitwidth"].cast(); + float scale = quant_info["scale"].cast(); + int32_t offset = quant_info["offset"].cast(); + quantize_param_wrapper = + std::make_unique( + bitwidth, scale, offset); } else if (encoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { float scale = quant_info["scale"].cast(); int32_t offset = quant_info["offset"].cast(); diff --git a/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp b/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp index 9051a337b2e..b7de6426cb2 100644 --- a/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp +++ b/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.cpp @@ -27,6 +27,33 @@ std::unique_ptr CreateQuantizationParamWrapper( quantize_param_wrapper = std::make_unique( quantization.axisScaleOffsetEncoding.axis, scale_offset); + } else if ( + quantization.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET) { + std::vector scales( + quantization.bwAxisScaleOffsetEncoding.scales, + quantization.bwAxisScaleOffsetEncoding.scales + + quantization.bwAxisScaleOffsetEncoding.numElements); + std::vector offsets( + quantization.bwAxisScaleOffsetEncoding.offsets, + quantization.bwAxisScaleOffsetEncoding.offsets + + quantization.bwAxisScaleOffsetEncoding.numElements); + + quantize_param_wrapper = + std::make_unique( + quantization.bwAxisScaleOffsetEncoding.bitwidth, + quantization.bwAxisScaleOffsetEncoding.axis, + quantization.bwAxisScaleOffsetEncoding.numElements, + scales, + offsets); + } else if ( + quantization.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET) { + quantize_param_wrapper = + std::make_unique( + quantization.bwScaleOffsetEncoding.bitwidth, + quantization.bwScaleOffsetEncoding.scale, + quantization.bwScaleOffsetEncoding.offset); } else if ( quantization.quantizationEncoding == QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { diff --git a/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h b/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h index 5bb6486b4c3..2cd594735ed 100644 --- a/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h +++ b/backends/qualcomm/aot/wrappers/QuantizeParamsWrapper.h @@ -77,6 +77,117 @@ class UndefinedQuantizeParamsWrapper final : public QuantizeParamsWrapper { } }; +class BwAxisScaleOffsetQuantizeParamsWrapper final + : public QuantizeParamsWrapper { + public: + explicit BwAxisScaleOffsetQuantizeParamsWrapper( + std::uint32_t bitwidth, + std::int32_t axis, + std::uint32_t num_elements, + std::vector scales, + std::vector offsets) + : QuantizeParamsWrapper( + QNN_DEFINITION_DEFINED, + QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET), + bitwidth_(bitwidth), + axis_(axis), + num_elements_(num_elements), + scales_(scales), + offsets_(offsets) {} + + BwAxisScaleOffsetQuantizeParamsWrapper( + const BwAxisScaleOffsetQuantizeParamsWrapper& rhs) + : QuantizeParamsWrapper( + rhs.GetEncodingDefinition(), + rhs.GetQuantizationEncoding()), + bitwidth_(rhs.bitwidth_), + axis_(rhs.axis_), + num_elements_(rhs.num_elements_), + scales_(rhs.scales_), + offsets_(rhs.offsets_) {} + BwAxisScaleOffsetQuantizeParamsWrapper( + BwAxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete; + BwAxisScaleOffsetQuantizeParamsWrapper& operator=( + const BwAxisScaleOffsetQuantizeParamsWrapper& rhs) = delete; + BwAxisScaleOffsetQuantizeParamsWrapper& operator=( + BwAxisScaleOffsetQuantizeParamsWrapper&& rhs) = delete; + + ~BwAxisScaleOffsetQuantizeParamsWrapper() override = default; + + std::unique_ptr Clone() override { + return std::make_unique(*this); + } + + Qnn_QuantizeParams_t CreateQuantizeParams() override { + Qnn_QuantizeParams_t rval; + rval.encodingDefinition = GetEncodingDefinition(); + rval.quantizationEncoding = GetQuantizationEncoding(); + rval.bwAxisScaleOffsetEncoding.bitwidth = bitwidth_; + rval.bwAxisScaleOffsetEncoding.axis = axis_; + rval.bwAxisScaleOffsetEncoding.numElements = num_elements_; + rval.bwAxisScaleOffsetEncoding.scales = scales_.data(); + rval.bwAxisScaleOffsetEncoding.offsets = offsets_.data(); + return rval; + } + + private: + std::uint32_t bitwidth_; + std::int32_t axis_; + std::uint32_t num_elements_; + std::vector scales_; + std::vector offsets_; +}; + +class BwScaleOffsetQuantizeParamsWrapper final : public QuantizeParamsWrapper { + public: + explicit BwScaleOffsetQuantizeParamsWrapper( + std::uint32_t bitwidth, + float scale, + std::int32_t offset) + : QuantizeParamsWrapper( + QNN_DEFINITION_DEFINED, + QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET), + bitwidth_(bitwidth), + scale_(scale), + offset_(offset) {} + + BwScaleOffsetQuantizeParamsWrapper( + const BwScaleOffsetQuantizeParamsWrapper& rhs) + : QuantizeParamsWrapper( + rhs.GetEncodingDefinition(), + rhs.GetQuantizationEncoding()), + bitwidth_(rhs.bitwidth_), + scale_(rhs.scale_), + offset_(rhs.offset_) {} + BwScaleOffsetQuantizeParamsWrapper(BwScaleOffsetQuantizeParamsWrapper&& rhs) = + delete; + BwScaleOffsetQuantizeParamsWrapper& operator=( + const BwScaleOffsetQuantizeParamsWrapper& rhs) = delete; + BwScaleOffsetQuantizeParamsWrapper& operator=( + BwScaleOffsetQuantizeParamsWrapper&& rhs) = delete; + + ~BwScaleOffsetQuantizeParamsWrapper() override = default; + + std::unique_ptr Clone() override { + return std::make_unique(*this); + } + + Qnn_QuantizeParams_t CreateQuantizeParams() override { + Qnn_QuantizeParams_t rval; + rval.encodingDefinition = GetEncodingDefinition(); + rval.quantizationEncoding = GetQuantizationEncoding(); + rval.bwScaleOffsetEncoding.bitwidth = bitwidth_; + rval.bwScaleOffsetEncoding.scale = scale_; + rval.bwScaleOffsetEncoding.offset = offset_; + return rval; + } + + private: + std::uint32_t bitwidth_; + float scale_; + std::int32_t offset_; +}; + class ScaleOffsetQuantizeParamsWrapper final : public QuantizeParamsWrapper { public: explicit ScaleOffsetQuantizeParamsWrapper(float scale, std::int32_t offset) diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index c6c7b47b38e..e5ce8ec2d74 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy from typing import Any, Dict, Tuple import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -38,16 +39,16 @@ float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, } -PER_CHANNEL_ENCODING_MAPPING = { - exir_ops.edge.quantized_decomposed.quantize_per_channel.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET, - exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET, +PER_CHANNEL_ENCODING = { + exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, } -PER_TENSOR_ENCODING_MAPPING = { - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, +PER_TENSOR_ENCODING = { + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, } @@ -87,6 +88,68 @@ def _get_tensor(node, index): tensor = tensor.permute(dims=op_node.meta["axis_order"]).contiguous() return tensor + def make_qnn_per_channel_config(self, node: torch.fx.Node, quant_attrs: Dict): + quant_config = copy.deepcopy(quant_attrs) + + scales = quant_attrs["scales"] + zero_points = quant_attrs["zero_points"] + assert len(scales) == len( + zero_points + ), f"Per channel encoding of node {node}, has different size for scales {len(scales)} and zero_points {len(zero_points)}" + + scale_offset = [] + for i in range(len(scales)): + # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h + scale_offset.append( + PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i]) + ) + + user_0 = list(node.users)[0] + # Memory layout of QNN conv weight always ends in Output. Like conv2d is HWIO + if ( + "convolution" in user_0.target.__name__ + and list(node.users)[0].args[1] == node + ): + quant_config["axis"] = 3 + + else: + quant_config["axis"] = quant_attrs["axis"] + + quant_config["scale_offset"] = scale_offset + # special case for 4 bits + if ( + quant_config["dtype"] == torch.int8 + and quant_config["quant_max"] - quant_config["quant_min"] <= 15 + ): + quant_config["bitwidth"] = 4 + return ( + PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET, + quant_config, + ) + return ( + PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET, + quant_config, + ) + + def make_qnn_per_tensor_config(self, quant_attrs: Dict): + quant_config = copy.deepcopy(quant_attrs) + # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h + quant_config["offset"] = -quant_attrs["zero_point"] + # special case for 4 bits + if ( + quant_config["dtype"] == torch.int8 + and quant_config["quant_max"] - quant_config["quant_min"] <= 15 + ): + quant_config["bitwidth"] = 4 + return ( + PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_BW_SCALE_OFFSET, + quant_config, + ) + return ( + PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_SCALE_OFFSET, + quant_config, + ) + def get_quant_encoding_conf(self, node: torch.fx.Node) -> Tuple[Any, Dict]: if not node.meta.get("quant_attrs", None): return ( @@ -99,66 +162,35 @@ def get_quant_encoding_conf(self, node: torch.fx.Node) -> Tuple[Any, Dict]: if "requantize" in node.meta else node.meta["quant_attrs"] ) - encoding = quant_attrs["encoding"] - - quant_config = {} - if encoding in PER_CHANNEL_ENCODING_MAPPING: - scales = quant_attrs["scales"] - zero_points = quant_attrs["zero_points"] - assert len(scales) == len( - zero_points - ), f"Per channel encoding of node {node}, has differnt size fo scales {len(scales)} and zero_points {len(zero_points)}" - - scale_offset = [] - for i in range(len(scales)): - scale_offset.append( - PyQnnWrapper.Qnn_ScaleOffset_t(scales[i], -zero_points[i]) - ) - user_0 = list(node.users)[0] - # Memory layout of QNN conv is NHW"C", need to set axis as 3 - if ( - type(user_0.target) != str - and user_0.target.__name__ in ["aten.convolution.default"] - and list(node.users)[0].args[1] == node - ): - quant_config["axis"] = 3 - else: - quant_config["axis"] = quant_attrs["axis"] - - quant_config["scale_offset"] = scale_offset - quant_config["quant_max"] = quant_attrs["quant_max"] - quant_config["quant_min"] = quant_attrs["quant_min"] - quant_config["dtype"] = quant_attrs["dtype"] - return PER_CHANNEL_ENCODING_MAPPING[encoding], quant_config - - # per tensor situation - quant_config["scale"] = quant_attrs["scale"] - # check Qnn_ScaleOffset_t in QNN/include/QnnTypes.h - quant_config["offset"] = -quant_attrs["zero_point"] - # Distinguish what data type the node is - quant_config["quant_max"] = quant_attrs["quant_max"] - quant_config["quant_min"] = quant_attrs["quant_min"] - quant_config["dtype"] = quant_attrs["dtype"] - return PER_TENSOR_ENCODING_MAPPING[encoding], quant_config + if quant_attrs["encoding"] in PER_CHANNEL_ENCODING: + return self.make_qnn_per_channel_config(node, quant_attrs) + + return self.make_qnn_per_tensor_config(quant_attrs) def get_quant_tensor_value( - self, node: torch.fx.Node, tensor: torch.Tensor, dtype + self, tensor: torch.Tensor, quant_attrs: Dict, dtype, bitwidth ) -> torch.Tensor: - quant_attrs = node.meta["quant_attrs"] - encoding = quant_attrs["encoding"] - - if encoding in PER_CHANNEL_ENCODING_MAPPING: - scales = quant_attrs["scales"] - offsets = quant_attrs["zero_points"] - return tensor.div(scales).add(offsets).round().to(quant_attrs["dtype"]) + if quant_attrs["encoding"] in PER_TENSOR_ENCODING: + scale = quant_attrs["scale"] + zero_point = quant_attrs["zero_point"] + else: # per channel case + scale = quant_attrs["scales"] + zero_point = quant_attrs["zero_points"] + + # To bypass torch.uint16 quantization is not supported + dtype = ( + torch.int32 + if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16 + else quant_attrs["dtype"] + ) - # per tensor situation - scale = quant_attrs["scale"] - offset = quant_attrs["zero_point"] - if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16: - return tensor.div(scale).add(offset).round().to(torch.int32) - return tensor.div(scale).add(offset).round().to(quant_attrs["dtype"]) + tensor = tensor.div(scale).add(zero_point).round().to(dtype) + # Make the backends access data correctly + if bitwidth == 4: + mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8) + tensor = torch.bitwise_and(mask, tensor) + return tensor def get_tensor_type( self, @@ -278,7 +310,12 @@ def define_value( ) else: if quant_configs: - tensor = self.get_quant_tensor_value(node, tensor, dtype) + tensor = self.get_quant_tensor_value( + tensor, + node.meta["quant_attrs"], + dtype, + quant_configs.get("bitwidth"), + ) tensor_wrapper = PyQnnWrapper.TensorWrapper( tensor_name, tensor_type, diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index d8a957de55c..f899e98efd4 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -248,6 +248,7 @@ def define_node( filter_node = node.args[1] filter_tensor = get_parameter(filter_node, self.edge_program) + # weight of pytorch OIHW, yet QNN is HWIO filter_axis_order = (2, 3, 1, 0) filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous() filter_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/passes/annotate_quant_attrs.py b/backends/qualcomm/passes/annotate_quant_attrs.py index 7501a8c8d1d..7ada90be6a6 100644 --- a/backends/qualcomm/passes/annotate_quant_attrs.py +++ b/backends/qualcomm/passes/annotate_quant_attrs.py @@ -7,11 +7,11 @@ from typing import Any, Dict import torch -from executorch.backends.qualcomm.builders.utils import set_parameter +from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from .utils import dq_ops, get_parameter, get_quant_attrs, q_ops +from .utils import dq_ops, get_quant_attrs, q_ops class AnnotateQuantAttrs(ExportPass): diff --git a/backends/qualcomm/passes/convert_to_linear.py b/backends/qualcomm/passes/convert_to_linear.py index 16fb151ddd3..6f96438c9c0 100644 --- a/backends/qualcomm/passes/convert_to_linear.py +++ b/backends/qualcomm/passes/convert_to_linear.py @@ -19,7 +19,7 @@ SourcePartition, ) -from .utils import dq_ops, q_ops +from .utils import dq_ops, get_quant_attrs, q_ops class ConvertToLinear(ExportPass): @@ -43,6 +43,7 @@ class ConvertToLinear(ExportPass): bmm_patterns = [ {view_copy: 3, permute_copy: 1, expand_copy: 2, add: 1, bmm: 1}, + {view_copy: 3, permute_copy: 1, expand_copy: 2, bmm: 1}, ] mm_patterns = [ @@ -60,22 +61,6 @@ def _get_original_input( cur_node = cur_node.args[0] return cur_node - def _annotate_quant_attrs( - self, gm: torch.fx.GraphModule, node: torch.fx.Node, q_node: torch.fx.Node - ) -> torch.fx.Node: - quant_attr_keys = [arg.name for arg in q_node.target._schema.arguments][1:] - quant_attrs = dict.fromkeys(quant_attr_keys) - - for i in range(1, len(q_node.args)): - attr_n = q_node.args[i] - value = attr_n - if type(attr_n) == torch.fx.node.Node: - value = getattr(gm, attr_n.target) - quant_attrs[quant_attr_keys[i - 1]] = value - quant_attrs["encoding"] = q_node.target - node.meta["quant_attrs"] = quant_attrs - return node - def _convert_to_linear( self, gm: torch.fx.GraphModule, @@ -100,8 +85,8 @@ def _convert_to_linear( # qnn htp does not support keepdim, the view_copy(reshape) should exist for now if self._get_original_input(inputs, input_node).target in dq_ops: - input_node = self._annotate_quant_attrs( - gm, input_node, self._get_original_input(inputs, input_node).args[0] + input_node.meta["quant_attrs"] = get_quant_attrs( + gm, self._get_original_input(inputs, input_node).args[0] ) args = [input_node, weight_node] if bias_node: @@ -113,8 +98,8 @@ def _convert_to_linear( ) linear_node.meta = fn_node.meta if list(output.users)[0].target in q_ops: - linear_node = self._annotate_quant_attrs( - gm, linear_node, list(output.users)[0] + linear_node.meta["quant_attrs"] = get_quant_attrs( + gm, list(output.users)[0] ) for user in fn_node.users.copy(): user.replace_input_with(fn_node, linear_node) @@ -138,14 +123,18 @@ def _extract_addmm_ops( def _extract_bmm_ops(self, partitioned_nodes: List[edge_op]) -> List[torch.fx.Node]: bmm_node = [n for n in partitioned_nodes if n.target == self.bmm][0] - add_node = [n for n in partitioned_nodes if n.target == self.add][0] + add_node = [n for n in partitioned_nodes if n.target == self.add] # weight -> expand_copy -> view_copy -> input of bmm weight_node = bmm_node.args[1].args[0].args[0].args[0] # input -> expand_copy -> view_copy -> input of bmm input_node = bmm_node.args[0].args[0].args[0] - bias_node = add_node.args[1] - return [input_node, weight_node, bias_node] + + ret = [input_node, weight_node, bmm_node] + if add_node: + bias_node = add_node[0].args[1] + ret += bias_node + return ret def _convert(self, graph_module: torch.fx.GraphModule): partitions = get_source_partitions(graph_module.graph, [torch.nn.Linear]) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 7cb07158c2a..674314d991c 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from enum import IntEnum, unique from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple import torch @@ -35,11 +36,24 @@ __all__ = [ "QnnQuantizer", - "get_default_8bit_qnn_ptq_config", + "QuantDtype", + "get_16a4w_qnn_ptq_config", "get_default_16bit_qnn_ptq_config", + "get_default_8bit_qnn_ptq_config", ] +@unique +class QuantDtype(IntEnum): + """ + bits of activation and bits of weight + """ + + use_16a16w = 0 + use_16a4w = 1 + use_8a8w = 2 + + def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: def _derive_bias_qparams_fn( obs_or_fqs: List, @@ -112,25 +126,64 @@ def get_default_8bit_qnn_ptq_config() -> QuantizationConfig: return quantization_config +# 4 bits quantization only supports specific ops. +def get_16a4w_qnn_ptq_config() -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + def get_default_16bit_qnn_ptq_config() -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-20} act_quantization_spec = QuantizationSpec( dtype=torch.int32, - quant_min=0, - quant_max=65535, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, qscheme=torch.per_tensor_affine, observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( dtype=torch.int16, - quant_min=-32767, - quant_max=32767, + quant_min=torch.iinfo(torch.int16).min + 1, + quant_max=torch.iinfo(torch.int16).max, qscheme=torch.per_tensor_symmetric, ch_axis=0, observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) + # torch does not support uint16 quantization, use int32 to bypass bias_quantization_spec = QuantizationSpec( dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min, @@ -150,34 +203,39 @@ def get_default_16bit_qnn_ptq_config() -> QuantizationConfig: def get_ptq_per_channel_weight_config( - input_dtype=torch.uint8, weight_dtype=torch.int8 + act_dtype=torch.uint8, weight_dtype=torch.int8 ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-12} - supported_types = { + supported_act_types = { torch.uint8, + torch.uint16, torch.int8, torch.int16, } + # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype + supported_weight_dtypes = {"int4", torch.int8, torch.int16} assert ( - input_dtype in supported_types - ), f"input_dtype, {input_dtype} is not one of supported_types, {supported_types}" + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + assert ( - weight_dtype in supported_types - ), f"weight_dtype, {input_dtype} is not one of supported_types, {supported_types}" + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + # torch do not support uint16 quantization, use int32 to bypass act_quantization_spec = QuantizationSpec( - dtype=input_dtype, - quant_min=torch.iinfo(input_dtype).min, - quant_max=torch.iinfo(input_dtype).max, + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, qscheme=torch.per_tensor_affine, observer_or_fake_quant_ctr=HistogramObserver.with_args(**extra_args), ) weight_quantization_spec = QuantizationSpec( - dtype=weight_dtype, - quant_min=torch.iinfo(weight_dtype).min + 1, - quant_max=torch.iinfo(weight_dtype).max, + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, qscheme=torch.per_channel_symmetric, ch_axis=0, observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), @@ -200,53 +258,34 @@ class QnnQuantizer(Quantizer): def __init__(self): super().__init__() - self.enable_per_channel_conv_quant: bool = True self.bit8_quant_config: QuantizationConfig = get_default_8bit_qnn_ptq_config() self.bit16_quant_config: QuantizationConfig = get_default_16bit_qnn_ptq_config() self.bit8_quant_ops: Set[OpOverload] = self.SUPPORTED_OPS.copy() self.bit16_quant_ops: Set[OpOverload] = set() - self.discard_nodes: Set[str] = set() self.custom_quant_annotations: Sequence[Callable] = [] + self.discard_nodes: Set[str] = set() - def set_per_channel_quant(self, enable: bool) -> None: - self.enable_per_channel_conv_quant = enable - - def set_bit8_op_quant_config(self, quantization_config: QuantizationConfig) -> None: - self.bit8_quant_config = quantization_config - - def set_bit16_op_quant_config( - self, quantization_config: QuantizationConfig - ) -> None: - self.bit16_quant_config = quantization_config - - def get_supported_ops(self) -> Set[OpOverload]: - return self.SUPPORTED_OPS - - def add_discard_nodes(self, nodes: Sequence[str]) -> None: - self.discard_nodes = set(nodes) - - def add_discard_ops(self, ops: Sequence[OpOverload]) -> None: - for op in ops: - if op in self.bit8_quant_ops: - self.bit8_quant_ops.remove(op) - if op in self.bit16_quant_ops: - self.bit16_quant_ops.remove(op) + self.enable_per_channel_conv_quant: bool = True + # the weight quantized for activation 8 bits and 16 bits + self.per_channel_weight_dtype: Dict = { + "8bit_act": torch.int8, + "16bit_act": torch.int16, + } - def add_custom_quant_annotations( - self, custom_quant_annotations: Sequence[Callable] - ) -> None: - self.custom_quant_annotations = custom_quant_annotations + def _annotate(self, gm: GraphModule) -> None: + for node in gm.graph.nodes: + if node.name in self.discard_nodes: + continue - def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None: - for op in ops: - assert ( - op in self.SUPPORTED_OPS - ), f"The annotation of op {op} is not implemented" + quant_config = self._get_quant_config(node.target) + if quant_config: + OP_ANNOTATOR[node.target](node, quant_config) - self.bit8_quant_ops.remove(op) - self.bit16_quant_ops.add(op) + def _annotate_custom_annotation(self, gm: GraphModule) -> None: + for annotation_func in self.custom_quant_annotations: + annotation_func(gm) def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig]: """ @@ -262,8 +301,12 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig torch.ops.aten.conv2d.default, ]: if op in self.bit16_quant_ops: - return get_ptq_per_channel_weight_config(torch.int16, torch.int16) - return get_ptq_per_channel_weight_config() + return get_ptq_per_channel_weight_config( + torch.uint16, self.per_channel_weight_dtype["16bit_act"] + ) + return get_ptq_per_channel_weight_config( + weight_dtype=self.per_channel_weight_dtype["8bit_act"] + ) if op in self.bit8_quant_ops: return self.bit8_quant_config @@ -273,28 +316,29 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig print(f"No quant config is implemented for op, {op}") - def transform_for_annotation(self, model: GraphModule) -> GraphModule: - model = RemoveClone()(model).graph_module - model = ReduceDynamicRange()(model).graph_module - model = ConvertHardsigmoid(quantization_capture=True)(model).graph_module - model = DecomposeScaledDotProductAttention()(model).graph_module - model = DecomposeSilu()(model).graph_module - model = ReplaceInfBuffer()(model).graph_module + def add_16bit_quant_ops(self, ops: Set[OpOverload]) -> None: + for op in ops: + assert ( + op in self.SUPPORTED_OPS + ), f"The annotation of op {op} is not implemented" - return model + self.bit8_quant_ops.remove(op) + self.bit16_quant_ops.add(op) - def _annotate(self, gm: GraphModule) -> None: - for node in gm.graph.nodes: - if node.name in self.discard_nodes: - continue + def add_custom_quant_annotations( + self, custom_quant_annotations: Sequence[Callable] + ) -> None: + self.custom_quant_annotations = custom_quant_annotations - quant_config = self._get_quant_config(node.target) - if quant_config: - OP_ANNOTATOR[node.target](node, quant_config) + def add_discard_nodes(self, nodes: Sequence[str]) -> None: + self.discard_nodes = set(nodes) - def _annotate_custom_annotation(self, gm: GraphModule) -> None: - for annotation_func in self.custom_quant_annotations: - annotation_func(gm) + def add_discard_ops(self, ops: Sequence[OpOverload]) -> None: + for op in ops: + if op in self.bit8_quant_ops: + self.bit8_quant_ops.remove(op) + if op in self.bit16_quant_ops: + self.bit16_quant_ops.remove(op) def annotate(self, model: GraphModule) -> GraphModule: self._annotate(model) @@ -302,5 +346,40 @@ def annotate(self, model: GraphModule) -> GraphModule: return model + def get_supported_ops(self) -> Set[OpOverload]: + return self.SUPPORTED_OPS + + def set_bit16_op_quant_config( + self, quantization_config: QuantizationConfig + ) -> None: + self.bit16_quant_config = quantization_config + + def set_bit8_op_quant_config(self, quantization_config: QuantizationConfig) -> None: + self.bit8_quant_config = quantization_config + + def set_per_channel_weight_dtype( + self, + weight_dtype_for_8bit_act: Optional[str | torch.dtype] = None, + weight_dtype_for_16bit_act: Optional[str | torch.dtype] = None, + ) -> None: + # TODO accept temporally str type. Remove it when torch support torch.int4 dtype + if weight_dtype_for_8bit_act: + self.per_channel_weight_dtype["8bit_act"] = weight_dtype_for_8bit_act + if weight_dtype_for_16bit_act: + self.per_channel_weight_dtype["16bit_act"] = weight_dtype_for_16bit_act + + def set_per_channel_quant(self, enable: bool) -> None: + self.enable_per_channel_conv_quant = enable + + def transform_for_annotation(self, model: GraphModule) -> GraphModule: + model = RemoveClone()(model).graph_module + model = ReduceDynamicRange()(model).graph_module + model = ConvertHardsigmoid(quantization_capture=True)(model).graph_module + model = DecomposeScaledDotProductAttention()(model).graph_module + model = DecomposeSilu()(model).graph_module + model = ReplaceInfBuffer()(model).graph_module + + return model + def validate(self, model: GraphModule) -> None: pass diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 5694fe8b198..bbc89276854 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -232,6 +232,21 @@ def forward(self, x): return self.second(self.first(x)) +class Conv2DSingle(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + + def forward(self, x): + return self.conv(x) + + class Conv2dAvgPool2d(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index be69ebfa55f..c5acd0016e1 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -10,7 +10,12 @@ from multiprocessing.connection import Listener import torch -from executorch.backends.qualcomm.tests.utils import QnnPartitioner, TestQNN, to_backend +from executorch.backends.qualcomm.tests.utils import ( + QnnPartitioner, + QuantDtype, + TestQNN, + to_backend, +) from executorch.backends.qualcomm.utils.utils import ( canonicalize_program, @@ -33,7 +38,6 @@ from executorch.examples.models.mobilenet_v3 import MV3Model from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel from executorch.examples.models.wav2letter import Wav2LetterModel -from executorch.examples.qualcomm.scripts.edsr import annotate_forward from executorch.exir.backend.backend_api import disable_validation from executorch.exir.program._program import EdgeCompileConfig, ExirExportedProgram @@ -479,6 +483,22 @@ def setUp(self): tensor_dump_output_path="", ) + def test_qnn_backend_16a4w_conv2d(self): + module = Conv2DSingle() # noqa: F405 + sample_input = (torch.randn([1, 1, 3, 3]),) + module = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_16a4w + ) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_16a4w_linear(self): + module = Linear() # noqa: F405 + sample_input = (torch.randn([3, 4]),) + module = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_16a4w + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_arange(self): module = Arange(5) # noqa: F405 sample_input = (torch.randn(5),) @@ -888,31 +908,10 @@ def test_qnn_backend_conv2d_max_pool2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_residual_block(self): - module = ResidualBlockModule() # noqa: F405 - sample_input = (torch.randn(1, 32, 28, 28),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) - - def test_qnn_backend_simple_model(self): - module = SimpleModel() # noqa: F405 - sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) - - def test_qnn_backend_view_permute_matmul(self): - module = ViewPermuteMatMul() # noqa: F405 - sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256])) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) - # check if requantization work - module = self.get_qdq_module(module, sample_input, use_16bit_quant=True) - self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_example_models(self): instances = [ {"module": DeepLabV3ResNet101Model(), "annotation": ()}, - {"module": EdsrModel(), "annotation": (annotate_forward,)}, + {"module": EdsrModel(), "annotation": ()}, {"module": InceptionV3Model(), "annotation": ()}, {"module": InceptionV4Model(), "annotation": ()}, {"module": Llama2Model(), "annotation": ()}, @@ -954,6 +953,29 @@ def test_qnn_backend_example_models(self): assert_output_equal=False, ) + def test_qnn_backend_residual_block(self): + module = ResidualBlockModule() # noqa: F405 + sample_input = (torch.randn(1, 32, 28, 28),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_simple_model(self): + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_view_permute_matmul(self): + module = ViewPermuteMatMul() # noqa: F405 + sample_input = (torch.randn([1, 8, 512]), torch.randn([1, 2, 8, 256])) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + # check if requantization work + module = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_16a16w + ) + self.lower_module_and_test_output(module, sample_input) + class TestQNNFloatingPointUtils(TestQNN): # TODO: refactor to support different backends @@ -1346,6 +1368,7 @@ def test_dummy_llama2(self): self.ip, "--port", str(self.port), + "--use_fp16", ] if self.host: cmds.extend(["--host", self.host]) @@ -1377,7 +1400,6 @@ def test_ptq_dummy_llama2(self): self.ip, "--port", str(self.port), - "--ptq", ] if self.host: cmds.extend(["--host", self.host]) @@ -1410,6 +1432,7 @@ def test_mobilebert(self): self.ip, "--port", str(self.port), + "--use_fp16", ] if self.host: cmds.extend(["--host", self.host]) @@ -1450,7 +1473,6 @@ def test_ptq_mobilebert(self): self.ip, "--port", str(self.port), - "--ptq", ] if self.host: cmds.extend(["--host", self.host]) diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index e521705f806..c39307e5203 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -16,8 +16,10 @@ from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.qnn_preprocess import QnnBackend from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a4w_qnn_ptq_config, get_default_16bit_qnn_ptq_config, QnnQuantizer, + QuantDtype, ) from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, @@ -55,6 +57,9 @@ class TestQNN(unittest.TestCase): image_dataset: Literal = "" pretrained_weight: Literal = "" online_prepare: bool = False + use_8a8w: str = "8a8w" + use_16a16w: str = "16a16w" + use_16a4w: str = "16a4w" def _assert_outputs_equal(self, model_output, ref_output): self.assertTrue(len(ref_output) == len(model_output)) @@ -179,7 +184,7 @@ def get_qdq_module( inputs: Tuple[torch.Tensor], is_conv_per_channel: Optional[bool] = True, custom_quant_annotations: Tuple[Callable] = (), - use_16bit_quant: Optional[bool] = False, + quant_dtype: QuantDtype = QuantDtype.use_8a8w, ) -> torch.fx.GraphModule: m = torch._export.capture_pre_autograd_graph(module, inputs) @@ -187,9 +192,17 @@ def get_qdq_module( quantizer.add_custom_quant_annotations(custom_quant_annotations) quantizer.set_per_channel_quant(is_conv_per_channel) - if use_16bit_quant: + if quant_dtype == QuantDtype.use_8a8w: + pass # default setting + elif quant_dtype == QuantDtype.use_16a16w: quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) + elif quant_dtype == QuantDtype.use_16a4w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) + quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") + else: + raise AssertionError(f"No support for QuantDtype {quant_dtype}.") prepared = prepare_pt2e(m, quantizer) prepared(*inputs) diff --git a/examples/qualcomm/scripts/deeplab_v3.py b/examples/qualcomm/scripts/deeplab_v3.py index 6d94ef64f97..133e64d8568 100755 --- a/examples/qualcomm/scripts/deeplab_v3.py +++ b/examples/qualcomm/scripts/deeplab_v3.py @@ -13,6 +13,7 @@ import numpy as np +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model from executorch.examples.qualcomm.scripts.utils import ( build_executorch_binary, @@ -107,6 +108,7 @@ def get_dataset(data_size, dataset_dir, download): inputs, skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, ) if args.compile_only: diff --git a/examples/qualcomm/scripts/dummy_llama2.py b/examples/qualcomm/scripts/dummy_llama2.py index 94e2e323c1d..3e7cbb6d35c 100755 --- a/examples/qualcomm/scripts/dummy_llama2.py +++ b/examples/qualcomm/scripts/dummy_llama2.py @@ -11,6 +11,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.models.llama2 import Llama2Model from executorch.examples.qualcomm.scripts.utils import ( build_executorch_binary, @@ -56,13 +57,20 @@ def create_device_inputs(example_inputs, use_kv_cache): ) parser.add_argument( - "-P", - "--ptq", - help="If specified, will do PTQ.", + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", action="store_true", default=False, ) + parser.add_argument( + "-P", + "--ptq", + help="If specified, will do PTQ quantization. default is 8bits activation and 8bits weight. Support 8a8w, 16a16w and 16a4w.", + default="8a8w", + ) + # QNN_SDK_ROOT might also be an argument, but it is used in various places. # So maybe it's fine to just use the environment. if "QNN_SDK_ROOT" not in os.environ: @@ -89,7 +97,20 @@ def create_device_inputs(example_inputs, use_kv_cache): pte_filename = "dummy_llama2_qnn" - use_fp16 = False if args.ptq else True + if args.ptq == "8a8w": + quant_dtype = QuantDtype.use_8a8w + elif args.ptq == "16a16w": + quant_dtype = QuantDtype.use_16a16w + elif args.ptq == "16a4w": + quant_dtype = QuantDtype.use_16a4w + else: + raise AssertionError( + f"No support for quant type {args.ptq}. Support 8a8w, 16a16w and 16a4w." + ) + + if args.use_fp16: + quant_dtype = None + build_executorch_binary( instance.get_eager_model().eval(), inputs, @@ -97,7 +118,7 @@ def create_device_inputs(example_inputs, use_kv_cache): f"{args.artifact}/{pte_filename}", inputs, custom_annotations=(), - use_fp16=use_fp16, + quant_dtype=quant_dtype, ) adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), diff --git a/examples/qualcomm/scripts/edsr.py b/examples/qualcomm/scripts/edsr.py index 2cb47614f82..f844b094c03 100755 --- a/examples/qualcomm/scripts/edsr.py +++ b/examples/qualcomm/scripts/edsr.py @@ -13,6 +13,7 @@ import numpy as np import piq import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.models.edsr import EdsrModel from executorch.examples.qualcomm.scripts.utils import ( build_executorch_binary, @@ -90,66 +91,6 @@ def get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str return SrDataset(hr_dir, lr_dir) -def annotate_forward(gm: torch.fx.GraphModule) -> None: - """ - This function is specific for EDSR. It constructs a nn module, which is - inherited from nn.conv2d. - The source_fn of the rewritten nn module turns out to be a string "forward" - """ - import itertools - - from executorch.backends.qualcomm.quantizer.quantizer import ( - get_ptq_per_channel_weight_config, - ) - from executorch.backends.qualcomm.quantizer.utils import ( - _is_annotated, - QUANT_ANNOTATION_KEY, - ) - from torch.ao.quantization.quantize_pt2e import QuantizationAnnotation - from torch.fx import Node - from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - conv_partitions = get_source_partitions(gm.graph, ["forward"]) - conv_partitions = list(itertools.chain(*conv_partitions.values())) - quantization_config = get_ptq_per_channel_weight_config() - for conv_partition in conv_partitions: - if len(conv_partition.output_nodes) > 1: - raise ValueError("conv partition has more than one output node") - conv_node = conv_partition.output_nodes[0] - if ( - conv_node.op != "call_function" - or conv_node.target != torch.ops.aten.conv2d.default - ): - raise ValueError(f"{conv_node} is not an aten conv2d operator") - # skip annotation if it is already annotated - if _is_annotated([conv_node]): - continue - - input_qspec_map = {} - input_act = conv_node.args[0] - assert isinstance(input_act, Node) - input_spec = quantization_config.input_activation - input_qspec_map[input_act] = input_spec - - weight = conv_node.args[1] - assert isinstance(weight, Node) - input_qspec_map[weight] = quantization_config.weight - - if len(conv_node.args) > 2: - bias = conv_node.args[2] - if isinstance(bias, Node): - if callable(quantization_config.bias): - input_qspec_map[bias] = quantization_config.bias(conv_node) - else: - input_qspec_map[bias] = quantization_config.bias - - conv_node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) - - if __name__ == "__main__": parser = setup_common_args_and_variables() @@ -212,9 +153,9 @@ def annotate_forward(gm: torch.fx.GraphModule) -> None: args.model, f"{args.artifact}/{pte_filename}", [(input,) for input in inputs], - custom_annotations=(annotate_forward,), skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, ) if args.compile_only: diff --git a/examples/qualcomm/scripts/inception_v3.py b/examples/qualcomm/scripts/inception_v3.py index 842ba5dcd43..244e38edbe5 100755 --- a/examples/qualcomm/scripts/inception_v3.py +++ b/examples/qualcomm/scripts/inception_v3.py @@ -12,6 +12,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.models.inception_v3.model import InceptionV3Model from executorch.examples.qualcomm.scripts.utils import ( build_executorch_binary, @@ -109,6 +110,7 @@ def get_data_loader(): inputs, skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, ) if args.compile_only: diff --git a/examples/qualcomm/scripts/inception_v4.py b/examples/qualcomm/scripts/inception_v4.py index 2b6338cd908..db3feda2708 100755 --- a/examples/qualcomm/scripts/inception_v4.py +++ b/examples/qualcomm/scripts/inception_v4.py @@ -12,6 +12,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.models.inception_v4 import InceptionV4Model from executorch.examples.qualcomm.scripts.utils import ( build_executorch_binary, @@ -108,6 +109,7 @@ def get_data_loader(): inputs, skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, ) if args.compile_only: diff --git a/examples/qualcomm/scripts/mobilebert_fine_tune.py b/examples/qualcomm/scripts/mobilebert_fine_tune.py index d241d5da3ee..dc148afa8eb 100755 --- a/examples/qualcomm/scripts/mobilebert_fine_tune.py +++ b/examples/qualcomm/scripts/mobilebert_fine_tune.py @@ -12,6 +12,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.qualcomm.scripts.utils import ( build_executorch_binary, make_output_dir, @@ -236,9 +237,9 @@ def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size): ) parser.add_argument( - "-u", - "--use_16bit_quant", - help="If specified, quantize model with 16 bits, otherwise, quantize model with 8 bits. Will only be used when ptq set to true.", + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", action="store_true", default=False, ) @@ -246,9 +247,8 @@ def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size): parser.add_argument( "-P", "--ptq", - help="If specified, will do PTQ.", - action="store_true", - default=False, + help="If specified, will do PTQ quantization. default is 8bits activation and 8bits weight. Support 8a8w, 16a16w and 16a4w.", + default="8a8w", ) args = parser.parse_args() @@ -271,30 +271,31 @@ def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size): ) inputs, input_list = get_dataset(data_val) - if args.ptq: - build_executorch_binary( - model, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - use_fp16=False, - use_16bit_quant=args.use_16bit_quant, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - ) + if args.ptq == "8a8w": + quant_dtype = QuantDtype.use_8a8w + elif args.ptq == "16a16w": + quant_dtype = QuantDtype.use_16a16w + elif args.ptq == "16a4w": + quant_dtype = QuantDtype.use_16a4w else: - build_executorch_binary( - model, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - None, - use_fp16=True, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, + raise AssertionError( + f"No support for quant type {args.ptq}. Support 8a8w, 16a16w and 16a4w." ) + if args.use_fp16: + quant_dtype = None + + build_executorch_binary( + model, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=quant_dtype, + ) + if args.compile_only: sys.exit(0) diff --git a/examples/qualcomm/scripts/mobilenet_v2.py b/examples/qualcomm/scripts/mobilenet_v2.py index cbd6ad9de67..5f214a6f8ca 100755 --- a/examples/qualcomm/scripts/mobilenet_v2.py +++ b/examples/qualcomm/scripts/mobilenet_v2.py @@ -12,6 +12,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.models.mobilenet_v2 import MV2Model from executorch.examples.qualcomm.scripts.utils import ( build_executorch_binary, @@ -109,6 +110,7 @@ def get_data_loader(): inputs, skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, ) if args.compile_only: diff --git a/examples/qualcomm/scripts/torchvision_vit.py b/examples/qualcomm/scripts/torchvision_vit.py index 4717b3e7dbe..ff22f93c4f4 100755 --- a/examples/qualcomm/scripts/torchvision_vit.py +++ b/examples/qualcomm/scripts/torchvision_vit.py @@ -12,6 +12,7 @@ import numpy as np import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel from executorch.examples.qualcomm.scripts.utils import ( build_executorch_binary, @@ -148,6 +149,7 @@ def get_data_loader(): args.model, f"{args.artifact}/{pte_filename}", inputs, + quant_dtype=QuantDtype.use_8a8w, ) # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable diff --git a/examples/qualcomm/scripts/utils.py b/examples/qualcomm/scripts/utils.py index 0ff80cb54f2..a0a0e8725ac 100755 --- a/examples/qualcomm/scripts/utils.py +++ b/examples/qualcomm/scripts/utils.py @@ -10,14 +10,17 @@ import sys from pathlib import Path +from typing import Optional + import numpy as np import torch from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a4w_qnn_ptq_config, get_default_16bit_qnn_ptq_config, - get_default_8bit_qnn_ptq_config, QnnQuantizer, + QuantDtype, ) from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, @@ -143,20 +146,26 @@ def build_executorch_binary( soc_model, file_name, dataset, - use_fp16=False, - use_16bit_quant=False, custom_annotations=(), skip_node_id_set=None, skip_node_op_set=None, + quant_dtype: Optional[QuantDtype] = None, ): - if not use_fp16: + if quant_dtype: quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) - if use_16bit_quant: + + if quant_dtype == QuantDtype.use_8a8w: + pass # default setting + elif quant_dtype == QuantDtype.use_16a16w: quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) + elif quant_dtype == QuantDtype.use_16a4w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) + quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") else: - quantizer.set_bit8_op_quant_config(get_default_8bit_qnn_ptq_config()) + raise AssertionError(f"No support for QuantDtype {quant_dtype}.") captured_model = torch._export.capture_pre_autograd_graph(model, inputs) annotated_model = prepare_pt2e(captured_model, quantizer) @@ -177,7 +186,9 @@ def build_executorch_binary( "SM8450": QcomChipset.SM8450, } - backend_options = generate_htp_compiler_spec(use_fp16=use_fp16) + backend_options = generate_htp_compiler_spec( + use_fp16=False if quant_dtype else True + ) qnn_partitioner = QnnPartitioner( generate_qnn_executorch_compiler_spec( soc_model=arch_table[soc_model],