diff --git a/backends/qualcomm/builders/op_dequantize.py b/backends/qualcomm/builders/op_dequantize.py index 9c351103949..f80103b4b89 100644 --- a/backends/qualcomm/builders/op_dequantize.py +++ b/backends/qualcomm/builders/op_dequantize.py @@ -56,20 +56,16 @@ def define_node( @register_node_visitor -class PerTensorDequantizeDefault(DequantizeOpBase): - target = ["quantized_decomposed.dequantize_per_tensor.default"] +class PerTensorDequantize(DequantizeOpBase): + target = [ + "quantized_decomposed.dequantize_per_tensor.default", + "quantized_decomposed.dequantize_per_tensor.tensor", + ] @register_node_visitor -class PerTensorDequantizeTensor(DequantizeOpBase): - target = ["quantized_decomposed.dequantize_per_tensor.tensor"] - - -@register_node_visitor -class PerChannelDequantizeDefault(DequantizeOpBase): - target = ["quantized_decomposed.dequantize_per_channel.default"] - - -@register_node_visitor -class PerChannelDequantizeTensor(DequantizeOpBase): - target = ["quantized_decomposed.dequantize_per_channel.tensor"] +class PerChannelDequantize(DequantizeOpBase): + target = [ + "quantized_decomposed.dequantize_per_channel.default", + "quantized_decomposed.dequantize_per_channel.tensor", + ] diff --git a/backends/qualcomm/passes/convert_hardsigmoid.py b/backends/qualcomm/passes/convert_hardsigmoid.py index dc0044da392..68fb8e11094 100644 --- a/backends/qualcomm/passes/convert_hardsigmoid.py +++ b/backends/qualcomm/passes/convert_hardsigmoid.py @@ -25,6 +25,10 @@ def call(self, graph_module: torch.fx.GraphModule): partitions = get_source_partitions(graph, [torch.nn.Hardsigmoid]) for _, src_partitions in partitions.items(): for src_partition in src_partitions: + if exir_ops.edge.aten.hardswish.default in [ + node.target for node in src_partition.nodes + ]: + continue if self.quantization_capture: # only one hardsigmoid op will be seen input_nodes = src_partition.input_nodes @@ -34,8 +38,6 @@ def call(self, graph_module: torch.fx.GraphModule): else: in_ops_target = exir_ops.edge.aten.add.Tensor out_ops_target = exir_ops.edge.aten.div.Tensor - # see the reverse engineering logic hardswish - # https://shorturl.at/pACEL input_nodes = [ n for n in src_partition.nodes if n.target is in_ops_target ] diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index fc879431307..b15a876a1f8 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -6,8 +6,10 @@ import json import subprocess import sys +import tempfile import unittest from multiprocessing.connection import Listener +from pathlib import Path import torch from executorch.backends.qualcomm.tests.utils import ( @@ -1102,6 +1104,19 @@ def test_qnn_backend_shared_buffer(self): expected_partitions=1, ) + def test_qnn_backend_online_prepare(self): + backend_options = generate_htp_compiler_spec(use_fp16=True) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.arch_table[TestQNN.model], + backend_options=backend_options, + debug=False, + saver=False, + online_prepare=True, + ) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + self.lower_module_and_test_output(module, sample_input) + class TestQNNQuantizedUtils(TestQNN): # TODO: refactor to support different backends @@ -1223,6 +1238,20 @@ def test_qnn_backend_shared_buffer(self): expected_partitions=1, ) + def test_qnn_backend_online_prepare(self): + backend_options = generate_htp_compiler_spec(use_fp16=False) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.arch_table[TestQNN.model], + backend_options=backend_options, + debug=False, + saver=False, + online_prepare=True, + ) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + class TestExampleOssScript(TestQNN): def required_envs(self, conditions=None) -> bool: @@ -1640,6 +1669,29 @@ def test_ptq_mobilebert(self): for k, v in cpu.items(): self.assertLessEqual(abs(v[0] - htp[k][0]), 5) + def test_export_example(self): + if not self.required_envs([self.model_name]): + self.skipTest("missing required envs") + + with tempfile.TemporaryDirectory() as tmp_dir: + cmds = [ + "python", + "qualcomm/scripts/export_example.py", + "--model_name", + self.model_name, + "--output_folder", + "{}/".format(tmp_dir), + "--generate_etrecord", + ] + + p = subprocess.Popen( + cmds, stdout=subprocess.DEVNULL, cwd=f"{self.executorch_root}/examples" + ) + p.communicate() + self.assertTrue( + Path("{0}/{1}.pte".format(tmp_dir, self.model_name)).exists() + ) + def setup_environment(): parser = setup_common_args_and_variables() @@ -1669,6 +1721,12 @@ def setup_environment(): default="", type=str, ) + parser.add_argument( + "-n", + "--model_name", + help="Input the model to export", + type=str, + ) parser.add_argument( "-o", "--online_prepare", @@ -1697,6 +1755,7 @@ def setup_environment(): TestQNN.artifact_dir = args.artifact_dir TestQNN.image_dataset = args.image_dataset TestQNN.pretrained_weight = args.pretrained_weight + TestQNN.model_name = args.model_name TestQNN.online_prepare = args.online_prepare TestQNN.enable_profile = args.enable_profile TestQNN.error_only = args.error_only diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index b6792b5d70b..0a9b7d064d1 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -19,6 +19,7 @@ ConvertBinaryOpsWithScalar, ) from executorch.backends.qualcomm.passes.convert_bmm_to_matmul import ConvertBmmToMatmul +from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid from executorch.backends.qualcomm.passes.convert_interpolate_with_upsample2d import ( ConvertInterpolateWithUpsample2D, ) @@ -103,6 +104,7 @@ def _transform(edge_program: ExportedProgram) -> None: graph_module = edge_program.graph_module RemoveClone()(graph_module) ConvertToLinear()(graph_module) + ConvertHardsigmoid()(graph_module) ConvertBmmToMatmul()(graph_module) ConvertInterpolateWithUpsample2D()(graph_module) I64toI32(edge_program)(graph_module) diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index cdb84f6e8c6..a6d2e6d1a3e 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -40,6 +40,14 @@ help="Generate ETRecord metadata to link with runtime results (used for profiling)", ) + parser.add_argument( + "-f", + "--output_folder", + type=str, + default="", + help="The folder to store the exported program", + ) + args = parser.parse_args() if args.model_name not in MODEL_NAME_TO_MODEL: @@ -92,7 +100,7 @@ ) if args.generate_etrecord: - etrecord_path = "etrecord.bin" + etrecord_path = args.output_folder + "etrecord.bin" generate_etrecord(etrecord_path, edge_copy, executorch_program) - save_pte_program(executorch_program, args.model_name) + save_pte_program(executorch_program, args.model_name, args.output_folder)