diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index b63a5583b10..5bd02ce53af 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -10,7 +10,6 @@ op_avg_pool2d, op_batch_norm, op_bmm, - op_cast, op_cat, op_ceil, op_clamp, @@ -41,9 +40,13 @@ op_skip_ops, op_slice_copy, op_softmax, + op_split, + op_sqrt, op_squeeze, op_sub, + op_sum_int_list, op_tanh, + op_to, op_transpose, op_unsqueeze, op_upsample_bilinear2d, @@ -55,7 +58,6 @@ op_avg_pool2d, op_batch_norm, op_bmm, - op_cast, op_cat, op_ceil, op_clamp, @@ -85,9 +87,13 @@ op_skip_ops, op_slice_copy, op_softmax, + op_split, op_squeeze, + op_sqrt, op_sub, + op_sum_int_list, op_tanh, + op_to, op_transpose, op_unsqueeze, op_upsample_bilinear2d, diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 3dae32f882e..060dd77fa66 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -14,9 +14,13 @@ from executorch.exir.dialects._ops import ops as exir_ops -from .qnn_constants import QNN_uint16 - -from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter +from .utils import ( + deduce_dtype, + get_parameter, + is_graph_input, + is_graph_output, + is_parameter, +) QNN_QUANT_TYPE_MAP = { @@ -26,16 +30,17 @@ # Note that there is no int64 tensor data type in Qnn. torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UNDEFINED, torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, - QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, + torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, } QNN_TENSOR_TYPE_MAP = { + torch.bool: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, torch.float32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, torch.int8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_8, torch.int16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_16, torch.int32: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, torch.int64: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_64, torch.uint8: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_8, - QNN_uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16, + torch.uint16: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16, float: PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, } @@ -169,7 +174,7 @@ def get_quant_encoding_conf( return self.make_qnn_per_tensor_config(quant_attrs) def get_quant_tensor_value( - self, tensor: torch.Tensor, quant_attrs: Dict, dtype, bitwidth + self, tensor: torch.Tensor, quant_attrs: Dict, quant_configs: Dict ) -> torch.Tensor: if quant_attrs["encoding"] in PER_TENSOR_ENCODING: scale = quant_attrs["scale"] @@ -178,16 +183,11 @@ def get_quant_tensor_value( scale = quant_attrs["scales"] zero_point = quant_attrs["zero_points"] - # To bypass torch.uint16 quantization is not supported - dtype = ( - torch.int32 - if dtype == PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_16 - else quant_attrs["dtype"] - ) + dtype = quant_configs["dtype"] tensor = tensor.div(scale).add(zero_point).round().to(dtype) # Make the backends access data correctly - if bitwidth == 4: + if quant_configs.get("bitwidth") == 4: mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8) tensor = torch.bitwise_and(mask, tensor) return tensor @@ -221,24 +221,9 @@ def get_data_type( self, tensor: torch.Tensor, quant_config: Dict, - is_tensor: bool, ) -> PyQnnWrapper.Qnn_TensorType_t: - if quant_config and is_tensor: - quant_range = quant_config["quant_max"] - quant_config["quant_min"] - unsigned = quant_config["quant_min"] >= 0 - if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min: - if unsigned: - quant_config["dtype"] = torch.uint8 - else: - quant_config["dtype"] = torch.int8 - elif ( - quant_range - <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min - ): - if unsigned: - quant_config["dtype"] = QNN_uint16 - else: - quant_config["dtype"] = torch.int16 + if quant_config: + quant_config["dtype"] = deduce_dtype(tensor, quant_config) return QNN_QUANT_TYPE_MAP[quant_config["dtype"]] else: return QNN_TENSOR_TYPE_MAP[tensor.dtype] @@ -283,7 +268,7 @@ def define_tensor( nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], is_input_tensor: bool, node_name: str = None, - is_tensor: bool = True, + wrapper_idx: int = 0, ) -> PyQnnWrapper.TensorWrapper: """ Covert torch.Tensor to TensorWrapper @@ -299,9 +284,12 @@ def define_tensor( if node_name is None: node_name = node.name - if node_name in nodes_to_wrappers: - return nodes_to_wrappers[node_name] + if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): + return cached + tensor_name = node.name + if is_graph_input(node, self.edge_program): + tensor_name = "QnnInput_" + str(self.external_ids[node]) + "_" + tensor_name if is_graph_output(node): tensor_name = "output_" + tensor_name dims = [1] if len(tensor.size()) == 0 else tensor.size() @@ -309,7 +297,7 @@ def define_tensor( quant_encoding, quant_configs = self.get_quant_encoding_conf( node, is_input_tensor ) - dtype = self.get_data_type(tensor, quant_configs, is_tensor) + dtype = self.get_data_type(tensor, quant_configs) if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): tensor_wrapper = PyQnnWrapper.TensorWrapper( tensor_name, @@ -327,8 +315,7 @@ def define_tensor( tensor = self.get_quant_tensor_value( tensor, node.meta["quant_attrs"], - dtype, - quant_configs.get("bitwidth"), + quant_configs, ) tensor_wrapper = PyQnnWrapper.TensorWrapper( tensor_name, @@ -341,7 +328,7 @@ def define_tensor( tensor.detach().numpy(), True, ) - nodes_to_wrappers[node_name] = tensor_wrapper + nodes_to_wrappers[node_name][wrapper_idx] = tensor_wrapper return tensor_wrapper def define_node( diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py index 905578790c0..a5d6aae1702 100644 --- a/backends/qualcomm/builders/op_embedding.py +++ b/backends/qualcomm/builders/op_embedding.py @@ -34,7 +34,7 @@ def define_node( weight_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, + is_input_tensor=True, ) indices_node = node.args[1] diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index 78d1e6244e9..9a593528219 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -62,7 +62,7 @@ def define_node( bias_node = node.args[2] # TODO remove this when qnn sdk support - if "scales" in bias_node.meta.get("quant_attrs"): + if "scales" in bias_node.meta.get("quant_attrs", {}): print( f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet." ) diff --git a/backends/qualcomm/builders/op_log_softmax.py b/backends/qualcomm/builders/op_log_softmax.py index c159b9bf00e..002dd5bc9b2 100644 --- a/backends/qualcomm/builders/op_log_softmax.py +++ b/backends/qualcomm/builders/op_log_softmax.py @@ -72,5 +72,4 @@ def define_node( PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {"data": np.uint32(dim)}, ) - # pdb.set_trace() return log_softmax_op diff --git a/backends/qualcomm/builders/op_skip_ops.py b/backends/qualcomm/builders/op_skip_ops.py index 9a1839f604e..837fb84d3ca 100644 --- a/backends/qualcomm/builders/op_skip_ops.py +++ b/backends/qualcomm/builders/op_skip_ops.py @@ -46,5 +46,7 @@ def define_node( raise AssertionError( f"Invalid number of index for {node.name }: {len(node.args[1])}" ) - nodes_to_wrappers[node.name] = nodes_to_wrappers.get(node.args[0].name) + nodes_to_wrappers[node.name] = { + 0: nodes_to_wrappers.get(node.args[0].name).get(node.args[1]) + } return diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py index 7972fb3dd92..3a294e35486 100644 --- a/backends/qualcomm/builders/op_slice_copy.py +++ b/backends/qualcomm/builders/op_slice_copy.py @@ -61,7 +61,9 @@ def define_node( ranges = [] for i in range(input_tensor_rank): if i == dim: - ranges.extend([start, end, 1]) + # find step + step = node.args[4] if len(node.args) > 4 else 1 + ranges.extend([start, end, step]) else: ranges.extend([0, input_tensor.shape[i], 1]) diff --git a/backends/qualcomm/builders/op_split.py b/backends/qualcomm/builders/op_split.py new file mode 100644 index 00000000000..00bfb3e556d --- /dev/null +++ b/backends/qualcomm/builders/op_split.py @@ -0,0 +1,85 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast, Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpSplit, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Split(NodeVisitor): + target = ["aten.split_with_sizes.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + split_input_tensors = [input_tensor_wrapper] + + axis = 0 if len(node.args) < 3 else cast(int, node.args[2]) + if axis < 0: + axis = axis % len(input_tensor.shape) + if "axis_order" in node.meta: + axis = node.meta["axis_order"].index(axis) + + # this is not the general case, only a quick workaround here + index = np.arange(1, input_tensor.shape[axis], dtype=np.uint32) + index_shape = [len(index)] + + split_output_tensors = [] + for i in range(input_tensor.shape[axis]): + output_tensor = self.get_tensor(node, node, i) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + wrapper_idx=i, + ) + split_output_tensors.append(output_tensor_wrapper) + + split_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpSplit.op_name, + ) + split_op.AddInputTensors(split_input_tensors) + split_op.AddOutputTensors(split_output_tensors) + + split_op.AddScalarParam( + OpSplit.param_axis, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {"data": np.uint32(axis)}, + ) + split_op.AddTensorParam( + OpSplit.param_split_index, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(index_shape), + index_shape, + index, + True, + ) + + return split_op diff --git a/backends/qualcomm/builders/op_cast.py b/backends/qualcomm/builders/op_sqrt.py similarity index 70% rename from backends/qualcomm/builders/op_cast.py rename to backends/qualcomm/builders/op_sqrt.py index d3096ee27cf..7847d00e8b8 100644 --- a/backends/qualcomm/builders/op_cast.py +++ b/backends/qualcomm/builders/op_sqrt.py @@ -10,12 +10,12 @@ import torch from .node_visitor import NodeVisitor, register_node_visitor -from .qnn_constants import OpCast, QNN_OP_PACKAGE_NAME_QTI_AISW +from .qnn_constants import OpSqrt, QNN_OP_PACKAGE_NAME_QTI_AISW @register_node_visitor -class Cast(NodeVisitor): - target = ["aten._to_copy.default"] +class SQRT(NodeVisitor): + target = ["aten.sqrt.default"] def __init__(self, *args) -> None: super().__init__(*args) @@ -25,6 +25,7 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], ) -> PyQnnWrapper.PyQnnOpWrapper: + # tensor input input_node = node.args[0] input_tensor = self.get_tensor(input_node, node) @@ -35,23 +36,24 @@ def define_node( nodes_to_wrappers, is_input_tensor=True, ) + sqrt_input_tensors = [input_tensor_wrapper] - output_tensor = self.get_tensor(node, node) - + out_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, - output_tensor, + out_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, is_input_tensor=False, ) + sqrt_output_tensors = [output_tensor_wrapper] - cast_op = PyQnnWrapper.PyQnnOpWrapper( + sqrt_op = PyQnnWrapper.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, - OpCast.op_name, + OpSqrt.op_name, ) - cast_op.AddInputTensors([input_tensor_wrapper]) - cast_op.AddOutputTensors([output_tensor_wrapper]) + sqrt_op.AddInputTensors(sqrt_input_tensors) + sqrt_op.AddOutputTensors(sqrt_output_tensors) - return cast_op + return sqrt_op diff --git a/backends/qualcomm/builders/op_sum_int_list.py b/backends/qualcomm/builders/op_sum_int_list.py new file mode 100644 index 00000000000..26cc262462e --- /dev/null +++ b/backends/qualcomm/builders/op_sum_int_list.py @@ -0,0 +1,80 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpReduceSum, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Sum(NodeVisitor): + target = ["aten.sum.dim_IntList"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + sum_input_tensors = [input_tensor_wrapper] + + # sum dims + sum_dims = cast(List[int], node.args[1]) + sum_dims = [sum_dim % len(input_node.meta["val"].shape) for sum_dim in sum_dims] + if "axis_order" in node.meta: + sum_dims = [node.meta["axis_order"].index(sum_dim) for sum_dim in sum_dims] + sum_dims_shape = [len(sum_dims)] + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + sum_output_tensors = [output_tensor_wrapper] + sum_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReduceSum.op_name, + ) + sum_op.AddInputTensors(sum_input_tensors) + sum_op.AddOutputTensors(sum_output_tensors) + sum_op.AddTensorParam( + OpReduceSum.param_axes, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(sum_dims_shape), + sum_dims_shape, + np.array(sum_dims, dtype=np.uint32), + True, + ) + + if len(node.args) > 2: + keep_dims = cast(bool, node.args[2]) + sum_op.AddScalarParam( + OpReduceSum.param_keep_dims, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {"data": keep_dims}, + ) + return sum_op diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py new file mode 100644 index 00000000000..d2762eb8f6b --- /dev/null +++ b/backends/qualcomm/builders/op_to.py @@ -0,0 +1,104 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class To(NodeVisitor): + target = ["aten._to_copy.default"] + sufixed_8_offset_diff = 128 + sufixed_16_offset_diff = 32768 + epsilon = 1e-6 + sufixed_8 = { + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, + } + sufixed_16 = { + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, + } + + def __init__(self, *args) -> None: + super().__init__(*args) + + def is_cast_node(self, node): + input_node = node.args[0] + + # Not a case which has two quant node, no need to consider the convert op + if not all([input_node.meta.get("quant_attrs"), node.meta.get("quant_attrs")]): + return True + + input_tensor = self.get_tensor(input_node, node) + _, inp_qconfs = self.get_quant_encoding_conf(input_node, False) + inp_dtype = self.get_data_type(input_tensor, inp_qconfs) + + output_tensor = self.get_tensor(node, node) + _, out_qconfs = self.get_quant_encoding_conf(node, False) + out_dtype = self.get_data_type(output_tensor, out_qconfs) + is_qparam_castalbe = ( + lambda o1, o2, s1, s2, diff: abs(s1 - s2) < self.epsilon + and abs(o1 - o2) == diff + ) + + if {inp_dtype, out_dtype} == self.sufixed_8: + return is_qparam_castalbe( + inp_qconfs["offset"], + out_qconfs["offset"], + inp_qconfs["scale"], + out_qconfs["scale"], + self.sufixed_8_offset_diff, + ) + elif {inp_dtype, out_dtype} == self.sufixed_16: + return is_qparam_castalbe( + inp_qconfs["offset"], + out_qconfs["offset"], + inp_qconfs["scale"], + out_qconfs["scale"], + self.sufixed_16_offset_diff, + ) + return False + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + output_tensor = self.get_tensor(node, node) + + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + + qnn_op = OpCast if self.is_cast_node(node) else OpConvert + op = PyQnnWrapper.PyQnnOpWrapper( + node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name + ) + op.AddInputTensors([input_tensor_wrapper]) + op.AddOutputTensors([output_tensor_wrapper]) + + return op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 82c50046bee..c776fe5a346 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -8,7 +8,6 @@ from enum import IntEnum, unique QNN_OP_PACKAGE_NAME_QTI_AISW = "qti.aisw" -QNN_uint16 = "uint16" # Below constants should be same as those in QNN headers. # Maybe someday we should expose these constants by pybind @@ -40,6 +39,11 @@ class OpConv2d: param_dilation: str = "dilation" +@dataclass(init=False, frozen=True) +class OpConvert: + op_name: str = "Convert" + + @dataclass(init=False, frozen=True) class OpDepthToSpace: op_name: str = "DepthToSpace" @@ -106,6 +110,13 @@ class OpExpandDims: param_axis: str = "axis" +@dataclass(init=False, frozen=True) +class OpReduceSum: + op_name: str = "ReduceSum" + param_axes: str = "axes" + param_keep_dims: str = "keep_dims" + + @dataclass(init=False, frozen=True) class OpFullyConnected: op_name: str = "FullyConnected" @@ -123,6 +134,11 @@ class OpGelu: op_name: str = "Gelu" +@dataclass(init=False, frozen=True) +class OpSqrt: + op_name: str = "ElementWiseSquareRoot" + + @dataclass(init=False, frozen=True) class OpHardSwish: op_name: str = "HardSwish" @@ -247,6 +263,13 @@ class OpSoftmax: param_beta: str = "beta" +@dataclass(init=False, frozen=True) +class OpSplit: + op_name: str = "Split" + param_axis: str = "axis" + param_split_index: str = "split_index" + + @dataclass(init=False, frozen=True) class OpSqueeze: op_name: str = "Squeeze" diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index 92c129f342f..217c840553c 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict, Optional + import torch from torch._export.utils import get_buffer, get_param, is_buffer, is_param @@ -97,3 +99,20 @@ def is_constant( return tensor.meta["val"].constant is not None return False + + +def deduce_dtype( + tensor: torch.Tensor, quant_infos: Optional[Dict] = None +) -> torch.dtype: + if quant_infos: + quant_range = quant_infos["quant_max"] - quant_infos["quant_min"] + unsigned = quant_infos["quant_min"] >= 0 + if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min + 1: + return torch.uint8 if unsigned else torch.int8 + + elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min: + return torch.uint16 if unsigned else torch.int16 + + return quant_infos["dtype"] + + return tensor.dtype diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index b06a5766a63..61935cf3536 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -11,8 +11,10 @@ not_supported_operator = [ exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.clone.default, - exir_ops.edge.aten.index.Tensor, exir_ops.edge.aten.full.default, + exir_ops.edge.aten.slice_scatter.default, + exir_ops.edge.aten.index.Tensor, + exir_ops.edge.aten.index_put.default, ] allow_list_operator = [ diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index a704d3a6336..0c5b25284eb 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import copy +from collections import defaultdict from typing import Any, Dict, List import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager @@ -49,7 +50,7 @@ def __init__( ) self.skip_node_id_set = skip_node_id_set - self.nodes_to_wrappers = {} + self.nodes_to_wrappers = defaultdict(dict) self.qnn_manager = PyQnnManager.QnnManager( generate_qnn_executorch_option(compiler_specs) ) @@ -95,6 +96,9 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: print(f"[QNN Partitioner Op Support]: {node.target.__name__} | {supported}") return supported + def __del__(self): + self.qnn_manager.Destroy() + class QnnPartitioner(Partitioner): def __init__( @@ -144,6 +148,7 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu # pop certain keys in meta for not affecting the passes in compilation # TODO: need to put property name in common definitions node.meta.pop("axis_order", "") + del self.op_support_checker return PartitionResult( tagged_exported_program=edge_program, partition_tags=self.partition_tags ) diff --git a/backends/qualcomm/passes/build_quant_io.py b/backends/qualcomm/passes/build_quant_io.py new file mode 100644 index 00000000000..44b66592f3c --- /dev/null +++ b/backends/qualcomm/passes/build_quant_io.py @@ -0,0 +1,52 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.tensor import TensorSpec + +from .utils import q_io_key + + +class BuildQuantIo(ExportPass): + + def __init__(self): + super(BuildQuantIo, self).__init__() + + def _make_spec(self, x): + if isinstance(x, torch.Tensor): + return TensorSpec.from_tensor(x) + elif isinstance(x, (int, bool, float)): + return x + else: + return None + + def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + # forcely update delegate node's meta['spec'] to get correct output + # tensor size in runtime + call_delegate = [ + node + for node in graph_module.graph.nodes + if node.op == "call_function" and node.name == "executorch_call_delegate" + ] + assert len(call_delegate) == 1 + spec = [] + for n in graph_module.graph.nodes: + if q_io_key in n.meta: + n.meta["val"] = n.meta["val"].to(dtype=n.meta[q_io_key]) + if n.op == "call_function" and "getitem" in n.name: + fake_tensor = n.meta["val"] + if q_io_key in n.meta: + fake_tensor = fake_tensor.to(dtype=n.meta[q_io_key]) + spec.append(self._make_spec(fake_tensor)) + + call_delegate[0].meta["spec"] = tuple(spec) + + def call(self, graph_module: torch.fx.GraphModule): + self._build(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/fuse_consecutive_transpose.py b/backends/qualcomm/passes/fuse_consecutive_transpose.py new file mode 100644 index 00000000000..740b91dfaac --- /dev/null +++ b/backends/qualcomm/passes/fuse_consecutive_transpose.py @@ -0,0 +1,86 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + + +class FuseConsecutiveTranspose(ExportPass): + """ + This pass fuses consecutive transpose / permute into one to reduce runtime + overhead + """ + + def __init__(self): + super().__init__() + self.op_map = { + exir_ops.edge.aten.permute_copy.default, + } + self.visited = set() + self.nodes = [] + + def _traverse(self, node): + if node in self.visited or not node.target in self.op_map: + return + + self.nodes.append(node) + self.visited.add(node) + next_users = [n for n in list(node.users) if n.target in self.op_map] + if not next_users: + return + + if len(next_users) == 1: + self._traverse(list(node.users)[0]) + else: + raise NotImplementedError( + f"Check the node {node}, wich encounter mutilple permute output case" + ) + + def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + graph = graph_module.graph + for n in graph_module.graph.nodes: + self._traverse(n) + if len(self.nodes) > 1: + permute_order = [] + input_node, output_node = self.nodes[0].args[0], self.nodes[-1] + input_shape = input_node.meta["val"].shape + axis_order = torch.arange(len(input_shape)).tolist() + for node in self.nodes: + permute_order.append(node.args[1]) + axis_order = [axis_order[i] for i in node.args[1]] + with graph.inserting_after(input_node): + permute_op = exir_ops.edge.aten.permute_copy.default + permute_node = graph.create_node( + "call_function", permute_op, (input_node, axis_order) + ) + users = output_node.users.copy() + for user in users: + user.replace_input_with(output_node, permute_node) + + # copy metadata + permute_node.meta = output_node.meta + # Without inserted_permute_tag, we might obtain wrong input shape + if any( + [ + pn.meta.get(LayoutTransform.inserted_permute_tag) + for pn in self.nodes + ] + ): + permute_node.meta[LayoutTransform.inserted_permute_tag] = True + + # clear current stack + self.nodes = [] + + def call(self, graph_module: torch.fx.GraphModule): + self._fuse(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/insert_io_qdq.py b/backends/qualcomm/passes/insert_io_qdq.py index 5e6a03799cf..d88ca24fbba 100644 --- a/backends/qualcomm/passes/insert_io_qdq.py +++ b/backends/qualcomm/passes/insert_io_qdq.py @@ -11,7 +11,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from .utils import dq_ops, q_ops +from .utils import q_io_key, q_ops class InsertIOQDQ(ExportPass): @@ -107,41 +107,18 @@ def _insert_dequant_node( if user.op == "output": user.replace_input_with(node, inserted_node) - # When having requantization dq/q nodes at the input, - # such as the case: input1 -> dq_node1 -> q_node1 -> node1, - # we should fold the dq_node1 and connect input -> q_node1 -> node1. - def _fold_mix_quantization_dq_node(self, graph_module, input_node): - input_users = list(input_node.users.keys()) - for input_user in input_users: - if input_user.target not in dq_ops: - continue - dq_users = list(input_user.users.keys()) - for dq_user in dq_users: - dq_user.replace_input_with(input_user, input_node) - - # When having requantization dq/q nodes at the output, - # such as the case: node(int32) -> dq(int32) -> q(uint8) -> output(int32), - # we should fold the q node and connect node(int32) -> dq(int32) -> output(int32). - def _fold_mix_quantization_q_node(self, graph_module, node, users): - for user in users: - if user.op == "output": - output_node = user - break - - dq_node = node.args[0] - for out_node in output_node.meta["val"]: - if dq_node.meta["val"].dtype == out_node.dtype: - user.replace_input_with(node, dq_node) - def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: + # do nothing when a node is expected to output a quant tensor + if n.meta.get(q_io_key): + continue + # insert q after input or fold mix_quantization dq if applicable if ( n.op == "placeholder" and n.meta.get("quant_attrs") and not is_parameter(n, self.edge_program) ): - self._fold_mix_quantization_dq_node(graph_module, n) self._insert_quant_node( graph_module, n, n.meta["quant_attrs"]["encoding"] ) @@ -149,10 +126,6 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: # insert dq before output or fold mix_quantization q if applicable users = list(n.users.keys()) if n.meta.get("quant_attrs") and any(user.op == "output" for user in users): - if n.target in q_ops: - self._fold_mix_quantization_q_node(graph_module, n, users) - # If q_node is fold, it will have no users, - # so it won't insert dequant node in following function. self._insert_dequant_node( graph_module, n, diff --git a/backends/qualcomm/passes/insert_requantize.py b/backends/qualcomm/passes/insert_requantize.py index d0169ebe357..c41bc16a5f5 100644 --- a/backends/qualcomm/passes/insert_requantize.py +++ b/backends/qualcomm/passes/insert_requantize.py @@ -6,11 +6,13 @@ import torch -from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from .utils import q_io_key -class InsertRequantize(InsertIOQDQ): + +class InsertRequantize(ExportPass): """ This pass inserts dq/q nodes for non-arithmetic operators which have different quantization specs in input and activation @@ -26,10 +28,9 @@ class InsertRequantize(InsertIOQDQ): def __init__( self, edge_program: torch.export.ExportedProgram, - insert_requantize: bool = False, ): - super().__init__(edge_program) - self.insert_requantize = insert_requantize + super(InsertRequantize, self).__init__() + self.edge_program = edge_program # TODO: Implement this function when we have an op with # multiple outputs that requires quant attributes. @@ -39,16 +40,21 @@ def _multi_output_annotation(self) -> None: def _single_output_annotation( self, gm: torch.fx.GraphModule, n: torch.fx.node ) -> None: - dq_attr = n.meta["quant_attrs"] - q_attr = n.meta["requantize"] - # insert dq with given quantization attribute in input node - dq = self._insert_quant_node( - gm, n, InsertIOQDQ.q_dq_map[q_attr["encoding"]], dq_attr - ) - dq.meta["quant_attrs"] = dq_attr - # insert q with given quantization attribute in current node - q = self._insert_quant_node(gm, dq, q_attr["encoding"], q_attr) - q.meta["quant_attrs"] = q_attr + with gm.graph.inserting_after(n): + users = list(n.users.keys()) + inserted_n = gm.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (n,), + ) + + inserted_n.meta["val"] = n.meta["val"] + inserted_n.meta["quant_attrs"] = n.meta.pop("requantize") + if n.meta.get(q_io_key): + inserted_n.meta[q_io_key] = n.meta[q_io_key] + + for user in users: + user.replace_input_with(n, inserted_n) def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: @@ -59,3 +65,9 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: or n.target in self.multi_output_op_ignore_set else self._multi_output_annotation() ) + + def call(self, graph_module: torch.fx.GraphModule): + self._insert(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/layout_transform.py b/backends/qualcomm/passes/layout_transform.py index 8c86f1919ad..fbf1431f1a5 100644 --- a/backends/qualcomm/passes/layout_transform.py +++ b/backends/qualcomm/passes/layout_transform.py @@ -52,6 +52,9 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.pow.Tensor_Scalar, *q_ops, *dq_ops, _operator.getitem, @@ -109,7 +112,10 @@ def is_layout_sensitive(self, node: torch.fx.Node) -> bool: return node.target in self.layout_sensitive_ops def is_layout_agnostic(self, node: torch.fx.Node) -> bool: - if node.target == exir_ops.edge.aten.mean.dim: + if node.target in [ + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sum.dim_IntList, + ]: # if dimemsion is not kept, we'll have no clue how to do layout transform if len(node.args) < 3 or not node.args[2]: return False diff --git a/backends/qualcomm/passes/utils.py b/backends/qualcomm/passes/utils.py index 49da1929e84..93e8a92a0d4 100755 --- a/backends/qualcomm/passes/utils.py +++ b/backends/qualcomm/passes/utils.py @@ -9,6 +9,8 @@ from executorch.exir.dialects._ops import ops as exir_ops +q_io_key = "q_tensor_io" + q_ops = { exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 1e6275892f0..e95aaa0aaea 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -5,12 +5,16 @@ # LICENSE file in the root directory of this source tree. import logging +from collections import defaultdict from typing import final, List import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors from executorch.backends.qualcomm.passes.convert_to_linear import ConvertToLinear +from executorch.backends.qualcomm.passes.fuse_consecutive_transpose import ( + FuseConsecutiveTranspose, +) from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ from executorch.backends.qualcomm.passes.insert_requantize import InsertRequantize from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform @@ -47,6 +51,8 @@ def preprocess( InsertRequantize(edge_program), InsertIOQDQ(edge_program), LayoutTransform(edge_program, insert_permute=True), + # please enable this when apply convert_linear_to_conv2d + FuseConsecutiveTranspose(), ] ) @@ -54,7 +60,7 @@ def preprocess( assert pass_result is not None enable_tensor_dump = qnn_manager.IsTensorDump() - nodes_to_wrappers = {} + nodes_to_wrappers = defaultdict(dict) node_visitors = get_node_visitors( edge_program, enable_tensor_dump=enable_tensor_dump ) @@ -88,6 +94,8 @@ def preprocess( ) assert len(qnn_context_binary) != 0, "Failed to generate Qnn context binary." qnn_manager.Destroy() + del py_op_wrapper_list + del qnn_manager # For now, debug_handle_map is not used by QNN ExecuTorch return PreprocessResult( processed_bytes=bytes(qnn_context_binary), debug_handle_map={} diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 1414af171a4..69f9ad1d589 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from enum import IntEnum, unique -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple +from typing import Callable, Dict, Optional, Sequence, Set import torch from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid @@ -16,28 +16,25 @@ from executorch.backends.qualcomm.passes.remove_clone import RemoveClone from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer -from torch import Tensor from torch._ops import OpOverload -from torch.ao.quantization.observer import ( - HistogramObserver, - MinMaxObserver, - MovingAverageMinMaxObserver, - PerChannelMinMaxObserver, +from torch.ao.quantization.quantizer import Quantizer +from torch.fx import GraphModule + +from .utils import ( + get_16a4w_qnn_ptq_config, + get_16a8w_qnn_ptq_config, + get_default_16bit_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, + get_ptq_per_channel_weight_config, + OP_ANNOTATOR, + QuantizationConfig, ) -from torch.ao.quantization.quantizer import ( - DerivedQuantizationSpec, - QuantizationSpec, - Quantizer, -) - -from torch.fx import GraphModule, Node - -from .utils import OP_ANNOTATOR, QuantizationConfig __all__ = [ "QnnQuantizer", "QuantDtype", "get_16a4w_qnn_ptq_config", + "get_16a8w_qnn_ptq_config", "get_default_16bit_qnn_ptq_config", "get_default_8bit_qnn_ptq_config", ] @@ -54,205 +51,6 @@ class QuantDtype(IntEnum): use_8a8w = 2 -def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: - def _derive_bias_qparams_fn( - obs_or_fqs: List, - ) -> Tuple[Tensor, Tensor]: - assert ( - len(obs_or_fqs) == 2 - ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" - act_obs_or_fq = obs_or_fqs[0] - weight_obs_or_fq = obs_or_fqs[1] - weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() - act_scale, act_zp = act_obs_or_fq.calculate_qparams() - (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( - act_scale, weight_scale - ) - derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) - derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) - return (derived_scale, derived_zero) - - input_act = node.args[0] - assert isinstance(input_act, Node) - weight = node.args[1] - assert isinstance(weight, Node) - - return DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], - derive_qparams_fn=_derive_bias_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - ch_axis=0, - qscheme=torch.per_channel_symmetric, - ) - - -def get_default_8bit_qnn_ptq_config() -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - quant_min=0, - quant_max=torch.iinfo(torch.uint8).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -# 4 bits quantization only supports specific ops. -def get_16a4w_qnn_ptq_config() -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-7, - quant_max=7, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_default_16bit_qnn_ptq_config() -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int16, - quant_min=torch.iinfo(torch.int16).min + 1, - quant_max=torch.iinfo(torch.int16).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - # torch does not support uint16 quantization, use int32 to bypass - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_ptq_per_channel_weight_config( - act_dtype=torch.uint8, weight_dtype=torch.int8 -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - supported_act_types = { - torch.uint8, - torch.uint16, - torch.int8, - torch.int16, - } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} - assert ( - act_dtype in supported_act_types - ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" - - assert ( - weight_dtype in supported_weight_dtypes - ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" - - # torch do not support uint16 quantization, use int32 to bypass - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=HistogramObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = _derived_bias_quant_spec - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - class QnnQuantizer(Quantizer): SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys()) diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index ee6eb1608d1..c9e21af767f 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -4,28 +4,33 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import torch +from torch import Tensor from torch._ops import OpOverload from torch._subclasses import FakeTensor +from torch.ao.quantization.observer import ( + HistogramObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + PerChannelMinMaxObserver, +) + from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) - from torch.ao.quantization.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, ) from torch.fx import Node -QUANT_ANNOTATION_KEY = "quantization_annotation" -OP_ANNOTATOR: Dict[OpOverload, Callable] = {} - @dataclass(eq=True, frozen=True) class QuantizationConfig: @@ -35,6 +40,255 @@ class QuantizationConfig: bias: Optional[QuantizationSpec | Callable] +def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: + def _derive_bias_qparams_fn( + obs_or_fqs: List, + ) -> Tuple[Tensor, Tensor]: + assert ( + len(obs_or_fqs) == 2 + ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() + act_scale, act_zp = act_obs_or_fq.calculate_qparams() + (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( + act_scale, weight_scale + ) + derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) + derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) + return (derived_scale, derived_zero) + + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + return DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + ch_axis=0, + qscheme=torch.per_channel_symmetric, + ) + + +def get_default_8bit_qnn_ptq_config(act_symmetric: bool = False) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max + 1, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +# 4 bits quantization only supports specific ops. +def get_16a4w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_16a8w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max + 1, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_default_16bit_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int16, + quant_min=torch.iinfo(torch.int16).min + 1, + quant_max=torch.iinfo(torch.int16).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + # torch does not support uint16 quantization, use int32 to bypass + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_ptq_per_channel_weight_config( + act_dtype=torch.uint8, weight_dtype=torch.int8 +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + supported_act_types = { + torch.uint8, + torch.uint16, + torch.int8, + torch.int16, + } + # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype + supported_weight_dtypes = {"int4", torch.int8, torch.int16} + assert ( + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + + assert ( + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + + # torch do not support uint16 quantization, use int32 to bypass + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +QUANT_ANNOTATION_KEY = "quantization_annotation" +OP_ANNOTATOR: Dict[OpOverload, Callable] = {} + + def register_annotator(ops: List[OpOverload]): def decorator(annotator: Callable): for op in ops: @@ -117,15 +371,12 @@ def annotate_single_in_single_out( assert isinstance(input_act, Node) input_qspec_map[input_act] = quantization_config.input_activation - node_tensor = node.meta.get("val") - if torch.is_tensor(node_tensor) and node_tensor.dtype != torch.float32: - return - - node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) + if _is_input_float_tensor(node): + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None: @@ -133,7 +384,11 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None return input_act_qspec = quantization_config.input_activation - output_act_qspec = quantization_config.output_activation + output_act_qspec = ( + quantization_config.output_activation + if node.meta["val"].dtype == torch.float32 + else None + ) input_qspec_map = {} input_act0 = node.args[0] @@ -176,6 +431,11 @@ def annotate_rsub(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.sum.dim_IntList]) +def annotate_sum(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + @register_annotator([torch.ops.aten.ceil.default]) def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -246,12 +506,16 @@ def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> N @register_annotator([torch.ops.aten.permute.default]) def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_single_out(node, quantization_config) @register_annotator([torch.ops.aten.view.default]) def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_single_out(node, quantization_config) @register_annotator([torch.ops.aten.pixel_shuffle.default]) @@ -303,6 +567,11 @@ def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.sqrt.default]) +def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.gelu.default]) def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -433,10 +702,8 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None if isinstance(input_act1, Node): # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: - input_qspec_map[input_act1] = quantization_config.weight - quantization_annotation = input_act1.meta.get(QUANT_ANNOTATION_KEY, None) - if quantization_annotation: - quantization_annotation.output_qspec = quantization_config.weight + # we should use int16 for mm / bmm instead of int4 + input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -464,10 +731,8 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: if isinstance(input_act1, Node): # In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: - input_qspec_map[input_act1] = quantization_config.weight - quantization_annotation = input_act1.meta.get(QUANT_ANNOTATION_KEY, None) - if quantization_annotation: - quantization_annotation.output_qspec = quantization_config.weight + # we should use int16 for mm / bmm instead of int4 + input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index 77449703c5f..8c787cf7981 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -11,7 +11,7 @@ #include #include #include - +#include #include namespace torch { namespace executor { @@ -20,6 +20,12 @@ using namespace qnn; using namespace qnn_delegate; constexpr const char* QNN_COMPILE_SPEC = "qnn_compile_spec"; +bool CompareQnnInput(const std::shared_ptr& a, const std::shared_ptr& b) { + int numA = std::stoi(a->GetName().substr(a->GetName().find('_') + 1)); + int numB = std::stoi(b->GetName().substr(b->GetName().find('_') + 1)); + return numA < numB; +} + Result QnnExecuTorchBackend::init( BackendInitContext& context, FreeableBuffer* processed, @@ -187,6 +193,9 @@ Error QnnExecuTorchBackend::execute( qnn_manager->GetGraphOutputs(); std::vector input_tensor_structs; std::vector output_tensor_structs; + // Using the order of the nodes as external_id in AOT + // to extract the right arg from *args at runtime + std::sort(input_tensors.begin(), input_tensors.end(), CompareQnnInput); input_tensor_structs.reserve(input_tensors.size()); for (int i = 0; i < input_tensors.size(); ++i) { diff --git a/backends/qualcomm/runtime/backends/QnnBackendCache.cpp b/backends/qualcomm/runtime/backends/QnnBackendCache.cpp index 0c569ae5ab6..b8ac289133b 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendCache.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendCache.cpp @@ -87,7 +87,10 @@ QnnBackendCache::QnnBackendCache( state_ = SERIALIZE; QNN_EXECUTORCH_LOG_INFO("Caching: Caching is in SAVE MODE."); return; - } else { + } + /*else { + // TODO: need fix on this since qnn context binary could somehow + // pass the check of flatbuffer verifier // check if context binary came from flatbuffer flatbuffers::FlatBufferBuilder builder; flatbuffers::Verifier verifier( @@ -98,7 +101,7 @@ QnnBackendCache::QnnBackendCache( state_ = ONLINE_PREPARE; return; } - } + }*/ if (qnn_sys_impl_.Load() != Error::Ok) { QNN_EXECUTORCH_LOG_ERROR( diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index c8379cf0b7a..b2c8e0d61ca 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -71,6 +71,7 @@ if [ "$BUILD_AARCH64" = true ]; then -DCMAKE_INSTALL_PREFIX=$BUILD_ROOT \ -DEXECUTORCH_BUILD_QNN=ON \ -DEXECUTORCH_BUILD_SDK=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ -DQNN_SDK_ROOT=$QNN_SDK_ROOT \ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ diff --git a/backends/qualcomm/setup.md b/backends/qualcomm/setup.md index 18ebf412fc0..b78b481e86e 100644 --- a/backends/qualcomm/setup.md +++ b/backends/qualcomm/setup.md @@ -93,7 +93,6 @@ mkdir build_android cd build_android # build executorch & qnn_executorch_backend cmake .. \ - -DBUCK2=buck2 \ -DCMAKE_INSTALL_PREFIX=$PWD \ -DEXECUTORCH_BUILD_QNN=ON \ -DQNN_SDK_ROOT=$QNN_SDK_ROOT \ diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index edc7a469f7b..ba97240455f 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -29,7 +29,7 @@ def __init__(self): super().__init__() def forward(self, x): - return 10.0 + x + return 10 + x class Arange(torch.nn.Module): @@ -122,8 +122,8 @@ def __init__( ) -> None: super().__init__() self.modules = [ - Conv2DSequential(), - Conv2DSequential(), + Conv2dSequential(), + Conv2dSequential(), Add(), Relu(), ] @@ -172,7 +172,7 @@ def forward(self, x, y): return CompositeReferenceModule(self.modules) -class Conv1DSequential(torch.nn.Module): +class Conv1dSequential(torch.nn.Module): def __init__(self): super().__init__() self.first = torch.nn.Conv1d( @@ -210,43 +210,6 @@ def forward(self, x): return x -class Conv2DSequential(torch.nn.Module): - def __init__(self): - super().__init__() - self.first = torch.nn.Conv2d( - in_channels=1, - out_channels=3, - kernel_size=(3, 3), - padding=1, - bias=True, - ) - self.second = torch.nn.Conv2d( - in_channels=3, - out_channels=2, - kernel_size=(3, 3), - padding=1, - bias=True, - ) - - def forward(self, x): - return self.second(self.first(x)) - - -class Conv2DSingle(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d( - in_channels=1, - out_channels=3, - kernel_size=(3, 3), - padding=1, - bias=True, - ) - - def forward(self, x): - return self.conv(x) - - class Conv2dAvgPool2d(torch.nn.Module): def __init__(self): super().__init__() @@ -321,6 +284,58 @@ def forward(self, x): return self.pool(self.conv(x)) +class Conv2dSequential(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.Conv2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + self.second = torch.nn.Conv2d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + + def forward(self, x): + return self.second(self.first(x)) + + +class Conv2dSingle(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + + def forward(self, x): + return self.conv(x) + + +class Conv2dSumReduceDim(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.Conv2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + + def forward(self, x): + return torch.sum(self.first(x), dim=(2, 3), keepdim=False) + + class Div(torch.nn.Module): def __init__(self): super().__init__() @@ -691,7 +706,7 @@ def __init__(self): super().__init__() def forward(self, x): - return x / torch.sqrt(torch.tensor([64])) + return x / torch.sqrt(torch.tensor([64.0])) class Squeeze(torch.nn.Module): @@ -748,6 +763,14 @@ def forward(self, x): return 10 - x +class SumIntList(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sum(x, dim=(2, 3), keepdim=True) + + class Tanh(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index d539827fdb9..3874da9e981 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -95,12 +95,12 @@ def test_qnn_backend_clamp(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv1d(self): - module = Conv1DSequential() # noqa: F405 + module = Conv1dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3]),) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv2d(self): - module = Conv2DSequential() # noqa: F405 + module = Conv2dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) self.lower_module_and_test_output(module, sample_input) @@ -183,11 +183,10 @@ def test_qnn_backend_element_wise_mul(self): self.lower_module_and_test_output(module, sample_input) index += 1 - @unittest.skip("not yet implemented") def test_qnn_backend_element_wise_sqrt(self): modules = [Sqrt(), SqrtConstant()] # noqa: F405 - sample_input = (torch.randn([3, 1]),) for i, module in enumerate(modules): + sample_input = (torch.rand([3, 1]),) with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) @@ -357,6 +356,11 @@ def test_qnn_backend_squeeze(self): sample_input = (torch.randn([1, 3, 3]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sum_int_list(self): + module = SumIntList() # noqa: F405 + sample_input = (torch.randn([1, 4, 8, 8]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -421,6 +425,11 @@ def test_qnn_backend_conv2d_max_pool2d(self): sample_input = (torch.rand(1, 2, 14, 14),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_sum_reduce_dim(self): + module = Conv2dSumReduceDim() # noqa: F405 + sample_input = (torch.randn([1, 1, 3, 3]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_residual_block(self): module = ResidualBlockModule() # noqa: F405 sample_input = (torch.randn(1, 32, 28, 28),) @@ -494,7 +503,7 @@ def setUp(self): ) def test_qnn_backend_16a4w_conv2d(self): - module = Conv2DSingle() # noqa: F405 + module = Conv2dSingle() # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) module = self.get_qdq_module( module, sample_input, quant_dtype=QuantDtype.use_16a4w @@ -575,13 +584,13 @@ def test_qnn_backend_clamp(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv1d(self): - module = Conv1DSequential() # noqa: F405 + module = Conv1dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3]),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv2d(self): - module = Conv2DSequential() # noqa: F405 + module = Conv2dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -669,11 +678,10 @@ def test_qnn_backend_element_wise_mul(self): self.lower_module_and_test_output(module, sample_input) index += 1 - @unittest.skip("not yet implemented") def test_qnn_backend_element_wise_sqrt(self): modules = [Sqrt(), SqrtConstant()] # noqa: F405 - sample_input = (torch.randn([3, 1]),) for i, module in enumerate(modules): + sample_input = (torch.rand([3, 1]),) with self.subTest(i=i): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -873,6 +881,12 @@ def test_qnn_backend_stack(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sum_int_list(self): + module = SumIntList() # noqa: F405 + sample_input = (torch.randn([1, 4, 8, 8]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -946,6 +960,12 @@ def test_qnn_backend_conv2d_max_pool2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_sum_reduce_dim(self): + module = Conv2dSumReduceDim() # noqa: F405 + sample_input = (torch.randn([1, 1, 3, 3]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_example_models(self): instances = [ {"module": DeepLabV3ResNet101Model(), "annotation": ()}, @@ -1095,6 +1115,7 @@ def test_qnn_backend_multi_contexts_composite(self): exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) + @unittest.expectedFailure def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=True) @@ -1227,6 +1248,7 @@ def test_qnn_backend_multi_contexts_composite(self): exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) + @unittest.expectedFailure def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=False) @@ -1323,6 +1345,40 @@ def test_fbnet(self): self.assertGreaterEqual(msg["top_1"], 60) self.assertGreaterEqual(msg["top_5"], 90) + def test_ssd300_vgg16(self): + if not self.required_envs([self.pretrained_weight, self.oss_repo]): + self.skipTest("missing required envs") + + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/ssd300_vgg16.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--oss_repo", + self.oss_repo, + "--pretrained_weight", + self.pretrained_weight, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + self.assertGreaterEqual(msg["mAP"], 0.70) + class TestExampleScript(TestQNN): def required_envs(self, conditions=None) -> bool: @@ -1771,6 +1827,11 @@ def setup_environment(): help="Emit log only when error happened", action="store_true", ) + parser.add_argument( + "--oss_repo", + help="Path to open source software model repository", + type=str, + ) args, ns_args = parser.parse_known_args(namespace=unittest) TestQNN.host = args.host @@ -1785,6 +1846,7 @@ def setup_environment(): TestQNN.online_prepare = args.online_prepare TestQNN.enable_profile = args.enable_profile TestQNN.error_only = args.error_only + TestQNN.oss_repo = args.oss_repo TestQNN.shared_buffer = args.shared_buffer return sys.argv[:1] + ns_args diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 0a9b7d064d1..bfeb1bb649d 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -61,6 +61,48 @@ def qnn_edge_config() -> exir.EdgeCompileConfig: return exir.EdgeCompileConfig(_check_ir_validity=False) +def convert_linear_to_conv2d(module: torch.nn.Module): + class Conv2D(torch.nn.Module): + def __init__(self, weight, bias=None): + super().__init__() + use_bias = bias is not None + self.conv = torch.nn.Conv2d( + in_channels=weight.shape[0], + out_channels=weight.shape[1], + kernel_size=1, + padding=0, + bias=use_bias, + ) + self.conv.weight = torch.nn.Parameter(weight.reshape(*weight.shape, 1, 1)) + if use_bias: + self.conv.bias = torch.nn.Parameter(bias) + + def forward(self, x): + rank = x.dim() + x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1) + x = torch.transpose(x, 1, 2) + res = self.conv(x) + res = torch.transpose(res, 1, 2) + res = res.squeeze(-1) if rank == 3 else res.reshape(*res.shape[1:3]) + return res + + def replace_linear(module: torch.nn.Module): + attr_strs = dir(module) + if type(module) == torch.nn.ModuleList: + attr_strs += [str(i) for i in range(len(module))] + + for attr_str in attr_strs: + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + setattr(module, attr_str, Conv2D(target_attr.weight, target_attr.bias)) + + for _, sub_module in module.named_children(): + sub_module = replace_linear(sub_module) + return module + + return replace_linear(module) + + def canonicalize_program(prog: ExportedProgram): # check if user specifies to use multi_contexts # this is a generic approach in case there exists multiple backends @@ -182,6 +224,7 @@ def generate_htp_compiler_spec( # TODO: enable voting mechanism in runtime and make this as an option htp_options.performance_mode = QnnExecuTorchHtpPerformanceMode.kHtpBurst htp_options.use_multi_contexts = use_multi_contexts + htp_options.max_sf_buf_size = 73859072 htp_options.use_dlbc = use_dlbc return QnnExecuTorchBackendOptions( backend_type=QnnExecuTorchBackendType.kHtpBackend, diff --git a/examples/models/llama2/tokenizer/bpe_tokenizer.cpp b/examples/models/llama2/tokenizer/bpe_tokenizer.cpp index ed7d34aca4d..a517d3157fb 100644 --- a/examples/models/llama2/tokenizer/bpe_tokenizer.cpp +++ b/examples/models/llama2/tokenizer/bpe_tokenizer.cpp @@ -328,7 +328,7 @@ BPETokenizer::encode(const std::string& text, int8_t bos, int8_t eos) { } delete[] str_buffer; - return Result(tokens); + return Result>(tokens); } } // namespace executor diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index cff5db2a63d..5b8e66e40b2 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -56,6 +56,7 @@ get_filename_component(EXECUTORCH_SOURCE_DIR ABSOLUTE ) set(_qnn_executor_runner__srcs ${_executor_runner__srcs}) +set(_qnn_llama_runner__srcs ${_llama_runner__srcs}) # portable_ops_lib gen_selected_ops("" "" "ON") @@ -74,6 +75,7 @@ target_include_directories(full_portable_ops_lib ${_common_include_directories} ) +# prerpocess executor runner src files list( TRANSFORM _qnn_executor_runner__srcs @@ -92,8 +94,29 @@ list( ${CMAKE_CURRENT_LIST_DIR}/executor_runner/qnn_executor_runner.cpp ) -add_executable(qnn_executor_runner ${_qnn_executor_runner__srcs}) +# preprocess llama runner src files +list( + TRANSFORM + _qnn_llama_runner__srcs + PREPEND + "${EXECUTORCH_SOURCE_DIR}/" +) +list( + FILTER + _qnn_llama_runner__srcs + EXCLUDE REGEX + ".*runner.*$" +) +list( + PREPEND + _qnn_llama_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/executor_runner/qnn_llama_runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/llama2/runner/runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/llama2/runner/runner.h +) +# build executor runner +add_executable(qnn_executor_runner ${_qnn_executor_runner__srcs}) target_include_directories(qnn_executor_runner PUBLIC ${_common_include_directories} @@ -109,3 +132,21 @@ target_compile_options(qnn_executor_runner PUBLIC ${_common_compile_options} ) + +# build llama runner +add_executable(qnn_llama_runner ${_qnn_llama_runner__srcs}) +target_include_directories(qnn_llama_runner + PUBLIC + ${_common_include_directories} +) +target_link_libraries(qnn_llama_runner + qnn_executorch_backend + full_portable_ops_lib + extension_data_loader + extension_module + gflags +) +target_compile_options(qnn_llama_runner + PUBLIC + ${_common_compile_options} +) diff --git a/examples/qualcomm/executor_runner/qnn_llama_runner.cpp b/examples/qualcomm/executor_runner/qnn_llama_runner.cpp new file mode 100644 index 00000000000..ab17a551d49 --- /dev/null +++ b/examples/qualcomm/executor_runner/qnn_llama_runner.cpp @@ -0,0 +1,217 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * + * This tool can run ExecuTorch model files with Qualcomm AI Engine Direct + * and the portable kernels. + * + * User could specify arguments like desired input data, iterfations, etc. + * Currently we assume that the outputs are all fp32 tensors. + */ + +#include +#include +#include +#include + +#include + +#include +#include +#include + +DEFINE_string( + model_paths, + "qnn_llama2.pte", + "Model serialized in flatbuffer format."); + +DEFINE_string( + output_folder_path, + "outputs", + "Executorch inference data output path."); + +DEFINE_string(input_list_path, "input_list.txt", "Model input list path."); + +DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); + +DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt."); + +DEFINE_double( + temperature, + 0.8f, + "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic"); + +DEFINE_int32( + seq_len, + 128, + "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens."); + +int main(int argc, char** argv) { + using namespace torch::executor; + + gflags::ParseCommandLineFlags(&argc, &argv, true); + + std::vector model_path_list; + std::istringstream f(FLAGS_model_paths); + std::string s; + while (getline(f, s, ',')) { + model_path_list.push_back(s); + } + + const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); + const char* prompt = FLAGS_prompt.c_str(); + double temperature = FLAGS_temperature; + int32_t seq_len = FLAGS_seq_len; + + // create llama runner + Runner runner(model_path_list, tokenizer_path, temperature); + ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method"); + + // MethodMeta describes the memory requirements of the method. + std::vector> method_metas = runner.get_methods_meta(); + for(auto& method_meta: method_metas){ + ET_CHECK_MSG( + method_meta.ok(), + "Failed to get method_meta 0x%x", + (unsigned int)method_meta.error()); + } + + // Fill in data for input + std::ifstream input_list(FLAGS_input_list_path); + ET_CHECK_MSG(input_list.is_open(), "Failed to open input_list.txt"); + + auto split = [](std::string s, std::string delimiter) { + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + std::vector res; + + while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { + token = s.substr(pos_start, pos_end - pos_start); + pos_start = pos_end + delim_len; + res.push_back(token); + } + res.push_back(s.substr(pos_start)); + return res; + }; + + std::string file_path; + size_t inference_index = 0; + std::vector> freqs_inputs(2); + std::vector>> inputs(method_metas.size()-2); + + for (int i = 1; i < method_metas.size()-1; i++){ + size_t num_inputs = method_metas[i]->num_inputs(); + inputs[i-1].resize(num_inputs); + } + + while (std::getline(input_list, file_path)) { + auto input_files = split(file_path, " "); + if (input_files.size() == 0) { + break; + } + // inputs: [tokens, pos_ids, freqs_cos, freqs_sin, atten_mask, *k_cache, *v_cache] + // tokens are determined by command line arguments + // pos_ids, atten_mask are infered inside runner + + for (int input_index = 2; input_index < 4; ++input_index) { + std::ifstream fin(input_files[input_index], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + + freqs_inputs[input_index-2].resize(file_size / sizeof(float)); + fin.seekg(0, fin.beg); + fin.read(reinterpret_cast(freqs_inputs[input_index-2].data()), file_size); + fin.close(); + } + + std::vector> managed_kv_inputs(method_metas.size()-2); + for (int i = 1; i < method_metas.size()-1; ++i){ + size_t num_inputs = method_metas[i]->num_inputs(); + const int k_caches_end = (num_inputs - 4) / 2; + + // TODO: need to handle batch size != 1 + // k caches init + for (int input_index = 4; input_index < k_caches_end; ++input_index) { + Result tensor_meta = + method_metas[i]->input_tensor_meta(input_index); + int file_index = (i-1) * (num_inputs - 4) + input_index + 1; + std::ifstream fin(input_files[file_index], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + + ET_CHECK_MSG( + file_size == tensor_meta->nbytes(), + "Input(%d) size mismatch. file bytes: %zu, tensor bytes: %zu", + file_index, + file_size, + tensor_meta->nbytes()); + + // to simplify kv_cache update logic, we use (bsz, head_dim+2, seq) + // for fast pointer shifting + // head_dim+1 is the buffer of last word + // head_dim+2 is for output + inputs[i-1][input_index].resize(tensor_meta->nbytes() + 2*(tensor_meta->nbytes()/tensor_meta->sizes()[1])); + fin.close(); + + auto tensor_shape = tensor_meta->sizes(); + std::vector sizes( + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); + + managed_kv_inputs[i-1].emplace_back(ManagedTensor( + inputs[i-1][input_index].data(), 128, sizes, tensor_meta->scalar_type())); + } + + // v caches init + for (int input_index = k_caches_end; input_index < num_inputs; ++input_index) { + Result tensor_meta = + method_metas[i]->input_tensor_meta(input_index); + int file_index = (i-1) * (num_inputs - 4) + input_index + 1; + std::ifstream fin(input_files[file_index], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + + ET_CHECK_MSG( + file_size == tensor_meta->nbytes(), + "Input(%d) size mismatch. file bytes: %zu, tensor bytes: %zu", + file_index, + file_size, + tensor_meta->nbytes()); + + // to simplify v_cache update logic, we use (bsz, 2*max_seq_len, head_dim) + // for fast pointer shifting + inputs[i-1][input_index].resize(2*tensor_meta->nbytes()); + fin.close(); + + auto tensor_shape = tensor_meta->sizes(); + std::vector sizes( + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); + + managed_kv_inputs[i-1].emplace_back(ManagedTensor( + inputs[i-1][input_index].data(), 128, sizes, tensor_meta->scalar_type())); + } + } + + // generate tokens + std::string inference_output; + runner.generate( + prompt, seq_len, managed_kv_inputs, freqs_inputs, [&](const std::string& piece) { + inference_output += piece; + }); + + auto output_file_name = FLAGS_output_folder_path + "/output_" + + std::to_string(inference_index++) + "_0.raw"; + std::ofstream fout(output_file_name.c_str()); + fout << inference_output; + fout.close(); + } + + return 0; +} diff --git a/examples/qualcomm/llama2/composite_llama.py b/examples/qualcomm/llama2/composite_llama.py new file mode 100644 index 00000000000..da4d19f7480 --- /dev/null +++ b/examples/qualcomm/llama2/composite_llama.py @@ -0,0 +1,873 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import codecs +import gc +import getpass +import json +import os +import shutil +import stat +import sys +from pathlib import Path + +sys.setrecursionlimit(4096) + +import time +from typing import List, Tuple + +import torch + +from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner +from executorch.backends.qualcomm.passes.build_quant_io import BuildQuantIo +from executorch.backends.qualcomm.passes.utils import q_io_key + +from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype +from executorch.backends.qualcomm.quantizer.utils import get_16a4w_qnn_ptq_config +from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( + QcomChipset, +) +from executorch.backends.qualcomm.utils.utils import ( + capture_program, + convert_linear_to_conv2d, + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, +) +from executorch.examples.models.llama2.builder import DType +from executorch.examples.models.llama2.llama_transformer import precompute_freqs_cis +from executorch.examples.qualcomm.llama2.model.static_llama import LlamaModel, ModelArgs +from executorch.examples.qualcomm.scripts.utils import ( + make_output_dir, + setup_common_args_and_variables, + SimpleADB, +) +from executorch.exir import EdgeCompileConfig, EdgeProgramManager +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass + +from sentencepiece import SentencePieceProcessor +from torch.ao.quantization.observer import MinMaxObserver +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + + +def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: + """ + This function is specific for matmul op 16a8w. + """ + from typing import Sequence + + from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a8w_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, + QuantizationConfig, + ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY + from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + SharedQuantizationSpec, + ) + from torch.fx import Node + + def annotate_matmul(node: Node, quantization_config: QuantizationConfig): + input_qspec_map = {} + input_act = node.args[0] + assert isinstance(input_act, Node) + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + input_act1 = node.args[1] + input_spec1 = quantization_config.weight + input_qspec_map[input_act1] = input_spec1 + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_cat(node: Node, quantization_config: QuantizationConfig): + input_nodes = node.args[0] + + assert isinstance(input_nodes, Sequence) + + first_input_node = input_nodes[0] + input_qspec_map = {} + assert isinstance(first_input_node, Node) + assert isinstance(node, Node) + input_qspec_map[first_input_node] = quantization_config.input_activation + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, node) + ) + + for input_node in input_nodes[1:]: + if input_node not in input_qspec_map: + assert isinstance(input_node, Node) + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act0_qspec, + _annotated=True, + ) + + def annotate_single_in_single_out( + node: Node, quantization_config: QuantizationConfig + ) -> None: + + input_qspec_map = {} + input_act = node.args[0] + assert isinstance(input_act, Node) + input_qspec_map[input_act] = quantization_config.input_activation + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_matmul_input1(node: Node): + quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) + while isinstance(node, Node) and node.op == "call_function": + if node.target in [ + torch.ops.aten.permute.default, + torch.ops.aten.transpose.int, + ]: + annotate_single_in_single_out(node, quantization_config_8a8w) + node = node.args[0] + elif node.target == torch.ops.aten.cat.default: + annotate_cat(node, quantization_config_8a8w) + node = node.args[0][0] + else: + node = node.args[0] + + quantization_config_16a8w = get_16a8w_qnn_ptq_config() + + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: + annotate_matmul(node, quantization_config_16a8w) + annotate_matmul_input1(node.args[1]) + + +def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: + from executorch.backends.qualcomm.quantizer.quantizer import ( + get_ptq_per_channel_weight_config, + QuantizationConfig, + ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY + from torch.ao.quantization.quantizer import QuantizationAnnotation + from torch.fx import Node + + def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: + input_qspec_map = {} + input_act = node.args[0] + assert isinstance(input_act, Node) + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + weight = node.args[1] + assert isinstance(weight, Node) + input_qspec_map[weight] = quantization_config.weight + + if len(node.args) > 2: + bias = node.args[2] + if isinstance(bias, Node): + if callable(quantization_config.bias): + input_qspec_map[bias] = quantization_config.bias(node) + else: + input_qspec_map[bias] = quantization_config.bias + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + quantization_config_16a8w_per_channel = get_ptq_per_channel_weight_config( + torch.uint16, weight_dtype=torch.int8 + ) + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: + if "nn_module_stack" in node.meta: + module_values_list = list(node.meta["nn_module_stack"].values()) + full_qualified_name = module_values_list[0][0] + if full_qualified_name == "L['self'].llama.output": + annotate_conv2d( + node, quantization_config=quantization_config_16a8w_per_channel + ) + + +def calibrate( + example_inputs, n_heads, layers_per_ctx, modules: List[torch.fx.GraphModule] +): + sp_model = SentencePieceProcessor(model_file="tokenizer.model") + _, _, freqs_cos, freqs_sin, atten_mask, k_caches, v_caches = example_inputs + + # TODO: change criteria & support batch inputs if necessary + pos = torch.tensor(0, dtype=torch.int32) + token_list = [sp_model.bos_id()] + user_prompts = ["Once"] + for prompt in user_prompts: + token_list += sp_model.encode(prompt) + + def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: + probs_sort, probs_indices = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0 + probs_sort /= probs_sort.sum(dim=-1, keepdim=True) + next_token = torch.multinomial(probs_sort, num_samples=1) + return probs_indices.gather(dim=-1, index=next_token) + + with torch.no_grad(): + while token_list[-1] != sp_model.eos_id() and pos < 128: + hidden_states = modules[0](torch.full((1, 1), token_list[pos])) + input_pos = torch.full((1, 1), pos) + k_caches_o_list = [] + v_caches_o_list = [] + for i, decode_module in enumerate(modules[1:-1]): + offset = i * layers_per_ctx * n_heads + k_caches_i = k_caches[offset : offset + layers_per_ctx * n_heads] + v_caches_i = v_caches[offset : offset + layers_per_ctx * n_heads] + hidden_states, k_caches_o, v_caches_o = decode_module( + hidden_states, + freqs_cos[input_pos][0], + freqs_sin[input_pos][0], + atten_mask, + k_caches_i, + v_caches_i, + ) + k_caches_o_list.extend(k_caches_o) + v_caches_o_list.extend(v_caches_o) + + logits = modules[-1](hidden_states) + # k_caches have been transposed ahead, the shpae is [batch, head_dim, seq-1] + k_caches = [ + torch.cat([k_cache[:, :, 1:], k_caches_o_list[i]], dim=-1) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat([v_cache[:, 1:, :], v_caches_o_list[i]], dim=1) + for i, v_cache in enumerate(v_caches) + ] + + pos += 1 + atten_mask[0][-pos - 1] = 0 + if pos >= len(token_list): + token_list.append(torch.argmax(logits[:, -1], dim=-1).item()) + + print(f"calibration data:\n{sp_model.decode(token_list)}") + + +class CompositeLlama: + def __init__(self, division, llama_model) -> None: + super().__init__() + self.division = division + self.layers_per_ctx = llama_model.n_layers // division + self.llama_model = llama_model + self.quant_dtype = None + self.split_modules, self.split_inputs = [], [] + self.llama_meta = self.llama_model.get_metadata() + self.has_quant_io = False + + def split_llama(self): + def get_block_module(llama, indexes): + class LlamaBlock(torch.nn.Module): + def __init__(self, llama, indexes) -> None: + super().__init__() + self.llama = llama + self.indexes = indexes + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + output_k_cache, output_v_cache = [], [] + for i, ind in enumerate(self.indexes): + offset = i * self.llama.n_heads + k_in = k_caches[offset : offset + self.llama.n_heads] + v_in = v_caches[offset : offset + self.llama.n_heads] + hidden_states, k, v = self.llama.layers[ind]( + x=hidden_states, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_in, + v_caches=v_in, + ) + output_k_cache.extend(k) + output_v_cache.extend(v) + + return hidden_states, output_k_cache, output_v_cache + + return LlamaBlock(llama, indexes) + + def get_affine_module(llama): + class LlamaAffine(torch.nn.Module): + def __init__(self, llama) -> None: + super().__init__() + self.llama = llama + + def forward(self, hidden_states): + hidden_states = self.llama.norm(hidden_states) + logits = self.llama.output(hidden_states) + return logits + + return LlamaAffine(llama) + + tokens, pos_ids, freqs_cos, freqs_sin, atten_mask, k_caches, v_caches = ( + self.get_example_inputs() + ) + + with torch.no_grad(): + # embedding + self.split_modules.append(self.llama_model.tok_embeddings) + self.split_inputs.append((tokens,)) + + # attentions + for i in range(self.division): + llama_block = get_block_module( + self.llama_model, + [*range(self.layers_per_ctx * i, self.layers_per_ctx * (i + 1))], + ) + offset = i * self.layers_per_ctx * self.llama_model.n_heads + k_caches_in = k_caches[ + offset : offset + self.layers_per_ctx * self.llama_model.n_heads + ] + v_caches_in = v_caches[ + offset : offset + self.layers_per_ctx * self.llama_model.n_heads + ] + self.split_modules.append(llama_block) + self.split_inputs.append( + ( + self.llama_model.tok_embeddings(tokens), + freqs_cos[pos_ids][0], + freqs_sin[pos_ids][0], + atten_mask, + k_caches_in, + v_caches_in, + ) + ) + + # affine layer + affine_block = get_affine_module(self.llama_model) + self.split_modules.append(affine_block) + self.split_inputs.append((self.llama_model.tok_embeddings(tokens),)) + + def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type=torch.float32): + if not self.has_quant_io: + return + + # shape of k caches and v caches + input_cache_shape = { + (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]), + (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]), + } + for n in gm.graph.nodes: + if ( + n.op == "placeholder" + and len(users := list(n.users)) == 1 + and users[0].meta["val"].size()[-2:] in input_cache_shape + ): + n.meta[q_io_key] = kv_type + elif n.op == "output": + for a in n.args[0]: + if ( + a.meta["val"].flatten().size()[0] + == self.llama_meta["get_head_dim"] + ): + a.meta[q_io_key] = kv_type + + def quantize(self, quant_dtype, custom_annotations=()): + self.quant_dtype = quant_dtype + quantizer = QnnQuantizer() + quantizer.set_per_channel_linear_quant(True) + quantizer.set_per_channel_conv_quant(True) + + if quant_dtype == QuantDtype.use_8a8w: + pass # default setting + elif quant_dtype == QuantDtype.use_16a4w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config( + get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) + ) + quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") + else: + raise AssertionError(f"No support for QuantDtype {quant_dtype}.") + quantizer.add_custom_quant_annotations(custom_annotations) + + self.has_quant_io = True + split_fx_graph_modules = [] + + with torch.no_grad(): + for nn_module, capture_inputs in zip(self.split_modules, self.split_inputs): + fx_graph_module = torch._export.capture_pre_autograd_graph( + nn_module, capture_inputs + ) + fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) + split_fx_graph_modules.append(fx_graph_module) + print("Quantizing the model...") + calibrate( + self.get_example_inputs(), + self.llama_model.n_heads, + self.layers_per_ctx, + split_fx_graph_modules, + ) + + self.split_modules = [ + convert_pt2e(fx_graph_module) for fx_graph_module in split_fx_graph_modules + ] + del self.llama_model + + def lowering_modules(self, work_space, kv_type=torch.float32): + + executorch_config = ExecutorchBackendConfig( + passes=[ + BuildQuantIo(), + ], + extract_constant_segment=False, + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=False, + alloc_graph_output=False, + ), + extract_delegate_segments=True, + ) + pte_filename_list = [] + index = len(self.split_modules) + with torch.no_grad(): + while index > 0: + # backend option + backend_options = generate_htp_compiler_spec( + use_fp16=True if self.quant_dtype is None else False, + use_multi_contexts=True, + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=QcomChipset.SM8650, + backend_options=backend_options, + # saver=True if index==5 else False + ) + partitioner = QnnPartitioner(compiler_specs) + pte_filename = f"llama2_qnn_{index-1}" + edge_prog = capture_program( + self.split_modules[index - 1], self.split_inputs[index - 1] + ) + self._tag_kv_ios( + edge_prog.exported_program.graph_module, kv_type=kv_type + ) + edge_prog_mgr = EdgeProgramManager( + edge_programs={"forward": edge_prog.exported_program}, + constant_methods=self.llama_meta, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + edge_prog_mgr = edge_prog_mgr.to_backend(partitioner) + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{work_space}/{pte_filename}.pte", "wb") as file: + exec_prog_mgr.write_to_file(file) + + del edge_prog + del edge_prog_mgr + del exec_prog_mgr + self.split_modules.pop() + self.split_inputs.pop() + gc.collect(generation=2) + pte_filename_list.insert(0, f"{work_space}/{pte_filename}.pte") + index -= 1 + return pte_filename_list + + def get_example_inputs(self): + tokens, pos_ids, atten_mask, k_caches, v_caches = ( + self.llama_model.get_example_inputs() + ) + freqs_cos, freqs_sin = precompute_freqs_cis( + self.llama_model.dim // self.llama_model.n_heads, + self.llama_model.max_seq_len, + self.llama_model.rope_freq_base, + ) + return (tokens, pos_ids, freqs_cos, freqs_sin, atten_mask, k_caches, v_caches) + + def get_export_inputs(self): + tokens, pos_ids, atten_mask, k_caches, v_caches = ( + self.llama_model.get_export_inputs() + ) + freqs_cos, freqs_sin = precompute_freqs_cis( + self.llama_model.dim // self.llama_model.n_heads, + self.llama_model.max_seq_len, + self.llama_model.rope_freq_base, + ) + export_inputs = [tokens, pos_ids, freqs_cos, freqs_sin, atten_mask] + for i in range(self.division): + offset = i * self.layers_per_ctx * self.llama_model.n_heads + k_caches_in = k_caches[ + offset : offset + self.layers_per_ctx * self.llama_model.n_heads + ] + v_caches_in = v_caches[ + offset : offset + self.layers_per_ctx * self.llama_model.n_heads + ] + export_inputs.append(k_caches_in) + export_inputs.append(v_caches_in) + + return tuple(export_inputs) + + +def create_device_inputs(example_inputs, kv_input_numel, kv_type=torch.float32): + # TODO: support batch inputs if necessary + input_list = "" + inputs, flat_inputs = [], [] + for input in example_inputs: + if isinstance(input, list): + for inp in input: + flat_inputs.append(inp) + else: + flat_inputs.append(input) + + for i, data in enumerate(flat_inputs): + input_list += f"input_0_{i}.raw " + if data.flatten().shape[0] == kv_input_numel: + data = data.to(dtype=kv_type) + inputs.append(data) + + input_list += "\n" + return tuple(inputs), input_list + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./llama2_qnn", + default="./llama2_qnn", + type=str, + ) + + parser.add_argument( + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", + action="store_true", + default=False, + ) + + parser.add_argument( + "-P", + "--ptq", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.", + default="16a4w", + ) + + parser.add_argument( + "--checkpoint", + help="Pass llama2 checkpoint.", + required=True, + type=str, + ) + + parser.add_argument( + "--params", + help="Pass llama2 params json file.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_bin", + help="Pass llama2 tokenizer binary.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_model", + help="Pass llama2 tokenizer model.", + type=str, + default=None, + ) + + parser.add_argument( + "--prompt", + help="User prompts for llama2.", + required=True, + type=str, + ) + + parser.add_argument( + "--seq_len", + help="Ouput sequence length for llama2.", + default=128, + type=int, + ) + + parser.add_argument( + "--temperature", + help="Sampling temperature for llama2.", + default=0.8, + type=float, + ) + + parser.add_argument( + "-d", + "--dtype-override", + default="fp32", + type=str, + choices=["fp32", "fp16"], + help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32", + ) + + parser.add_argument( + "--pre_gen_pte", + help="Pre-generated llama2.", + type=str, + ) + + args = parser.parse_args() + division = 4 + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + start_ts = time.time() + with open(args.params) as f: + config = ModelArgs(**json.load(f)) + # TODO: support batch inputs if necessary + config.max_batch_size = 1 + config.max_seq_len = 1024 + device = "cpu" + state_dict = torch.load(args.checkpoint, map_location=device, mmap=True) + end_load_ts = time.time() + print("torch.load checkpoint", end_load_ts - start_ts) + llama_instance = None + with torch.device("meta"): + llama_instance = LlamaModel(config, output_new_cache_only=True) + if "model" in state_dict: + state_dict = state_dict["model"] + llama_instance.load_state_dict( + state_dict, + strict=False, + assign=True, + ) + end_load_state_dict_ts = time.time() + print("instance.load_state_dict", end_load_state_dict_ts - end_load_ts) + + for l in llama_instance.layers: + if getattr(l.attention, "prepare_sha", None): + l.attention.prepare_sha() + kv_type = torch.uint8 + if args.ptq == "8a8w": + quant_dtype = QuantDtype.use_8a8w + elif args.ptq == "16a4w": + quant_dtype = QuantDtype.use_16a4w + else: + raise AssertionError( + f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." + ) + + if args.use_fp16: + quant_dtype = None + else: + assert args.tokenizer_model is not None, "Need tokenizer model for calibration" + + if args.dtype_override is not None: + dtype_override = DType[args.dtype_override] + llama_instance = llama_instance.to(dtype_override.to_torch_dtype()) + + llama_instance = convert_linear_to_conv2d(llama_instance) + + composite_llama = CompositeLlama(division, llama_instance.eval()) + kv_input_numel = ( + composite_llama.llama_meta["get_max_seq_len"] - 1 + ) * composite_llama.llama_meta["get_head_dim"] + start_split_ts = time.time() + inputs, input_list = create_device_inputs( + composite_llama.get_export_inputs(), kv_input_numel, kv_type + ) + pte_filename_list = [] + if args.pre_gen_pte is None: + composite_llama.split_llama() + end_split_ts = time.time() + print("composite_llama.split_llama()", end_split_ts - start_split_ts) + + if quant_dtype is not None: + composite_llama.quantize( + quant_dtype, + custom_annotations=( + annotate_matmul_16a8w, + annotate_linear_16a8w_in_affine_layer, + ), + ) + end_quantize_ts = time.time() + print( + "composite_llama.quantize(quant_dtype)", end_quantize_ts - end_split_ts + ) + del llama_instance + pte_filename_list = composite_llama.lowering_modules( + args.artifact, kv_type=kv_type + ) + assert len(pte_filename_list) != 0, "Failed to save pte file." + end_lowering_ts = time.time() + print("Complete Compile", end_lowering_ts - end_quantize_ts) + else: + for i in range(division + 2): + pte_filename = f"llama2_qnn_{i}" + pte_filename_list.append(f"{args.pre_gen_pte}/{pte_filename}.pte") + + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/composite_llama" + pte_filenames = [Path(pte_filename).name for pte_filename in pte_filename_list] + + runner_args = " ".join( + [ + f"--model_paths {','.join(pte_filenames)}", + "--output_folder_path outputs", + "--input_list_path input_list.txt", + f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}", + f"--prompt {args.prompt}", + f"--seq_len {args.seq_len}", + f"--temperature {args.temperature}", + ] + ) + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + "export ADSP_LIBRARY_PATH=. &&", + "export LD_LIBRARY_PATH=. &&", + f"./qnn_llama_runner {runner_args}", + ] + ) + + if not args.compile_only: + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + artifact_path=f"{args.build_folder}", + pte_path=pte_filename_list, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + runner="examples/qualcomm/qnn_llama_runner", + ) + adb.push(inputs=[inputs], input_list=input_list, files=[args.tokenizer_bin]) + adb.execute(custom_runner_cmd=runner_cmd) + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + outputs = [] + + def post_process(): + for f in sorted( + os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) + ): + with codecs.open( + os.path.join(output_data_folder, f), + "r", + encoding="utf-8", + errors="replace", + ) as fdata: + outputs.append(fdata.read()) + + adb.pull(output_path=args.artifact, callback=post_process) + + for idx, output in enumerate(outputs): + print(f"Results[{idx}]:\n{output}") + + else: + compile_only_dir = os.path.join(args.artifact, args.artifact) + to_device_dir = os.path.join(compile_only_dir, "to_device") + os.makedirs(to_device_dir, exist_ok=True) + # input_list + input_list_file = os.path.join(to_device_dir, "input_list.txt") + with open(input_list_file, "w") as f: + f.write(input_list) + + # write inputs + for idx, data in enumerate([inputs]): + flat_inputs = [] + for d in data: + if isinstance(d, list): + for dd in d: + flat_inputs.append(dd) + else: + flat_inputs.append(d) + for i, d in enumerate(flat_inputs): + filename = os.path.join(to_device_dir, f"input_{idx}_{i}.raw") + d.detach().numpy().tofile(filename) + + # binaries + arch_table = { + "SM8650": "75", + "SM8550": "73", + "SM8475": "69", + "SM8450": "69", + } + dsp_arch = arch_table[args.model] + qnn_sdk_root = os.getenv("QNN_SDK_ROOT") + + on_device_files = [ + os.path.join(qnn_sdk_root, "lib", "aarch64-android", "libQnnHtp.so"), + os.path.join( + qnn_sdk_root, + "lib", + f"hexagon-v{dsp_arch}", + "unsigned", + f"libQnnHtpV{dsp_arch}Skel.so", + ), + os.path.join( + qnn_sdk_root, "lib", "aarch64-android", f"libQnnHtpV{dsp_arch}Stub.so" + ), + os.path.join(qnn_sdk_root, "lib", "aarch64-android", "libQnnSystem.so"), + os.path.join(args.build_folder, "examples", "qualcomm", "qnn_llama_runner"), + os.path.join( + args.build_folder, + "backends", + "qualcomm", + "libqnn_executorch_backend.so", + ), + ] + pte_filename_list + + for on_device_file in on_device_files: + shutil.copy2(on_device_file, to_device_dir) + + # tokenizer + shutil.copy2(args.tokenizer_bin, to_device_dir) + + run_sh_lines = [ + "set -e", + 'SOURCEDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"', + f'adb_cmd="adb -s {args.device} -H {args.host}"', + f'${{adb_cmd}} shell "rm -rf {workspace} && mkdir -p {workspace}/outputs"', + f"${{adb_cmd}} push ${{SOURCEDIR}}/to_device/* {workspace}", + f'${{adb_cmd}} shell "{runner_cmd}"', + "echo", + "echo ----- output_0_0.raw -----", + "echo", + f'${{adb_cmd}} shell "cat {workspace}/outputs/output_0_0.raw"', + "", + ] + + run_sh_file = os.path.join(compile_only_dir, "run.sh") + with open(run_sh_file, "w") as fp: + fp.write("\n".join(run_sh_lines)) + + os.chmod(run_sh_file, stat.S_IRWXU | stat.S_IRWXG) + + print("Zipping files.....") + shutil.make_archive( + compile_only_dir, + "zip", + root_dir=args.artifact, + base_dir=os.path.relpath(compile_only_dir, args.artifact), + ) + + print(f"Compile only mode, necessary files are written to {compile_only_dir}") + print(f"And it's zipped as {compile_only_dir}.zip") diff --git a/examples/qualcomm/llama2/llama.py b/examples/qualcomm/llama2/llama.py new file mode 100644 index 00000000000..8c31d372f75 --- /dev/null +++ b/examples/qualcomm/llama2/llama.py @@ -0,0 +1,296 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import codecs +import json +import os +import sys + +from functools import partial + +import torch + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.utils.utils import convert_linear_to_conv2d +from executorch.examples.qualcomm.llama2.model.static_llama import LlamaModel, ModelArgs +from executorch.examples.qualcomm.scripts.utils import ( + build_executorch_binary, + make_output_dir, + setup_common_args_and_variables, + SimpleADB, +) + +from sentencepiece import SentencePieceProcessor +from torch.ao.quantization.observer import MinMaxObserver + + +def create_device_inputs(example_inputs): + # TODO: support batch inputs if necessary + input_list = "" + inputs, flat_inputs = [], [] + for input in example_inputs: + if isinstance(input, list): + for inp in input: + flat_inputs.append(inp) + else: + flat_inputs.append(input) + + for i, data in enumerate(flat_inputs): + input_list += f"input_0_{i}.raw " + inputs.append(data) + + input_list += "\n" + return tuple(inputs), input_list + + +def calibrate(example_inputs, module: torch.fx.GraphModule): + sp_model = SentencePieceProcessor(model_file="tokenizer.model") + _, _, atten_mask, k_caches, v_caches = example_inputs + + # TODO: change criteria & support batch inputs if necessary + pos = torch.tensor(0, dtype=torch.int32) + token_list = [sp_model.bos_id()] + user_prompts = ["Once"] + for prompt in user_prompts: + token_list += sp_model.encode(prompt) + + def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: + probs_sort, probs_indices = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0 + probs_sort /= probs_sort.sum(dim=-1, keepdim=True) + next_token = torch.multinomial(probs_sort, num_samples=1) + return probs_indices.gather(dim=-1, index=next_token) + + with torch.no_grad(): + while token_list[-1] != sp_model.eos_id() and pos < 128: + logits, new_k_caches, new_v_caches = module( + torch.full((1, 1), token_list[pos]), + torch.full((1, 1), pos), + atten_mask, + *k_caches, + *v_caches, + ) + k_caches = [ + torch.cat([k_cache[:, 1:, :], new_k_caches[i]], dim=1) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1) + for i, v_cache in enumerate(v_caches) + ] + + pos += 1 + atten_mask[0][-pos - 1] = 0 + if pos >= len(token_list): + probs = torch.softmax(logits[:, -1] / 0.8, dim=-1) + token_list.append(sample_top_p(probs, 0.9).item()) + + print(f"calibration data:\n{sp_model.decode(token_list)}") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./llama2_qnn", + default="./llama2_qnn", + type=str, + ) + + parser.add_argument( + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", + action="store_true", + default=False, + ) + + parser.add_argument( + "-P", + "--ptq", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.", + default="16a4w", + ) + + parser.add_argument( + "--checkpoint", + help="Pass llama2 checkpoint.", + required=True, + type=str, + ) + + parser.add_argument( + "--params", + help="Pass llama2 params json file.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_bin", + help="Pass llama2 tokenizer binary.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_model", + help="Pass llama2 tokenizer model.", + type=str, + default=None, + ) + + parser.add_argument( + "--prompt", + help="User prompts for llama2.", + required=True, + type=str, + ) + + parser.add_argument( + "--seq_len", + help="Ouput sequence length for llama2.", + default=128, + type=int, + ) + + parser.add_argument( + "--temperature", + help="Sampling temperature for llama2.", + default=0.8, + type=float, + ) + + parser.add_argument( + "--pre_gen_pte", + help="Pre-generated llama2.", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + with open(args.params) as f: + config = ModelArgs(**json.load(f)) + # TODO: support batch inputs if necessary + config.max_batch_size = 1 + + state_dict = torch.load(args.checkpoint) + if "model" in state_dict: + state_dict = state_dict["model"] + with torch.device("meta"): + instance = LlamaModel(config) + instance.load_state_dict(state_dict, strict=False, assign=True) + + inputs, input_list = create_device_inputs(instance.get_export_inputs()) + pte_filename = "llama2_qnn" + + if args.ptq == "8a8w": + quant_dtype = QuantDtype.use_8a8w + elif args.ptq == "16a4w": + quant_dtype = QuantDtype.use_16a4w + else: + raise AssertionError( + f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." + ) + + if args.use_fp16: + quant_dtype = None + else: + assert args.tokenizer_model is not None, "Need tokenizer model for calibration" + + # prepare sha if the function is provided + for l in instance.layers: + if getattr(l.attention, "prepare_sha", None): + l.attention.prepare_sha() + + if args.pre_gen_pte is None: + build_executorch_binary( + # try this if you want: convert_linear_to_conv2d(instance.eval()), + instance.eval(), + inputs, + args.model, + f"{args.artifact}/{pte_filename}", + partial(calibrate, instance.get_example_inputs()), + custom_annotations=(), + quant_dtype=quant_dtype, + per_channel_linear=True, + shared_buffer=args.shared_buffer, + metadata=instance.get_metadata(), + direct_io=True, + act_observer=MinMaxObserver, + ) + + if args.compile_only: + sys.exit(0) + + # build custom commands for qnn_llama_runner + pte_path = ( + f"{args.artifact}/{pte_filename}.pte" + if args.pre_gen_pte is None + else args.pre_gen_pte + ) + workspace = f"/data/local/tmp/executorch/{pte_filename}" + runner_args = " ".join( + [ + f"--model_path {pte_filename}.pte", + "--output_folder_path outputs", + "--input_list_path input_list.txt", + f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}", + f"--prompt {args.prompt}", + f"--seq_len {args.seq_len}", + f"--temperature {args.temperature}", + ] + ) + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + "export ADSP_LIBRARY_PATH=. &&", + "export LD_LIBRARY_PATH=. &&", + f"./qnn_llama_runner {runner_args}", + ] + ) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + artifact_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + runner="examples/qualcomm/qnn_llama_runner", + ) + adb.push(inputs=[inputs], input_list=input_list, files=[args.tokenizer_bin]) + adb.execute(custom_runner_cmd=runner_cmd) + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + outputs = [] + + def post_process(): + for f in sorted( + os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) + ): + with codecs.open( + os.path.join(output_data_folder, f), + "r", + encoding="utf-8", + errors="replace", + ) as fdata: + outputs.append(fdata.read()) + + adb.pull(output_path=args.artifact, callback=post_process) + + for idx, output in enumerate(outputs): + print(f"Results[{idx}]:\n{output}") diff --git a/examples/qualcomm/llama2/model/static_llama.py b/examples/qualcomm/llama2/model/static_llama.py new file mode 100644 index 00000000000..3b98700b220 --- /dev/null +++ b/examples/qualcomm/llama2/model/static_llama.py @@ -0,0 +1,353 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + +import torch +import torch.nn as nn + +from executorch.examples.models.llama2.llama_transformer import ( + FeedForward, + ModelArgs, + precompute_freqs_cis, + RMSNorm, +) + + +def apply_rotary_emb_single( + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + x_r, x_i = x[..., ::2], x[..., 1::2] + + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out + + +class LlamaAttention(nn.Module): + def __init__(self, config: ModelArgs, output_new_cache_only=False): + super().__init__() + self.dim = config.dim + self.n_heads = config.n_heads + self.head_dim = config.dim // config.n_heads + self.n_kv_heads = config.n_kv_heads + self.num_key_value_groups = config.n_heads // self.n_kv_heads + self.max_seq_len = config.max_seq_len + self.output_new_cache_only = output_new_cache_only + + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + + self.attn_softmax = torch.nn.Softmax(dim=-1) + + scale = float(self.head_dim) ** -0.5 + scale_tensor = torch.tensor( + [scale], dtype=torch.float32, requires_grad=False, device="cpu" + ).view(1, 1, 1) + self.register_buffer("scale_tensor", scale_tensor, False) + + def prepare_sha(self): + self.wq_sha = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wk_sha = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wv_sha = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + + self.forward_mha = self.forward + self.forward = self.forward_sha + + for i in range(self.n_heads): + self.wq_sha[i].weight.data.copy_( + self.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wk_sha[i].weight.data.copy_( + self.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wv_sha[i].weight.data.copy_( + self.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + + def forward_sha( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seqlen, _ = hidden_states.shape + + q = [wq_sha(hidden_states) for wq_sha in self.wq_sha] + k = [wk_sha(hidden_states) for wk_sha in self.wk_sha] + v = [wv_sha(hidden_states) for wv_sha in self.wv_sha] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1) + + output_kh, output_vh, output_y = [], [], [] + for i, _ in enumerate(k_caches): + # cat at the seq dim + kh = torch.cat([k_caches[i], k[i]], dim=-1) + vh = torch.cat([v_caches[i], v[i]], dim=1) + + attn = q[i] @ kh + attn = attn * self.scale_tensor + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh + + if self.output_new_cache_only: + output_kh.append(k[i]) + output_vh.append(v[i]) + else: + output_kh.append(kh) + output_vh.append(vh) + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.wo(y) + return y, output_kh, output_vh + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seqlen, _ = hidden_states.shape + + q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states) + q = q.view(bsz, seqlen, self.n_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + + q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) + k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 1) + + output_kh, output_vh, output_y = [], [], [] + + for i, _ in enumerate(k_caches): + # cat at the seq dim + kh = torch.cat( + [k_caches[i], k[:, :, :, i]], dim=-1 + ) # TODO verify the correctness + vh = torch.cat([v_caches[i], v[:, :, i, :]], dim=1) + + attn = q[:, :, i, :] @ kh.permute(0, 2, 1) + attn = attn * self.scale_tensor + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh + + output_kh.append(kh) + output_vh.append(vh) + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.wo(y) + + return y, output_kh, output_vh + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, output_new_cache_only=False): + super().__init__() + self.dim = config.dim + self.attention = LlamaAttention( + config=config, output_new_cache_only=output_new_cache_only + ) + self.feed_forward = FeedForward(config) + self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + h, k_cache, v_cache = self.attention( + hidden_states=self.attention_norm(x), + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + h = x + h + output = h + self.feed_forward(self.ffn_norm(h)) + return output, k_cache, v_cache + + +class LlamaModel(nn.Module): + def __init__(self, config: ModelArgs, output_new_cache_only=True): + super().__init__() + self.dim = config.dim + self.head_dim = config.dim // config.n_heads + self.max_batch_size = config.max_batch_size + self.max_seq_len = config.max_seq_len + self.n_heads = config.n_heads + self.n_kv_heads = config.n_kv_heads + self.n_layers = config.n_layers + self.vocab_size = config.vocab_size + self.rope_freq_base = config.rope_freq_base + self.output_new_cache_only = output_new_cache_only + + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, self.output_new_cache_only) + for _ in range(config.n_layers) + ] + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + freqs_cos, freqs_sin = precompute_freqs_cis( + config.dim // config.n_heads, + config.max_seq_len, + config.rope_freq_base, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + + def forward( + self, + tokens: torch.Tensor, + input_pos: torch.Tensor, + atten_mask: torch.Tensor, + *args, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + output_k_cache = [] + output_v_cache = [] + # following tensors should be invariant across batches + freqs_cos = self.freqs_cos[input_pos][0] + freqs_sin = self.freqs_sin[input_pos][0] + + hidden_states = self.tok_embeddings(tokens) + for ind, decoder_layer in enumerate(self.layers): + offset_k = ind * self.n_heads + offset_v = self.n_layers * self.n_heads + offset_k + k_caches = args[offset_k : offset_k + self.n_heads] + v_caches = args[offset_v : offset_v + self.n_heads] + hidden_states, k, v = decoder_layer( + hidden_states, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + output_k_cache.extend(k) + output_v_cache.extend(v) + + hidden_states = self.norm(hidden_states) + logits = self.output(hidden_states) + + return logits, output_k_cache, output_v_cache + + def get_example_inputs(self): + tokens = torch.randint( + self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32 + ) + pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) + k_cache, v_cache = [], [] + atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) + atten_mask[:, -1] = 0 + for _ in range(self.n_layers): + for _ in range(self.n_heads): + # transpose first to decrease the runtime efforts + k_cache.append( + torch.zeros( + self.max_batch_size, + self.head_dim, + self.max_seq_len - 1, + ) + ) + v_cache.append( + torch.zeros( + self.max_batch_size, + self.max_seq_len - 1, + self.head_dim, + ) + ) + return ( + tokens, + pos_ids, + atten_mask, + k_cache, + v_cache, + ) + + def get_export_inputs(self): + tokens = torch.randint( + self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32 + ) + pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) + # this is important for torch.export not to take it as dummy input + k_cache, v_cache = [], [] + atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) + atten_mask[:, -1] = 0 + for _ in range(self.n_layers): + for _ in range(self.n_heads): + # transpose first to decrease the runtime efforts + k_cache.append( + torch.randn( + self.max_batch_size, + self.head_dim, + self.max_seq_len - 1, + ) + ) + v_cache.append( + torch.randn( + self.max_batch_size, + self.max_seq_len - 1, + self.head_dim, + ) + ) + return ( + tokens, + pos_ids, + atten_mask, + k_cache, + v_cache, + ) + + def get_metadata(self): + # TODO: modify this when enabling LLAMA 7B + return { + "get_bos_id": 1, + "get_eos_id": 2, + "get_head_dim": self.dim // self.n_heads, + "get_max_batch_size": self.max_batch_size, + "get_max_seq_len": self.max_seq_len, + "get_n_bos": 1, + "get_n_eos": 1, + "get_n_kv_heads": self.n_heads, + "get_n_layers": self.n_layers, + "get_vocab_size": self.vocab_size, + } diff --git a/examples/qualcomm/llama2/runner/runner.cpp b/examples/qualcomm/llama2/runner/runner.cpp new file mode 100644 index 00000000000..379c4d2e41b --- /dev/null +++ b/examples/qualcomm/llama2/runner/runner.cpp @@ -0,0 +1,551 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple llama2 runner that includes preprocessing and post processing logic. +// The module takes in a string as input and emits a string as output. + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace torch { +namespace executor { + +namespace { +static constexpr auto kTopp = 0.9f; +void printReport(const Runner::Stats& stats); +std::string statsToJsonString(const Runner::Stats& stats); +} // namespace + +Runner::Runner( + const std::vector& model_path_list, + const std::string& tokenizer_path, + const float temperature) + : tokenizer_path_(tokenizer_path), + temperature_(temperature) { + for(auto& model_path : model_path_list){ + modules_.emplace_back(std::make_unique( + model_path, + Module::MlockConfig::UseMlockIgnoreErrors)); + ET_LOG( + Info, + "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", + model_path.c_str(), + tokenizer_path.c_str()); + } +} + +bool Runner::is_loaded() const { + bool loaded = true; + for(auto& module : modules_){ + loaded &= module->is_loaded(); + } + return loaded && tokenizer_ && sampler_; +} + +Error Runner::load() { + if (is_loaded()) { + return Error::Ok; + } + stats_.model_load_start_ms = util::time_in_ms(); + for(auto& module : modules_){ + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("forward")); + + + // Read out metadata from the model + ET_LOG(Info, "Reading metadata from model"); + const auto method_names = module->method_names(); + ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model"); + model_methods_ = method_names.get(); + vocab_size_ = getMetadataHelper(module.get(), "get_vocab_size", 32000); + bos_id_ = getMetadataHelper(module.get(), "get_bos_id", 1); + eos_id_ = getMetadataHelper(module.get(), "get_eos_id", 2); + n_bos_ = getMetadataHelper(module.get(), "get_n_bos", 1); + n_eos_ = getMetadataHelper(module.get(), "get_n_eos", 1); + max_seq_len_ = getMetadataHelper(module.get(), "get_max_seq_len", 128); + head_dim_ = getMetadataHelper(module.get(), "get_head_dim", 32); + } + // Load tokenizer + tokenizer_ = std::make_unique(vocab_size_, bos_id_, eos_id_); + tokenizer_->load(tokenizer_path_); + if (tokenizer_->bos_tok() != bos_id_) { + ET_LOG( + Error, + "Tokenizer's BOS id %lu does not match model's BOS id %d, will override tokenizer's BOS.", + tokenizer_->bos_tok(), + bos_id_); + } + if (tokenizer_->eos_tok() != eos_id_) { + ET_LOG( + Error, + "Tokenizer's EOS id %lu does not match model's EOS id %d, will override tokenizer's EOS.", + tokenizer_->eos_tok(), + eos_id_); + } + // Create sampler + sampler_ = std::make_unique( + vocab_size_, + temperature_, + kTopp, + static_cast(std::time(nullptr))); + stats_.model_load_end_ms = util::time_in_ms(); + + return Error::Ok; +} + +template +T Runner::getMetadataHelper(Module* module, std::string method_name, T default_val) { + T res = default_val; + if (model_methods_.count(method_name)) { + Result> outputs = module->execute(method_name); + if (outputs.ok()) { + std::vector outs = outputs.get(); + if (outs.size() > 0) { + res = outs[0].to(); + } + } + } else { + ET_LOG( + Info, + "The model does not contain %s method, using default value %lld", + method_name.c_str(), + (long long)default_val); + } + ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res); + return res; +} + +template +int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) { + T* logits = logits_tensor.mutable_data_ptr(); + + // Since the logits are for all tokens, get the last token probabilities + T* logits_last = logits; + return sampler_->sample(logits_last); +} + + +// Given an input token. Set up the inputs for the model and execute a single +// step. Returning the logits tensor. +Result Runner::run_model_step( + int64_t input_token, + Tensor& token, + Tensor& start_pos, + Tensor& atten_mask, + Tensor& freqs_cos, + Tensor& freqs_sin, + std::vector>& kv_tensors, + std::vector>& kv_outputs) { + token.mutable_data_ptr()[0] = input_token; + // embedding + std::vector inputs = {token}; + Result> outputs_res = modules_[0]->forward(inputs); + ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + EValue hidden_states = outputs_res.get()[0]; + + // llama block + std::vector>> llama_block_results; + for(int i = 1; i < modules_.size() - 1; ++i){ + inputs = {hidden_states, freqs_cos, freqs_sin, atten_mask}; + inputs.insert(inputs.end(), kv_tensors[i-1].begin(), kv_tensors[i-1].end()); + Result> llama_block_outputs_res = modules_[i]->forward(inputs); + ET_CHECK_OK_OR_RETURN_ERROR(llama_block_outputs_res.error()); + hidden_states = llama_block_outputs_res.get()[0]; + } + + // TODO: need to handle batch size != 1 + // update k_cache + size_t v_offset = kv_outputs[0][0].nbytes(); + size_t el_size = kv_outputs[0][0].element_size(); + size_t k_input_step = (max_seq_len_-1) * el_size; + for (int i = 1; i < modules_.size() - 1; ++i) { + int k_tensors_end = kv_tensors[i].size() / 2; + //update k caches + for (int j = 0, index = i-1; j < k_tensors_end; ++j) { + char *input_addr = static_cast(kv_tensors[index][j].mutable_data_ptr()); + char *output_addr = static_cast(kv_outputs[index][j].mutable_data_ptr()); + + // fill the output k values back + #pragma omp parallel for + for (int src = 0, dst = k_input_step; src < kv_outputs[index][j].nbytes(); src+=el_size, dst+=k_input_step) { + memcpy(input_addr+dst, output_addr+src, el_size); + } + + // inputs + ET_CHECK_MSG( + internal::set_tensor_data(kv_tensors[index][j], input_addr + kv_tensors[index][j].element_size(), kv_tensors[index][j].nbytes()) == Error::Ok, + "Failed to set input tensor when updating kv_cache"); + } + // update v caches + for (int j = k_tensors_end, index = i-1; j < kv_tensors[index].size(); ++j) { + // inputs + char *input_addr = static_cast(kv_tensors[index][j].mutable_data_ptr()) + v_offset; + ET_CHECK_MSG( + internal::set_tensor_data(kv_tensors[index][j], input_addr, kv_tensors[index][j].nbytes()) == Error::Ok, + "Failed to set input tensor when updating kv_cache"); + + // outputs + char *output_addr = static_cast(kv_outputs[index][j].mutable_data_ptr()) + v_offset; + ET_CHECK_MSG( + internal::set_tensor_data(kv_outputs[index][j], output_addr, kv_outputs[index][j].nbytes()) == Error::Ok, + "Failed to set output tensor when updating kv_cache"); + ET_CHECK_MSG( + modules_[i]->set_output_data_ptr(kv_outputs[index][j], j+1) == Error::Ok, + "Failed to set output tensor for llama block"); + } + } + + // affine module + inputs = {hidden_states}; + Result> logits_outputs_res = modules_[modules_.size()-1]->forward(inputs); + ET_CHECK_OK_OR_RETURN_ERROR(logits_outputs_res.error()); + + // Bump start_pos by 1 + start_pos.mutable_data_ptr()[0]++; + + // update atten_mask + atten_mask.mutable_data_ptr()[atten_mask.numel() - 1 - start_pos.const_data_ptr()[0]] = 0; + + return logits_outputs_res.get()[0].toTensor(); +} +// TODO: add overloaded method for on-device tokenize +Error Runner::generate( + const std::string& prompt, + int32_t seq_len, + std::vector>& managed_kv_inputs, + std::vector>& freqs_inputs, + std::function token_callback, + std::function stats_callback) { + ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); + ET_CHECK_MSG(is_loaded(), "Please invoke load method first"); + + // First token time only measures the time it takes to encode the prompt and + // return a response token. + stats_.inference_start_ms = util::time_in_ms(); + shouldStop_ = false; + + // Set the sequence length to the max seq length if not provided + seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_; + + Result> encode_res = + tokenizer_->encode(prompt, n_bos_, 0); + + ET_CHECK_OK_OR_RETURN_ERROR( + encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + int num_prompt_tokens = prompt_tokens.size(); + + ET_CHECK_MSG( + num_prompt_tokens < max_seq_len_, + "Max seq length exceeded - please increase max seq len value in static_llama.py"); + + ET_CHECK_MSG( + num_prompt_tokens < seq_len, + "Sequence length exceeded - please increase the seq_len value passed to generate()"); + + int32_t pos = 0, prev_token, cur_token = prompt_tokens[0]; + std::vector token_data = {1}; + std::vector token_shape = {1, 1}; + + std::vector start_pos_data = {0}; + std::vector start_pos_shape = {1, 1}; + + std::vector atten_mask_data(max_seq_len_); + std::fill(atten_mask_data.begin(), atten_mask_data.end()-1, -255.0); + atten_mask_data.back() = 0; + + std::vector freqs_cos_data(head_dim_/2); + std::fill(freqs_cos_data.begin(), freqs_cos_data.end(), 0.0); + + std::vector freqs_sin_data(head_dim_/2); + std::fill(freqs_sin_data.begin(), freqs_sin_data.end(), 0.0); + + std::vector freqs_cos_shape = {1, head_dim_/2}; + + std::vector freqs_sin_shape = {1, head_dim_/2}; + + std::vector atten_mask_shape = {1, max_seq_len_}; + + std::vector logits_data_shape = {1, vocab_size_}; + + // initialize tensor wrappers + ManagedTensor managed_token( + token_data.data(), 128, token_shape, ScalarType::Int); + ManagedTensor managed_pos_id( + start_pos_data.data(), 128, start_pos_shape, ScalarType::Int); + ManagedTensor managed_atten_mask( + atten_mask_data.data(), 128, atten_mask_shape, ScalarType::Float); + ManagedTensor managed_freqs_cos( + freqs_cos_data.data(), 128, freqs_cos_shape, ScalarType::Float); + ManagedTensor managed_freqs_sin( + freqs_sin_data.data(), 128, freqs_sin_shape, ScalarType::Float); + + + Tensor token = managed_token.get_aliasing_tensor(); + Tensor atten_mask = managed_atten_mask.get_aliasing_tensor(); + Tensor start_pos = managed_pos_id.get_aliasing_tensor(); + Tensor freqs_cos = managed_freqs_cos.get_aliasing_tensor(); + Tensor freqs_sin = managed_freqs_sin.get_aliasing_tensor(); + + // embedding + std::vector embedding_logits_data(vocab_size_); + ManagedTensor embedding_managed_logits( + embedding_logits_data.data(), 128, logits_data_shape, ScalarType::Float); + Tensor embedding_logits = embedding_managed_logits.get_aliasing_tensor(); + ET_CHECK_MSG( + modules_[0]->set_output_data_ptr(embedding_logits, 0) == Error::Ok, + "Failed to set output tensor for embedding module - logits"); + + // llama block + std::vector> llama_block_logit_tensor_data(modules_.size()-2); + std::vector llama_block_logit_tensors, kv_outputs_managed; + std::vector> kv_tensors(modules_.size()-2), kv_outputs(modules_.size()-2); + std::vector> methods_meta = get_methods_meta(); + + for (int i = 1; i < modules_.size() - 1; ++i){ + Result &cur_meta = methods_meta[i]; + std::vector logits_data(vocab_size_); + llama_block_logit_tensor_data.push_back(logits_data); + llama_block_logit_tensors.emplace_back(ManagedTensor( + logits_data.data(), 128, logits_data_shape, ScalarType::Float)); + Tensor logits = llama_block_logit_tensors.back().get_aliasing_tensor(); + const int k_caches_end = managed_kv_inputs[i-1].size()/2; + + // k caches init + for (int j = 0; j < k_caches_end; ++j) { + kv_tensors[i-1].push_back(managed_kv_inputs[i-1][j].get_aliasing_tensor()); + Result out_tensor_meta = cur_meta->output_tensor_meta(j+1); + auto tensor_shape = out_tensor_meta->sizes(); + std::vector out_tensor_shape( + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); + + int output_offset = (out_tensor_meta->nbytes()+kv_tensors[i-1][j].element_size()) * (max_seq_len_-1); + char *output_addr = static_cast(kv_tensors[i-1][j].mutable_data_ptr()) + output_offset; + + kv_outputs_managed.push_back(ManagedTensor( + output_addr, 128, out_tensor_shape, kv_tensors[i-1][j].scalar_type())); + kv_outputs[i-1].push_back(kv_outputs_managed.back().get_aliasing_tensor()); + ET_CHECK_MSG( + modules_[i]->set_output_data_ptr(kv_outputs[i-1][j], j+1) == Error::Ok, + "Failed to set output tensor for llama block"); + } + // v caches init + for (int j = k_caches_end; j < managed_kv_inputs[i-1].size(); ++j) { + kv_tensors[i-1].push_back(managed_kv_inputs[i-1][j].get_aliasing_tensor()); + char *output_addr = static_cast(kv_tensors[i-1][j].mutable_data_ptr()) + + (max_seq_len_-1)*head_dim_*kv_tensors[i-1][j].element_size(); + + Result out_tensor_meta = cur_meta->output_tensor_meta(j+1); + auto tensor_shape = out_tensor_meta->sizes(); + std::vector out_tensor_shape( + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); + + kv_outputs_managed.push_back(ManagedTensor( + output_addr, 128, out_tensor_shape, kv_tensors[i-1][j].scalar_type())); + kv_outputs[i-1].push_back(kv_outputs_managed.back().get_aliasing_tensor()); + ET_CHECK_MSG( + modules_[i]->set_output_data_ptr(kv_outputs[i-1][j], j+1) == Error::Ok, + "Failed to set output tensor for llama block"); + } + ET_CHECK_MSG( + modules_[i]->set_output_data_ptr(logits, 0) == Error::Ok, + "Failed to set output tensor for llama block - logits"); + } + + // affine layer + std::vector affine_logits_data(vocab_size_); + ManagedTensor affine_managed_logits( + affine_logits_data.data(), 128, logits_data_shape, ScalarType::Float); + Tensor affine_logits = affine_managed_logits.get_aliasing_tensor(); + ET_CHECK_MSG( + modules_[modules_.size()-1]->set_output_data_ptr(affine_logits, 0) == Error::Ok, + "Failed to set output tensor for affine module - logits"); + + // Start consuming user's prompts and generating new tokens + std::string final_output; + while (pos < seq_len - 1) { + for(int i = 0; i < head_dim_/2; i++){ + freqs_cos.mutable_data_ptr()[i] = freqs_inputs[0][pos*(head_dim_/2)+i]; + freqs_sin.mutable_data_ptr()[i] = freqs_inputs[1][pos*(head_dim_/2)+i]; + } + + // Run the model + Result logits_res = + run_model_step(cur_token, token, start_pos, atten_mask, freqs_cos, freqs_sin, kv_tensors, kv_outputs); + if (pos == num_prompt_tokens) { + stats_.first_token_ms = util::time_in_ms(); + } else if (pos == num_prompt_tokens - 1) { + stats_.prompt_eval_end_ms = util::time_in_ms(); + } + + ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); + exec_aten::Tensor& logits_tensor = logits_res.get(); + prev_token = cur_token; + long sample_start_time_ms = util::time_in_ms(); + + cur_token = logitsToToken(logits_tensor); + stats_.aggregate_sampling_time_ms += + util::time_in_ms() - sample_start_time_ms; + + // advance the state machine + if (pos < num_prompt_tokens - 1) { + // prefill, force the next token to be the next prompt token + cur_token = prompt_tokens[pos + 1]; + } + pos++; + + // print the token as string, decode it with the Tokenizer object + auto piece_res = tokenizer_->decode(prev_token, cur_token); + ET_CHECK(piece_res.ok()); + + if (token_callback) { + token_callback(piece_res.get()); + } + + if (shouldStop_) { + break; + } + + // data-dependent terminating condition: we have n_eos_ number of EOS + if (pos >= num_prompt_tokens && cur_token == eos_id_) { + ET_LOG(Info, "Reached to the end of generation"); + break; + } + } + stats_.inference_end_ms = util::time_in_ms(); + + if (pos == seq_len) { + ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len); + } + + stats_.num_prompt_tokens = num_prompt_tokens; + stats_.num_generated_tokens = pos - num_prompt_tokens; + printReport(stats_); + if (stats_callback) { + stats_callback(stats_); + } + + return Error::Ok; +} + +namespace { +void printReport(const Runner::Stats& stats) { + printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str()); + + ET_LOG( + Info, + "\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64, + stats.num_prompt_tokens, + stats.num_generated_tokens); + + ET_LOG( + Info, + "\tModel Load Time:\t\t%f (seconds)", + ((double)(stats.model_load_end_ms - stats.model_load_start_ms) / + stats.SCALING_FACTOR_UNITS_PER_SECOND)); + double inference_time_ms = + (double)(stats.inference_end_ms - stats.inference_start_ms); + ET_LOG( + Info, + "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND, + + (stats.num_generated_tokens) / + (double)(stats.inference_end_ms - stats.inference_start_ms) * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + double prompt_eval_time = + (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); + ET_LOG( + Info, + "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, + (stats.num_prompt_tokens) / prompt_eval_time * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + + double eval_time = + (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); + ET_LOG( + Info, + "\t\tGenerated %" PRIu64 + " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + stats.num_generated_tokens, + eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, + stats.num_generated_tokens / eval_time * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + + // Time to first token is measured from the start of inference, excluding + // model load time. + ET_LOG( + Info, + "\tTime to first generated token:\t%f (seconds)", + ((double)(stats.first_token_ms - stats.inference_start_ms) / + stats.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)", + stats.num_prompt_tokens + stats.num_generated_tokens, + (double)stats.aggregate_sampling_time_ms / + stats.SCALING_FACTOR_UNITS_PER_SECOND); +} + +std::string statsToJsonString(const Runner::Stats& stats) { + std::stringstream ss; + ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << "," + << "\"generated_tokens\":" << stats.num_generated_tokens << "," + << "\"model_load_start_ms\":" << stats.model_load_start_ms << "," + << "\"model_load_end_ms\":" << stats.model_load_end_ms << "," + << "\"inference_start_ms\":" << stats.inference_start_ms << "," + << "\"inference_end_ms\":" << stats.inference_end_ms << "," + << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << "," + << "\"first_token_ms\":" << stats.first_token_ms << "," + << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms + << "," + << "\"SCALING_FACTOR_UNITS_PER_SECOND\":" + << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}"; + return ss.str(); +} +} // namespace + +void Runner::stop() { + shouldStop_ = true; +} + +std::vector> Runner::get_methods_meta() { + std::vector> tmp; + for (auto& module : modules_){ + tmp.push_back(module->method_meta("forward")); + } + return tmp; +} + +// explicit instantiation of template methods +template int64_t Runner::getMetadataHelper( + Module* module, + std::string method_name, + int64_t default_val); +template bool Runner::getMetadataHelper( + Module* module, + std::string method_name, + bool default_val); + +} // namespace executor +} // namespace torch diff --git a/examples/qualcomm/llama2/runner/runner.h b/examples/qualcomm/llama2/runner/runner.h new file mode 100644 index 00000000000..ffda2eb37cb --- /dev/null +++ b/examples/qualcomm/llama2/runner/runner.h @@ -0,0 +1,108 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple llama2 runner that includes preprocessing and post processing logic. +// The module takes in a string as input and emits a string as output. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace torch { +namespace executor { + +class Runner { + public: + explicit Runner( + const std::vector& model_path_list, + const std::string& tokenizer_path, + const float temperature = 0.8f); + + struct Stats { + // Scaling factor for timestamps - in this case, we use ms. + const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; + // Time stamps for the different stages of the execution + // model_load_start_ms: Start of model loading. + long model_load_start_ms; + // model_load_end_ms: End of model loading. + long model_load_end_ms; + // inference_start_ms: Immediately after the model is loaded (or we check + // for model load), measure the inference time. + long inference_start_ms; + // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right + // before the inference loop starts + long prompt_eval_end_ms; + // first_token: Timestamp when the first generated token is emitted + long first_token_ms; + // inference_end_ms: End of inference/generation. + long inference_end_ms; + // Keep a running total of the time spent in sampling. + long aggregate_sampling_time_ms; + // Token count from prompt + int64_t num_prompt_tokens; + // Token count from generated (total - prompt) + int64_t num_generated_tokens; + }; + + bool is_loaded() const; + Error load(); + Error generate( + const std::string& prompt, + int32_t seq_len, + std::vector>& managed_kv_inputs, + std::vector>& freqs_inputs, + std::function token_callback = {}, + std::function stats_callback = {}); + void stop(); + std::vector> get_methods_meta(); + + private: + // metadata + template + T getMetadataHelper(Module*, std::string method_name, T default_val); + template + int32_t logitsToToken(const exec_aten::Tensor& logits_tensor); + Result run_model_step( + int64_t input_token, + Tensor& token, + Tensor& start_pos, + Tensor& atten_mask, + Tensor& freqs_cos, + Tensor& freqs_sin, + std::vector>& kv_tensors, + std::vector>& kv_outputs); + // metadata + int32_t vocab_size_; + int64_t bos_id_; + int64_t eos_id_; + int32_t n_bos_; + int32_t n_eos_; + int32_t max_seq_len_; + int32_t head_dim_; + std::unordered_set model_methods_; + std::vector> modules_; + std::string tokenizer_path_; + float temperature_; + std::unique_ptr tokenizer_; + std::unique_ptr sampler_; + bool shouldStop_{false}; + Stats stats_; +}; + +} // namespace executor +} // namespace torch diff --git a/examples/qualcomm/scripts/utils.py b/examples/qualcomm/scripts/utils.py index f8c28371619..6e43835bfdf 100755 --- a/examples/qualcomm/scripts/utils.py +++ b/examples/qualcomm/scripts/utils.py @@ -6,11 +6,12 @@ import argparse import os +import re import subprocess import sys from pathlib import Path -from typing import Optional +from typing import Callable, List, Optional import numpy as np @@ -30,9 +31,11 @@ generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, ) +from executorch.exir import EdgeCompileConfig, EdgeProgramManager from executorch.exir.backend.backend_api import to_backend from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from torch.ao.quantization.observer import MovingAverageMinMaxObserver from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -48,14 +51,15 @@ def __init__( host_id=None, error_only=False, shared_buffer=False, + runner="examples/qualcomm/qnn_executor_runner", ): self.qnn_sdk = qnn_sdk self.artifact_path = artifact_path - self.pte_path = pte_path + self.pte_path = pte_path if isinstance(pte_path, list) else [pte_path] self.workspace = workspace self.device_id = device_id self.host_id = host_id - self.working_dir = Path(self.pte_path).parent.absolute() + self.working_dir = Path(self.pte_path[0]).parent.absolute() self.input_list_filename = "input_list.txt" self.etdump_path = f"{self.workspace}/etdump.etdp" self.output_folder = f"{self.workspace}/outputs" @@ -68,6 +72,7 @@ def __init__( self.soc_model = arch_table[soc_model] self.error_only = error_only self.shared_buffer = shared_buffer + self.runner = runner def _adb(self, cmd): if not self.host_id: @@ -80,7 +85,7 @@ def _adb(self, cmd): cmds, stdout=subprocess.DEVNULL if self.error_only else sys.stdout ) - def push(self, inputs, input_list): + def push(self, inputs, input_list, files=None): self._adb(["shell", f"rm -rf {self.workspace}"]) self._adb(["shell", f"mkdir -p {self.workspace}"]) @@ -92,7 +97,6 @@ def push(self, inputs, input_list): # necessary artifacts for artifact in [ - f"{self.pte_path}", f"{self.qnn_sdk}/lib/aarch64-android/libQnnHtp.so", ( f"{self.qnn_sdk}/lib/hexagon-v{self.soc_model}/" @@ -104,39 +108,60 @@ def push(self, inputs, input_list): ), f"{self.qnn_sdk}/lib/aarch64-android/libQnnHtpPrepare.so", f"{self.qnn_sdk}/lib/aarch64-android/libQnnSystem.so", - f"{self.artifact_path}/examples/qualcomm/qnn_executor_runner", + f"{self.artifact_path}/{self.runner}", f"{self.artifact_path}/backends/qualcomm/libqnn_executorch_backend.so", input_list_file, - ]: + ] + self.pte_path: self._adb(["push", artifact, self.workspace]) # input data for idx, data in enumerate(inputs): - for i, d in enumerate(data): + # print("[Warning] inputs push are is skip") + # break + flat_inputs = [] + for input in data: + if isinstance(input, list): + for inp in input: + flat_inputs.append(inp) + else: + flat_inputs.append(input) + for i, d in enumerate(flat_inputs): file_name = f"{self.working_dir}/input_{idx}_{i}.raw" d.detach().numpy().tofile(file_name) self._adb(["push", file_name, self.workspace]) - def execute(self): + # extra files + if files is not None: + for f in files: + self._adb(["push", f, self.workspace]) + + def execute(self, custom_runner_cmd=None): self._adb(["shell", f"mkdir -p {self.output_folder}"]) # run the delegation - qnn_executor_runner_args = " ".join( - [ - f"--model_path {os.path.basename(self.pte_path)}", - f"--output_folder_path {self.output_folder}", - f"--input_list_path {self.input_list_filename}", - f"--etdump_path {self.etdump_path}", - "--shared_buffer" if self.shared_buffer else "", - ] - ) - qnn_executor_runner_cmds = " ".join( - [ - f"cd {self.workspace} &&", - "export ADSP_LIBRARY_PATH=. &&", - "export LD_LIBRARY_PATH=. &&", - f"./qnn_executor_runner {qnn_executor_runner_args}", - ] - ) + if custom_runner_cmd is None: + pte_path_str = ",".join( + [os.path.basename(pte_path) for pte_path in self.pte_path] + ) + qnn_executor_runner_args = " ".join( + [ + f"--model_paths {pte_path_str}", + f"--output_folder_path {self.output_folder}", + f"--input_list_path {self.input_list_filename}", + f"--etdump_path {self.etdump_path}", + "--shared_buffer" if self.shared_buffer else "", + ] + ) + qnn_executor_runner_cmds = " ".join( + [ + f"cd {self.workspace} &&", + "export ADSP_LIBRARY_PATH=. &&", + "export LD_LIBRARY_PATH=. &&", + f"./qnn_executor_runner {qnn_executor_runner_args}", + ] + ) + else: + qnn_executor_runner_cmds = custom_runner_cmd + self._adb(["shell", f"{qnn_executor_runner_cmds}"]) def pull(self, output_path, callback=None): @@ -156,25 +181,34 @@ def build_executorch_binary( inputs, # noqa: B006 soc_model, file_name, - dataset, + dataset: List[torch.Tensor] | Callable[[torch.fx.GraphModule], None], custom_annotations=(), skip_node_id_set=None, skip_node_op_set=None, quant_dtype: Optional[QuantDtype] = None, + per_channel_linear=False, # TODO: remove this once QNN fully supports linear + direct_io=False, # TODO: temporal workaround for llama shared_buffer=False, + metadata=None, + act_observer=MovingAverageMinMaxObserver, ): - if quant_dtype: + if quant_dtype is not None: quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) + quantizer.set_per_channel_linear_quant(per_channel_linear) if quant_dtype == QuantDtype.use_8a8w: pass # default setting elif quant_dtype == QuantDtype.use_16a16w: quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) + quantizer.set_bit16_op_quant_config( + get_default_16bit_qnn_ptq_config(act_observer=act_observer) + ) elif quant_dtype == QuantDtype.use_16a4w: quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) + quantizer.set_bit16_op_quant_config( + get_16a4w_qnn_ptq_config(act_observer=act_observer) + ) quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") else: raise AssertionError(f"No support for QuantDtype {quant_dtype}.") @@ -183,8 +217,11 @@ def build_executorch_binary( annotated_model = prepare_pt2e(captured_model, quantizer) print("Quantizing the model...") # calibration - for data in dataset: - annotated_model(*data) + if callable(dataset): + dataset(annotated_model) + else: + for data in dataset: + annotated_model(*data) quantized_model = convert_pt2e(annotated_model) edge_prog = capture_program(quantized_model, inputs) @@ -208,29 +245,45 @@ def build_executorch_binary( debug=False, saver=False, shared_buffer=shared_buffer, + profile=False, ), skip_node_id_set, skip_node_op_set, ) - edge_prog.exported_program = to_backend(edge_prog.exported_program, qnn_partitioner) - edge_prog.exported_program.graph_module.graph.print_tabular() - exec_prog = edge_prog.to_executorch( - config=ExecutorchBackendConfig( - extract_constant_segment=False, - # For shared buffer, user must pass the memory address - # which is allocated by RPC memory to executor runner. - # Therefore, won't want to pre-allocate - # by memory manager in runtime. - memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", - alloc_graph_input=not shared_buffer, - alloc_graph_output=not shared_buffer, - ), - extract_delegate_segments=True, - ) + + executorch_config = ExecutorchBackendConfig( + extract_constant_segment=False, + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=not shared_buffer and not direct_io, + alloc_graph_output=not shared_buffer and not direct_io, + ), + extract_delegate_segments=True, ) - with open(f"{file_name}.pte", "wb") as file: - file.write(exec_prog.buffer) + + if metadata is None: + edge_prog.exported_program = to_backend( + edge_prog.exported_program, qnn_partitioner + ) + edge_prog.exported_program.graph_module.graph.print_tabular() + exec_prog = edge_prog.to_executorch(config=executorch_config) + with open(f"{file_name}.pte", "wb") as file: + file.write(exec_prog.buffer) + else: + edge_prog_mgr = EdgeProgramManager( + edge_programs={"forward": edge_prog.exported_program}, + constant_methods=metadata, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + edge_prog_mgr = edge_prog_mgr.to_backend(qnn_partitioner) + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{file_name}.pte", "wb") as file: + file.write(exec_prog_mgr.buffer) def make_output_dir(path: str): diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 83ded144469..015c6cfc68e 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -156,4 +156,11 @@ Result> Module::execute( return outputs; } +Error Module::set_output_data_ptr(Tensor& output_tensor, size_t output_index) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method("forward")); + auto& method = methods_.at("forward").method; + return method->set_output_data_ptr( + output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index); +} + } // namespace torch::executor diff --git a/extension/module/module.h b/extension/module/module.h index fb70cb08417..83fff368db8 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -191,6 +191,16 @@ class Module final { return event_tracer_.get(); } + /** + * Set output data pointer for forward method. + * + * @param[in] output_tensor A Tensor for the output of 'forward' method. + * @param[in] output_index Index of the output in 'forward' method. + * + * @returns An Error to indicate success or failure of the loading process. + */ + Error set_output_data_ptr(Tensor& output_tensor, size_t output_index); + private: struct MethodHolder { std::vector> planned_buffers;