diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index 191edc4c6390d..a12aca47f5b65 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -6,15 +6,15 @@ from __future__ import annotations import logging +import tempfile from pathlib import Path import onnx -from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed +from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed, optimize_model from ....tools.remove_initializer_from_input import remove_initializer_from_input from ...fusions import FusionGelu, FusionLayerNormalization from ...onnx_model import ONNXModel -from ...quant_utils import save_and_reload_model_with_shape_infer from .fusion_lpnorm import FusionLpNormalization from .fusion_spacetodepth import FusionSpaceToDepth @@ -93,7 +93,7 @@ def qnn_preprocess_model( """ modified = False model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input) - model = save_and_reload_model_with_shape_infer(model) + model = save_and_reload_optimize_model(model, shape_infer=True) onnx_model = ONNXModel(model) # Optionally, fix the dynamic input shapes. @@ -178,6 +178,24 @@ def qnn_preprocess_model( return modified +def save_and_reload_optimize_model(model: onnx.ModelProto, shape_infer: bool) -> onnx.ModelProto: + with tempfile.TemporaryDirectory(prefix="ort.qnn_preproc.") as qnn_preproc_tmp_dir: + model_in_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_input.onnx") + onnx.save_model(model, model_in_path, save_as_external_data=True) + if shape_infer: + model_infer_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_infer.onnx") + onnx.shape_inference.infer_shapes_path(str(model_in_path), str(model_infer_path)) + model_in_path = model_infer_path + model_out_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_output.onnx") + optimize_model(model_in_path, model_out_path) + ret_model = onnx.load_model(model_out_path) + ret_metaprops = {"onnx.infer": "onnxruntime.tools.qnn.preprocess"} + if ret_model.metadata_props: + ret_metaprops.update(ret_model.metadata_props) + onnx.helper.set_model_props(ret_model, ret_metaprops) + return ret_model + + class InputOutputNameMap: def __init__( self,