diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index c4fbdeae14b..d3bf98bae72 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -23,6 +23,8 @@ op_hardsigmoid, op_hardswish, op_hardtanh, + op_index, + op_index_put, op_layer_norm, op_linear, op_log_softmax, @@ -75,6 +77,8 @@ op_hardswish, op_hardtanh, op_hardsigmoid, + op_index, + op_index_put, op_layer_norm, op_linear, op_log_softmax, diff --git a/backends/qualcomm/builders/op_index.py b/backends/qualcomm/builders/op_index.py new file mode 100644 index 00000000000..6f8dc558fe5 --- /dev/null +++ b/backends/qualcomm/builders/op_index.py @@ -0,0 +1,83 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Index(NodeVisitor): + # schema = aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + target = ["aten.index.Tensor"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + if len(node.args[1]) > 1: + # TODO consider to implement it in a recursive way. + raise NotImplementedError("Not support tuple of tensor.") + + indices_node = node.args[1][0] + indices_tensor = self.get_tensor(indices_node, node).to(torch.int32) + assert indices_tensor.size(0) != 0, "Not support empty indices list" + + indices_tensor_wrapper = self.define_tensor( + indices_node, + indices_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + gather_input_tensors = [input_tensor_wrapper, indices_tensor_wrapper] + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + gather_output_tensors = [output_tensor_wrapper] + + gather_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpGather.op_name, + ) + gather_op.AddInputTensors(gather_input_tensors) + gather_op.AddOutputTensors(gather_output_tensors) + + # If support tuple of tensor, need to refine it based on len + gather_op.AddScalarParam( + OpGather.param_axis, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, + {"data": np.int32(0)}, + ) + + return gather_op diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py new file mode 100644 index 00000000000..af5311dfb2a --- /dev/null +++ b/backends/qualcomm/builders/op_index_put.py @@ -0,0 +1,83 @@ +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpScatterNd, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class IndexPutVisitor(NodeVisitor): + target = ["aten.index_put.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + indicies_node = node.args[1] + indices_list = [ + self.get_tensor(idx, idx) for idx in indicies_node if idx is not None + ] + + # Unpack the tuple + indices_unpacked = [torch.flatten(idx) for idx in indices_list] + + # Convert to 2-D tensor + indices_qnn = torch.cat(indices_unpacked).unsqueeze(0) + indice_node = [n for n in indicies_node if isinstance(n, torch.fx.Node)] + # TODO consider to write a pass to combine to one input tensor for indices + assert len(indice_node) == 1, "Not support mutilple indices tensor" + + indices_tensor_wrapper = self.define_tensor( + indice_node[0], + indices_qnn, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + value_node = node.args[2] + + value_tensor = self.get_tensor(value_node, node) + + value_tensor_wrapper = self.define_tensor( + value_node, + value_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + + index_put_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpScatterNd.op_name, + ) + index_put_op.AddInputTensors( + [input_tensor_wrapper, indices_tensor_wrapper, value_tensor_wrapper] + ) + index_put_op.AddOutputTensors([output_tensor_wrapper]) + + return index_put_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index dca47ebeec6..4a87e5dbbb3 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -124,13 +124,6 @@ class OpExpandDims: param_axis: str = "axis" -@dataclass(init=False, frozen=True) -class OpReduceSum: - op_name: str = "ReduceSum" - param_axes: str = "axes" - param_keep_dims: str = "keep_dims" - - @dataclass(init=False, frozen=True) class OpFullyConnected: op_name: str = "FullyConnected" @@ -144,13 +137,14 @@ class OpGather: @dataclass(init=False, frozen=True) -class OpGelu: - op_name: str = "Gelu" +class OpGatherND: + op_name: str = "GatherNd" + param_batch_dims: str = "batch_dims" @dataclass(init=False, frozen=True) -class OpSqrt: - op_name: str = "ElementWiseSquareRoot" +class OpGelu: + op_name: str = "Gelu" @dataclass(init=False, frozen=True) @@ -246,6 +240,13 @@ class OpReduceMean: param_keep_dims: str = "keep_dims" +@dataclass(init=False, frozen=True) +class OpReduceSum: + op_name: str = "ReduceSum" + param_axes: str = "axes" + param_keep_dims: str = "keep_dims" + + @dataclass(init=False, frozen=True) class OpRelu: op_name: str = "Relu" @@ -277,6 +278,12 @@ class OpResizeNearestNeighbor: param_half_pixel_centers: str = "half_pixel_centers" +@dataclass(init=False, frozen=True) +class OpScatterNd: + op_name: str = "ScatterNd" + param_reduction: str = "reduction" + + @dataclass(init=False, frozen=True) class OpSigmoid: op_name: str = "Sigmoid" @@ -307,6 +314,11 @@ class OpSplit: param_split_index: str = "split_index" +@dataclass(init=False, frozen=True) +class OpSqrt: + op_name: str = "ElementWiseSquareRoot" + + @dataclass(init=False, frozen=True) class OpSqueeze: op_name: str = "Squeeze" diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 61935cf3536..c60afc2dd33 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -13,8 +13,7 @@ exir_ops.edge.aten.clone.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.slice_scatter.default, - exir_ops.edge.aten.index.Tensor, - exir_ops.edge.aten.index_put.default, + exir_ops.edge.aten.copy.default, ] allow_list_operator = [ diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index f2265daf325..d31b4753a3d 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -784,6 +784,38 @@ def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> N ) +@register_annotator([torch.ops.aten.index.Tensor]) +def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + input_qspec_map = {} + input = node.args[0] + input_qspec_map[input] = quantization_config.input_activation + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=SharedQuantizationSpec((input, node)), + _annotated=True, + ) + + +@register_annotator( + [torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default] +) +def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None: + input = node.args[0] + value = node.args[2] + + input_qspec_map = {} + input_qspec_map[input] = quantization_config.input_activation + input_qspec_map[value] = SharedQuantizationSpec((input, node)) + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=SharedQuantizationSpec((input, node)), + _annotated=True, + ) + + @register_annotator([torch.ops.aten.expand.default]) def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None: annotate_in_out_obs_sharing_op(node, quantization_config) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index fe72b1e8930..ff52fc61b57 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -443,6 +443,29 @@ def forward(self, x): return self.hardtanh(x) +class Index(torch.nn.Module): + def __init__(self): + super().__init__() + self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]]) + self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]]) + + def forward(self, x): + return x[self.idx0] + x[self.idx1] + + +class IndexPut(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "k_cache", + torch.zeros((1, 1024, 12, 64), dtype=torch.float32), + ) + + def forward(self, input_pos, k_val): + k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val) + return k_out + + class LayerNorm(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 508a027da68..83757c5eaf9 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -256,6 +256,19 @@ def test_qnn_backend_hardtanh(self): sample_input = (torch.randn([2, 5, 1, 3]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_index(self): + module = Index() # noqa: F405 + sample_input = (torch.randn([8, 172, 64]),) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_index_put(self): + module = IndexPut() # noqa: F405 + sample_input = ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_interpolate_bilinear_2d(self): module = ResizeBilinear2D() # noqa: F405 sample_input = (torch.randn(2, 3, 4, 5),) @@ -827,6 +840,21 @@ def test_qnn_backend_hardtanh(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_index(self): + module = Index() # noqa: F405 + sample_input = (torch.randn([8, 172, 64]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_index_put(self): + module = IndexPut() # noqa: F405 + sample_input = ( + torch.tensor([2], dtype=torch.int32), + torch.randn([1, 1, 12, 64]), + ) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_interpolate_bilinear_2d(self): module = ResizeBilinear2D() # noqa: F405 sample_input = (torch.randn(2, 3, 4, 5),)