diff --git a/backends/qualcomm/builders/op_avg_pool2d.py b/backends/qualcomm/builders/op_avg_pool2d.py index f4762e8bb5a..bd7094cc382 100644 --- a/backends/qualcomm/builders/op_avg_pool2d.py +++ b/backends/qualcomm/builders/op_avg_pool2d.py @@ -23,6 +23,12 @@ class AvgPool2d(NodeVisitor): def __init__(self, *args) -> None: super().__init__(*args) + def _get_filter_size(self, node): + filter_size = cast(List[int], node.args[1]) + if len(filter_size) == 1: + filter_size = filter_size + filter_size + return filter_size + def define_node( self, node: torch.fx.Node, @@ -46,31 +52,44 @@ def define_node( PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) + + pt_ceil_mode = node.args[4] if len(node.args) >= 4 else False + # kernel info - filter_size = cast(List[int], node.args[1]) - if len(filter_size) == 1: - filter_size = filter_size + filter_size + input_shape = input_node.meta["val"].shape + input_h, input_w = input_shape[2], input_shape[3] + filter_size = self._get_filter_size(node) + if pt_ceil_mode: + # filter_size might larger than input_h, input_w, use min of them + filter_size = [min(filter_size[0], input_h), min(filter_size[1], input_w)] filter_size_shape = [len(filter_size)] - # stride info - default to kernel_size if not given - stride = cast(List[int], node.args[2]) if len(node.args) > 2 else filter_size - if len(stride) == 1: - stride = stride + stride - stride_shape = [len(stride)] - padding = [0, 0] if len(node.args) > 3: padding = cast(List[int], node.args[3]) if len(padding) == 1: padding = padding + padding + if pt_ceil_mode: + ori_filter_h, ori_filter_w = self._get_filter_size(node) + padding = [ + 0 if ori_filter_h > input_h else padding[0], + 0 if ori_filter_w > input_w else padding[1], + ] + padding_shape = [len(padding), len(padding)] # if ceil mode is True, use ceil instead of floor to compute the output shape - mode = OpPoolAvg2d.RoundingMode.FLOOR - if len(node.args) > 4: - ceil_mode = cast(bool, node.args[4]) - if ceil_mode: - mode = OpPoolAvg2d.RoundingMode.CEIL + mode = ( + OpPoolAvg2d.RoundingMode.CEIL + if pt_ceil_mode + else OpPoolAvg2d.RoundingMode.FLOOR + ) + + # stride info - default to kernel_size if not given + stride = cast(List[int], node.args[2]) if len(node.args) > 2 else filter_size + if len(stride) == 1: + stride = stride + stride + stride_shape = [len(stride)] count_include_pad = True if len(node.args) > 5: diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 7cf661a0e01..ae28076d270 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -970,6 +970,7 @@ def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator( [ torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, torch.ops.aten.conv1d.default, torch.ops.aten.conv_transpose2d.input, torch.ops.aten.conv_transpose1d.default, diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 23f9e8fd79c..025c0bee171 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -147,12 +147,13 @@ def forward(self, x, y): class AvgPoolModule(torch.nn.Module): - def __init__(self): + def __init__(self, kernel_size, stride, padding, ceil_mode): super().__init__() self.avgPool = torch.nn.AvgPool2d( - kernel_size=(2, 2), - padding=(1, 1), - stride=(1, 1), + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, count_include_pad=False, ) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index c2526ad9aa8..5a173d5d4d4 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -163,9 +163,19 @@ def test_qnn_backend_argmin(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_avg_pool2d(self): - module = AvgPoolModule() # noqa: F405 - sample_input = (torch.randn(1, 3, 2, 2),) - self.lower_module_and_test_output(module, sample_input) + modules = [ + AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405 + AvgPoolModule((1280, 1280), (1280, 1280), (0, 0), True), # noqa: F405 + AvgPoolModule((1280, 1280), (1280, 1280), (320, 320), True), # noqa: F405 + ] # noqa: F405 + sample_inputs = [ + (torch.randn(1, 3, 2, 2),), + (torch.randn(1, 1280, 7, 7),), + (torch.randn(1, 1280, 7, 7),), + ] + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_inputs[i]) def test_qnn_backend_batch_norm(self): modules = [BatchNorm(32), BatchNorm(32, False)] # noqa: F405 @@ -1271,10 +1281,20 @@ def test_qnn_backend_argmin(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_avg_pool2d(self): - module = AvgPoolModule() # noqa: F405 - sample_input = (torch.randn(1, 3, 2, 2),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + modules = [ + AvgPoolModule((2, 2), (1, 1), (1, 1), False), # noqa: F405 + AvgPoolModule((1280, 1280), (1280, 1280), (0, 0), True), # noqa: F405 + AvgPoolModule((1280, 1280), (1280, 1280), (320, 320), True), # noqa: F405 + ] # noqa: F405 + sample_inputs = [ + (torch.randn(1, 3, 2, 2),), + (torch.randn(1, 1280, 7, 7),), + (torch.randn(1, 1280, 7, 7),), + ] + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_inputs[i]) + self.lower_module_and_test_output(module, sample_inputs[i]) def test_qnn_backend_batch_norm(self): modules = [BatchNorm(32), BatchNorm(32, False)] # noqa: F405 @@ -3864,6 +3884,41 @@ def test_dino_v2(self): self.assertGreaterEqual(msg["top_1"], 70) self.assertGreaterEqual(msg["top_5"], 85) + def test_efficientnet(self): + if not self.required_envs([self.image_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/efficientnet.py" + "--dataset", + self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 70) + self.assertGreaterEqual(msg["top_5"], 85) + def test_efficientSAM(self): if not self.required_envs( [self.image_dataset, self.pretrained_weight, self.oss_repo] diff --git a/examples/qualcomm/oss_scripts/efficientnet.py b/examples/qualcomm/oss_scripts/efficientnet.py new file mode 100644 index 00000000000..b11ad7abc47 --- /dev/null +++ b/examples/qualcomm/oss_scripts/efficientnet.py @@ -0,0 +1,145 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import logging +import os +from multiprocessing.connection import Client + +import numpy as np + +import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + get_imagenet_dataset, + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, + topk_accuracy, +) +from transformers import AutoModelForImageClassification + + +def main(args): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + + data_num = 100 + if args.ci: + inputs = [(torch.rand(1, 3, 224, 224),)] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, targets, input_list = get_imagenet_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + image_shape=(256, 256), + crop_size=224, + ) + + module = ( + AutoModelForImageClassification.from_pretrained("google/efficientnet-b0") + .eval() + .to("cpu") + ) + pte_filename = "efficientnet_qnn_q16" + build_executorch_binary( + module.eval(), + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_16a16w, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + return + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + # top-k analysis + predictions = [] + for i in range(data_num): + predictions.append( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + ) + + k_val = [1, 5] + topk = [topk_accuracy(predictions, targets, k).item() for k in k_val] + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)})) + else: + for i, k in enumerate(k_val): + print(f"top_{k}->{topk[i]}%") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=False, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./efficientnet", + default="./efficientnet", + type=str, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e)