Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from __future__ import annotations

import logging
import tempfile
from pathlib import Path

import onnx

from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed
from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed, optimize_model
from ....tools.remove_initializer_from_input import remove_initializer_from_input
from ...fusions import FusionGelu, FusionLayerNormalization
from ...onnx_model import ONNXModel
from ...quant_utils import save_and_reload_model_with_shape_infer
from .fusion_lpnorm import FusionLpNormalization
from .fusion_spacetodepth import FusionSpaceToDepth

Expand Down Expand Up @@ -93,7 +93,7 @@ def qnn_preprocess_model(
"""
modified = False
model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input)
model = save_and_reload_model_with_shape_infer(model)
model = save_and_reload_optimize_model(model, shape_infer=True)
onnx_model = ONNXModel(model)

# Optionally, fix the dynamic input shapes.
Expand Down Expand Up @@ -178,6 +178,24 @@ def qnn_preprocess_model(
return modified


def save_and_reload_optimize_model(model: onnx.ModelProto, shape_infer: bool) -> onnx.ModelProto:
with tempfile.TemporaryDirectory(prefix="ort.qnn_preproc.") as qnn_preproc_tmp_dir:
model_in_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_input.onnx")
onnx.save_model(model, model_in_path, save_as_external_data=True)
if shape_infer:
model_infer_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_infer.onnx")
onnx.shape_inference.infer_shapes_path(str(model_in_path), str(model_infer_path))
model_in_path = model_infer_path
model_out_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_output.onnx")
optimize_model(model_in_path, model_out_path)
ret_model = onnx.load_model(model_out_path)
ret_metaprops = {"onnx.infer": "onnxruntime.tools.qnn.preprocess"}
if ret_model.metadata_props:
ret_metaprops.update(ret_model.metadata_props)
onnx.helper.set_model_props(ret_model, ret_metaprops)
return ret_model


class InputOutputNameMap:
def __init__(
self,
Expand Down
Loading