diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index b63a5583b10..8adbfde0b92 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -41,8 +41,10 @@ op_skip_ops, op_slice_copy, op_softmax, + op_sqrt, op_squeeze, op_sub, + op_sum_int_list, op_tanh, op_transpose, op_unsqueeze, @@ -86,7 +88,9 @@ op_slice_copy, op_softmax, op_squeeze, + op_sqrt, op_sub, + op_sum_int_list, op_tanh, op_transpose, op_unsqueeze, diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index 78d1e6244e9..9a593528219 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -62,7 +62,7 @@ def define_node( bias_node = node.args[2] # TODO remove this when qnn sdk support - if "scales" in bias_node.meta.get("quant_attrs"): + if "scales" in bias_node.meta.get("quant_attrs", {}): print( f"[WARNING] Fallback linear bias, {bias_node}. per channel bias quantization is not support yet." ) diff --git a/backends/qualcomm/builders/op_log_softmax.py b/backends/qualcomm/builders/op_log_softmax.py index c159b9bf00e..002dd5bc9b2 100644 --- a/backends/qualcomm/builders/op_log_softmax.py +++ b/backends/qualcomm/builders/op_log_softmax.py @@ -72,5 +72,4 @@ def define_node( PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, {"data": np.uint32(dim)}, ) - # pdb.set_trace() return log_softmax_op diff --git a/backends/qualcomm/builders/op_sqrt.py b/backends/qualcomm/builders/op_sqrt.py new file mode 100644 index 00000000000..7847d00e8b8 --- /dev/null +++ b/backends/qualcomm/builders/op_sqrt.py @@ -0,0 +1,59 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpSqrt, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class SQRT(NodeVisitor): + target = ["aten.sqrt.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + # tensor input + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + sqrt_input_tensors = [input_tensor_wrapper] + + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + sqrt_output_tensors = [output_tensor_wrapper] + + sqrt_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpSqrt.op_name, + ) + sqrt_op.AddInputTensors(sqrt_input_tensors) + sqrt_op.AddOutputTensors(sqrt_output_tensors) + + return sqrt_op diff --git a/backends/qualcomm/builders/op_sum_int_list.py b/backends/qualcomm/builders/op_sum_int_list.py new file mode 100644 index 00000000000..26cc262462e --- /dev/null +++ b/backends/qualcomm/builders/op_sum_int_list.py @@ -0,0 +1,80 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import cast, Dict, List + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import numpy as np +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpReduceSum, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class Sum(NodeVisitor): + target = ["aten.sum.dim_IntList"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + sum_input_tensors = [input_tensor_wrapper] + + # sum dims + sum_dims = cast(List[int], node.args[1]) + sum_dims = [sum_dim % len(input_node.meta["val"].shape) for sum_dim in sum_dims] + if "axis_order" in node.meta: + sum_dims = [node.meta["axis_order"].index(sum_dim) for sum_dim in sum_dims] + sum_dims_shape = [len(sum_dims)] + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + sum_output_tensors = [output_tensor_wrapper] + sum_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReduceSum.op_name, + ) + sum_op.AddInputTensors(sum_input_tensors) + sum_op.AddOutputTensors(sum_output_tensors) + sum_op.AddTensorParam( + OpReduceSum.param_axes, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(sum_dims_shape), + sum_dims_shape, + np.array(sum_dims, dtype=np.uint32), + True, + ) + + if len(node.args) > 2: + keep_dims = cast(bool, node.args[2]) + sum_op.AddScalarParam( + OpReduceSum.param_keep_dims, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, + {"data": keep_dims}, + ) + return sum_op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 82c50046bee..2adb3102357 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -106,6 +106,13 @@ class OpExpandDims: param_axis: str = "axis" +@dataclass(init=False, frozen=True) +class OpReduceSum: + op_name: str = "ReduceSum" + param_axes: str = "axes" + param_keep_dims: str = "keep_dims" + + @dataclass(init=False, frozen=True) class OpFullyConnected: op_name: str = "FullyConnected" @@ -123,6 +130,11 @@ class OpGelu: op_name: str = "Gelu" +@dataclass(init=False, frozen=True) +class OpSqrt: + op_name: str = "ElementWiseSquareRoot" + + @dataclass(init=False, frozen=True) class OpHardSwish: op_name: str = "HardSwish" diff --git a/backends/qualcomm/passes/layout_transform.py b/backends/qualcomm/passes/layout_transform.py index 8c86f1919ad..fbf1431f1a5 100644 --- a/backends/qualcomm/passes/layout_transform.py +++ b/backends/qualcomm/passes/layout_transform.py @@ -52,6 +52,9 @@ class LayoutTransform(ExportPass): exir_ops.edge.aten.bmm.default, exir_ops.edge.aten.full.default, exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.pow.Tensor_Scalar, *q_ops, *dq_ops, _operator.getitem, @@ -109,7 +112,10 @@ def is_layout_sensitive(self, node: torch.fx.Node) -> bool: return node.target in self.layout_sensitive_ops def is_layout_agnostic(self, node: torch.fx.Node) -> bool: - if node.target == exir_ops.edge.aten.mean.dim: + if node.target in [ + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sum.dim_IntList, + ]: # if dimemsion is not kept, we'll have no clue how to do layout transform if len(node.args) < 3 or not node.args[2]: return False diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index a6a8118d0b8..ac741b7dc14 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -42,6 +42,7 @@ def decorator(annotator: Callable): return decorator + def _is_input_float_tensor(node: Node): """Check if the input is not a float tensor, so that we can skip quantization for the node since observers only works with float Tensors @@ -175,6 +176,11 @@ def annotate_rsub(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) +@register_annotator([torch.ops.aten.sum.dim_IntList]) +def annotate_sum(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_binary(node, quantization_config) + + @register_annotator([torch.ops.aten.ceil.default]) def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -302,6 +308,11 @@ def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.sqrt.default]) +def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.gelu.default]) def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index edc7a469f7b..812380ae6af 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -122,8 +122,8 @@ def __init__( ) -> None: super().__init__() self.modules = [ - Conv2DSequential(), - Conv2DSequential(), + Conv2dSequential(), + Conv2dSequential(), Add(), Relu(), ] @@ -172,7 +172,7 @@ def forward(self, x, y): return CompositeReferenceModule(self.modules) -class Conv1DSequential(torch.nn.Module): +class Conv1dSequential(torch.nn.Module): def __init__(self): super().__init__() self.first = torch.nn.Conv1d( @@ -210,43 +210,6 @@ def forward(self, x): return x -class Conv2DSequential(torch.nn.Module): - def __init__(self): - super().__init__() - self.first = torch.nn.Conv2d( - in_channels=1, - out_channels=3, - kernel_size=(3, 3), - padding=1, - bias=True, - ) - self.second = torch.nn.Conv2d( - in_channels=3, - out_channels=2, - kernel_size=(3, 3), - padding=1, - bias=True, - ) - - def forward(self, x): - return self.second(self.first(x)) - - -class Conv2DSingle(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d( - in_channels=1, - out_channels=3, - kernel_size=(3, 3), - padding=1, - bias=True, - ) - - def forward(self, x): - return self.conv(x) - - class Conv2dAvgPool2d(torch.nn.Module): def __init__(self): super().__init__() @@ -321,6 +284,58 @@ def forward(self, x): return self.pool(self.conv(x)) +class Conv2dSequential(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.Conv2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + self.second = torch.nn.Conv2d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + + def forward(self, x): + return self.second(self.first(x)) + + +class Conv2dSingle(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + + def forward(self, x): + return self.conv(x) + + +class Conv2dSumReduceDim(torch.nn.Module): + def __init__(self): + super().__init__() + self.first = torch.nn.Conv2d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3), + padding=1, + bias=True, + ) + + def forward(self, x): + return torch.sum(self.first(x), dim=(2, 3), keepdim=False) + + class Div(torch.nn.Module): def __init__(self): super().__init__() @@ -691,7 +706,7 @@ def __init__(self): super().__init__() def forward(self, x): - return x / torch.sqrt(torch.tensor([64])) + return x / torch.sqrt(torch.tensor([64.0])) class Squeeze(torch.nn.Module): @@ -748,6 +763,14 @@ def forward(self, x): return 10 - x +class SumIntList(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sum(x, dim=(2, 3), keepdim=True) + + class Tanh(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index d539827fdb9..3874da9e981 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -95,12 +95,12 @@ def test_qnn_backend_clamp(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv1d(self): - module = Conv1DSequential() # noqa: F405 + module = Conv1dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3]),) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv2d(self): - module = Conv2DSequential() # noqa: F405 + module = Conv2dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) self.lower_module_and_test_output(module, sample_input) @@ -183,11 +183,10 @@ def test_qnn_backend_element_wise_mul(self): self.lower_module_and_test_output(module, sample_input) index += 1 - @unittest.skip("not yet implemented") def test_qnn_backend_element_wise_sqrt(self): modules = [Sqrt(), SqrtConstant()] # noqa: F405 - sample_input = (torch.randn([3, 1]),) for i, module in enumerate(modules): + sample_input = (torch.rand([3, 1]),) with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) @@ -357,6 +356,11 @@ def test_qnn_backend_squeeze(self): sample_input = (torch.randn([1, 3, 3]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sum_int_list(self): + module = SumIntList() # noqa: F405 + sample_input = (torch.randn([1, 4, 8, 8]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -421,6 +425,11 @@ def test_qnn_backend_conv2d_max_pool2d(self): sample_input = (torch.rand(1, 2, 14, 14),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_sum_reduce_dim(self): + module = Conv2dSumReduceDim() # noqa: F405 + sample_input = (torch.randn([1, 1, 3, 3]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_residual_block(self): module = ResidualBlockModule() # noqa: F405 sample_input = (torch.randn(1, 32, 28, 28),) @@ -494,7 +503,7 @@ def setUp(self): ) def test_qnn_backend_16a4w_conv2d(self): - module = Conv2DSingle() # noqa: F405 + module = Conv2dSingle() # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) module = self.get_qdq_module( module, sample_input, quant_dtype=QuantDtype.use_16a4w @@ -575,13 +584,13 @@ def test_qnn_backend_clamp(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv1d(self): - module = Conv1DSequential() # noqa: F405 + module = Conv1dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3]),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_conv2d(self): - module = Conv2DSequential() # noqa: F405 + module = Conv2dSequential() # noqa: F405 sample_input = (torch.randn([1, 1, 3, 3]),) module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -669,11 +678,10 @@ def test_qnn_backend_element_wise_mul(self): self.lower_module_and_test_output(module, sample_input) index += 1 - @unittest.skip("not yet implemented") def test_qnn_backend_element_wise_sqrt(self): modules = [Sqrt(), SqrtConstant()] # noqa: F405 - sample_input = (torch.randn([3, 1]),) for i, module in enumerate(modules): + sample_input = (torch.rand([3, 1]),) with self.subTest(i=i): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) @@ -873,6 +881,12 @@ def test_qnn_backend_stack(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_sum_int_list(self): + module = SumIntList() # noqa: F405 + sample_input = (torch.randn([1, 4, 8, 8]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_tanh(self): module = Tanh() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -946,6 +960,12 @@ def test_qnn_backend_conv2d_max_pool2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv2d_sum_reduce_dim(self): + module = Conv2dSumReduceDim() # noqa: F405 + sample_input = (torch.randn([1, 1, 3, 3]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_example_models(self): instances = [ {"module": DeepLabV3ResNet101Model(), "annotation": ()}, @@ -1095,6 +1115,7 @@ def test_qnn_backend_multi_contexts_composite(self): exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) + @unittest.expectedFailure def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=True) @@ -1227,6 +1248,7 @@ def test_qnn_backend_multi_contexts_composite(self): exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) + @unittest.expectedFailure def test_qnn_backend_profile_op(self): TestQNN.enable_profile = True backend_options = generate_htp_compiler_spec(use_fp16=False) @@ -1323,6 +1345,40 @@ def test_fbnet(self): self.assertGreaterEqual(msg["top_1"], 60) self.assertGreaterEqual(msg["top_5"], 90) + def test_ssd300_vgg16(self): + if not self.required_envs([self.pretrained_weight, self.oss_repo]): + self.skipTest("missing required envs") + + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/ssd300_vgg16.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--oss_repo", + self.oss_repo, + "--pretrained_weight", + self.pretrained_weight, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + self.assertGreaterEqual(msg["mAP"], 0.70) + class TestExampleScript(TestQNN): def required_envs(self, conditions=None) -> bool: @@ -1771,6 +1827,11 @@ def setup_environment(): help="Emit log only when error happened", action="store_true", ) + parser.add_argument( + "--oss_repo", + help="Path to open source software model repository", + type=str, + ) args, ns_args = parser.parse_known_args(namespace=unittest) TestQNN.host = args.host @@ -1785,6 +1846,7 @@ def setup_environment(): TestQNN.online_prepare = args.online_prepare TestQNN.enable_profile = args.enable_profile TestQNN.error_only = args.error_only + TestQNN.oss_repo = args.oss_repo TestQNN.shared_buffer = args.shared_buffer return sys.argv[:1] + ns_args diff --git a/examples/qualcomm/oss_scripts/ssd300_vgg16.py b/examples/qualcomm/oss_scripts/ssd300_vgg16.py new file mode 100644 index 00000000000..6457b68f7d6 --- /dev/null +++ b/examples/qualcomm/oss_scripts/ssd300_vgg16.py @@ -0,0 +1,277 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import sys +from multiprocessing.connection import Client +from pprint import PrettyPrinter + +import numpy as np +import torch + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.qualcomm.scripts.utils import ( + build_executorch_binary, + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) + + +def create_data_lists(voc07_path, data_size): + """ + Create lists of images, the bounding boxes and labels of the objects in these images, and save these to file. + + :param voc07_path: path to the 'VOC2007' folder + :param output_folder: folder where the JSONs must be saved + """ + from utils import parse_annotation + + voc07_path = os.path.abspath(voc07_path) + + # Test data + test_images = [] + test_objects = [] + n_objects = 0 + + # Find IDs of images in the test data + with open(os.path.join(voc07_path, "ImageSets/Main/test.txt")) as f: + ids = f.read().splitlines() + + for index, id in enumerate(ids): + if index >= data_size: + break + # Parse annotation's XML file + objects = parse_annotation(os.path.join(voc07_path, "Annotations", id + ".xml")) + if len(objects) == 0: + continue + test_objects.append(objects) + n_objects += len(objects) + test_images.append(os.path.join(voc07_path, "JPEGImages", id + ".jpg")) + + assert len(test_objects) == len(test_images) + + # TEST_images.json stores the file name of the images, and TEST_objects.json stores info such as boxes, labels, and difficulties + with open(os.path.join(voc07_path, "TEST_images.json"), "w") as j: + json.dump(test_images, j) + with open(os.path.join(voc07_path, "TEST_objects.json"), "w") as j: + json.dump(test_objects, j) + + print( + "\nThere are %d test images containing a total of %d objects. Files have been saved to %s." + % (len(test_images), n_objects, os.path.abspath(voc07_path)) + ) + + +def get_dataset(data_size, dataset_dir, download): + from datasets import PascalVOCDataset + from torchvision import datasets + + if download: + datasets.VOCSegmentation( + root=os.path.join(dataset_dir, "voc_image"), + year="2007", + image_set="test", + download=True, + ) + voc07_path = os.path.join(dataset_dir, "voc_image", "VOCdevkit", "VOC2007") + create_data_lists(voc07_path, data_size) + + # voc07_path is where the data and ground truth json file will be stored + test_dataset = PascalVOCDataset(voc07_path, split="test", keep_difficult=True) + + test_loader = torch.utils.data.DataLoader( + test_dataset, shuffle=True, collate_fn=test_dataset.collate_fn + ) + + inputs, input_list = [], "" + true_boxes = [] + true_labels = [] + true_difficulties = [] + for index, (images, boxes, labels, difficulties) in enumerate(test_loader): + if index >= data_size: + break + inputs.append((images,)) + input_list += f"input_{index}_0.raw\n" + true_boxes.extend(boxes) + true_labels.extend(labels) + true_difficulties.extend(difficulties) + + return inputs, input_list, true_boxes, true_labels, true_difficulties + + +def SSD300VGG16(pretrained_weight_model): + from model import SSD300 + + model = SSD300(n_classes=21) + checkpoint = torch.load(pretrained_weight_model, map_location="cpu") + model.load_state_dict(checkpoint["model"].state_dict()) + + return model.eval() + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./ssd300_vgg16", + default="./ssd300_vgg16", + type=str, + ) + + parser.add_argument( + "-d", + "--download", + help="If specified, download VOCSegmentation dataset by torchvision API", + action="store_true", + default=False, + ) + + parser.add_argument( + "--oss_repo", + help=( + "Repository that contains model backbone and score calculation." + "e.g., --M ./a-PyTorch-Tutorial-to-Object-Detection" + "Please clone the repository from https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-p", + "--pretrained_weight", + help=( + "Location of model pretrained weight." + "e.g., -p ./checkpoint_ssd300.pth.tar" + "Pretrained model can be found in the link https://github.com/sgrvinod/a-PyTorch-Tutorial-to-Object-Detection, under the Training Section" + ), + type=str, + required=True, + ) + + args = parser.parse_args() + + sys.path.insert(0, args.oss_repo) + + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + + data_num = 100 + inputs, input_list, true_boxes, true_labels, true_difficulties = get_dataset( + data_size=data_num, dataset_dir=args.artifact, download=args.download + ) + + pte_filename = "ssd300_vgg16_qnn" + model = SSD300VGG16(args.pretrained_weight) + + sample_input = (torch.randn((1, 3, 300, 300)),) + build_executorch_binary( + model, + sample_input, + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, + ) + + if args.compile_only: + sys.exit(0) + + # setup required paths accordingly + # qnn_sdk : QNN SDK path setup in environment variable + # artifact_path : path where artifacts were built + # pte_path : path where executorch binary was stored + # device_id : serial number of android device + # workspace : folder for storing artifacts on android device + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + artifact_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + ) + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + det_boxes = [] + det_labels = [] + det_scores = [] + + def post_process(): + from utils import calculate_mAP + + np.set_printoptions(threshold=np.inf) + + # output_xxx_0.raw is output of boxes, and output_xxx_1.raw is output of classes + for file_index in range(data_num): + boxes_filename = os.path.join( + output_data_folder, f"output_{file_index}_0.raw" + ) + category_filename = os.path.join( + output_data_folder, f"output_{file_index}_1.raw" + ) + + predicted_locs = np.fromfile(boxes_filename, dtype=np.float32).reshape( + [1, 8732, 4] + ) + predicted_locs = torch.tensor(predicted_locs) + + predicted_scores = np.fromfile(category_filename, dtype=np.float32).reshape( + [1, 8732, 21] + ) + predicted_scores = torch.tensor(predicted_scores) + + det_boxes_batch, det_labels_batch, det_scores_batch = model.detect_objects( + predicted_locs, + predicted_scores, + min_score=0.01, + max_overlap=0.45, + top_k=200, + ) + + det_boxes.extend(det_boxes_batch) + det_labels.extend(det_labels_batch) + det_scores.extend(det_scores_batch) + + pp = PrettyPrinter() + # Calculate mAP + APs, mAP = calculate_mAP( + det_boxes, + det_labels, + det_scores, + true_boxes, + true_labels, + true_difficulties, + ) + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"mAP": float(mAP)})) + else: + print("\nMean Average Precision (mAP): %.3f" % mAP) + pp.pprint(APs) + + adb.pull(output_path=args.artifact, callback=post_process)