diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index d7cb2d5ea0d0f..5a831a106ae08 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -13,8 +13,8 @@ # Header paths find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) - if(OpenVINO_VERSION VERSION_LESS 2024.5) - message(FATAL_ERROR "OpenVINO 2024.5 and newer are supported. Please, use latest OpenVINO release") + if(OpenVINO_VERSION VERSION_LESS 2025.0) + message(FATAL_ERROR "OpenVINO 2025.0 and newer are supported. Please, use latest OpenVINO release") endif() if(OpenVINO_VERSION VERSION_GREATER_EQUAL 2024.4) @@ -49,7 +49,7 @@ endif() add_dependencies(onnxruntime_providers_openvino onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) target_include_directories(onnxruntime_providers_openvino SYSTEM PUBLIC ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${OpenVINO_INCLUDE_DIR} ${OPENVINO_INCLUDE_DIR_LIST} ${PYTHON_INCLUDE_DIRS} $ENV{OPENCL_INCS} $ENV{OPENCL_INCS}/../../cl_headers/) - target_link_libraries(onnxruntime_providers_openvino ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${OPENVINO_LIB_LIST} ${ABSEIL_LIBS} Eigen3::Eigen) + target_link_libraries(onnxruntime_providers_openvino ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 ${OPENVINO_LIB_LIST} ${ABSEIL_LIBS} Eigen3::Eigen onnx_proto) target_compile_definitions(onnxruntime_providers_openvino PRIVATE FILE_NAME=\"onnxruntime_providers_openvino.dll\") diff --git a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc index 1841dfa2791e0..7f214e656e0ab 100644 --- a/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc +++ b/onnxruntime/core/optimizer/double_qdq_pairs_remover.cc @@ -52,6 +52,7 @@ static void ApplyNewInputValue(Graph& graph, Node& node, QDQ::InputIndex index, input_init.ToProto(new_input_tensor); auto new_name = graph.GenerateNodeArgName("DoubleQDQRemoved_" + node.InputDefs()[index]->Name()); new_input_tensor.set_name(new_name); + new_input_tensor.add_dims(1); NodeArg& new_input = graph_utils::AddInitializerWithExternalData(graph, new_input_tensor); graph_utils::ReplaceNodeInput(node, index, new_input); } diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 041d9c07e41fe..28804d2f76492 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -20,6 +20,7 @@ #include "core/providers/openvino/ov_interface.h" #include "core/providers/openvino/ov_versions/capability.h" #include "core/providers/openvino/qdq_transformations/qdq_stripping.h" +#include "core/providers/openvino/qdq_transformations/qdq_scales_fix.h" namespace onnxruntime { namespace openvino_ep { @@ -117,7 +118,9 @@ BackendManager::BackendManager(SessionContext& session_context, LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; if ((!session_context_.disable_dynamic_shapes && (session_context_.device_type.find("CPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos)) || + session_context_.device_type.find("GPU") != std::string::npos || + (session_context_.device_type.find("NPU") != std::string::npos && + session_context_.enable_causallm))) || (subgraph_context_.is_ep_ctx_graph)) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " << "Creating backend Dynamic Shapes"; @@ -429,8 +432,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, const auto& onnx_model_path_name = subgraph.ModelPath(); // QDQ stripping enabled only for the NPU and experimentally on the GPU - if ((session_context_.device_type.find("NPU") != std::string::npos || - session_context_.device_type.find("GPU") != std::string::npos) && + if ((session_context_.device_type.find("NPU") != std::string::npos) && (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights); @@ -440,6 +442,17 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; + } else if ((session_context_.device_type.find("GPU") != std::string::npos) && + enable_ovep_qdq_optimizer) { + // Create a copy of the model + std::unique_ptr model; + Status status = qdq_scales_fix::Transform(subgraph, logger, model); + auto model_proto = model->ToProto(); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + print_model_proto_duration(); + DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); + ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); + return model_proto; } else { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled"; auto model = subgraph.CreateModel(logger); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index df75f84a5fee0..61235ef2138b5 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -36,10 +36,6 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr if (ValidateSubgraph(const_outputs_map_)) return; - // Pre-requisite is provider_option "context" must be set - auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) || - (session_context_.OpenVINO_Version.at(0) >= 2024 && - session_context_.OpenVINO_Version.at(1) > 2)); ov::AnyMap device_config; SetOVDeviceConfiguration(device_config); if (subgraph_context_.is_ep_ctx_graph) { @@ -81,42 +77,46 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr ORT_THROW(msg); } // Delete stream after it is no longer needed } else { + std::shared_ptr ov_model; std::string model = model_proto->SerializeAsString(); if (!subgraph_context.has_dynamic_input_shape) { model_proto.reset(); } + bool eligible_for_cpu_fallback = session_context_.device_type.find("NPU") != std::string::npos && + !session_context_.so_disable_cpu_ep_fallback && + !subgraph_context_.is_ep_ctx_graph; +#if defined(OPENVINO_DISABLE_NPU_FALLBACK) + eligible_for_cpu_fallback = false; +#endif + auto auto_unified_compile = (hw_target.find("AUTO") == std::string::npos); + + // Unified compile is efficient with cahce_dir cached model loading that bypass Read Model + // Does not support model with exteral weights, dynamic input shape, Epctx onnx cached model, + // reshape, enable_causallm, and for NPU CPU fallback + + auto is_unified_compile = (!session_context_.has_external_weights && + !subgraph_context_.has_dynamic_input_shape && + !session_context_.so_context_enable && + session_context_.reshape.empty() && + !enable_causallm && + !eligible_for_cpu_fallback && + auto_unified_compile); try { - // SetOVDeviceConfiguration(device_config); - if (!session_context_.has_external_weights && - !subgraph_context_.has_dynamic_input_shape && - !session_context_.so_context_enable && - session_context_.reshape.empty() && - !enable_causallm && - auto_unified_compile) { - // Unified OV compile_model is efficient when ov model caching is enabled - // Unified OV compile_model API is supported with AUTO from version 2024.3 and above - // Inputs with static dimensions - // Not enabled for models with external weights and when ep context is set. - + if (is_unified_compile) { exe_network_ = OVCore::Get()->CompileModel(model, hw_target, device_config, subgraph_context_.subgraph_name); } else { // For all other types use ov::ov_core read_model() to generate OV IR // followed by ov::ov_core compile_model() - auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); + ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); exe_network_ = OVCore::Get()->CompileModel( ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name); } LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } catch (const OnnxRuntimeException& ex) { std::string exception_str = ex.what(); - bool eligible_for_cpu_fallback = session_context_.device_type.find("NPU") != std::string::npos && - !session_context_.so_disable_cpu_ep_fallback && - !subgraph_context_.is_ep_ctx_graph; -#if defined(OPENVINO_DISABLE_NPU_FALLBACK) - eligible_for_cpu_fallback = false; -#endif + if (eligible_for_cpu_fallback) { LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." << "Falling back to OV CPU for execution"; @@ -125,8 +125,6 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr device_config.clear(); SetOVDeviceConfiguration(device_config); try { - // Recreate the model with CPU device type - auto ov_model = CreateOVModel(std::move(model), session_context_, const_outputs_map_); exe_network_ = OVCore::Get()->CompileModel( ov_model, hw_target, device_config, enable_causallm, subgraph_context_.subgraph_name); } catch (std::string const& msg) { diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index fb1757199698b..f6bc5ad599e18 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -133,6 +133,22 @@ class OVInferRequest { auto tensor_ptr = std::make_shared(type, shape, const_cast(ort_ptr)); SetTensor(name, tensor_ptr); cached_binding = {tensor_ptr, ort_ptr}; + } else if (ort_ptr == nullptr) { + // a null ort_ptr is expected for a tensor that has 0 elements. + // for example, a tensor of shape=[1, 8, 0, 64], which is valid. + // So, we check to see if at least one shape entry is 0. + auto contains_zero = [](const ov::Shape& shape) { + for (auto& s : shape) + if (s == 0) return true; + return false; + }; + if (contains_zero(shape)) { + // if there are zero elements (i.e. at least one shape entry is 0), + // then create and set the tensor anyway. + auto tensor_ptr = std::make_shared(type, shape); + SetTensor(name, tensor_ptr); + cached_binding = {tensor_ptr, ort_ptr}; + } } } diff --git a/onnxruntime/core/providers/openvino/ov_protobuf_utils.cpp b/onnxruntime/core/providers/openvino/ov_protobuf_utils.cpp new file mode 100644 index 0000000000000..e28330e0bd433 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_protobuf_utils.cpp @@ -0,0 +1,24 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "ov_protobuf_utils.h" + +#include "core/graph/onnx_protobuf.h" +#include "core/common/common.h" + +namespace onnxruntime { +namespace openvino_ep { +float get_float_initializer_data(const void* initializer) { + const auto* tp = reinterpret_cast(initializer); + ORT_ENFORCE((tp->has_data_type() && (tp->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT))); + // ORT_ENFORCE(initializer.dims_size() == 1); + return tp->float_data(0); +} +void set_float_initializer_data(const void* initializer, float data) { + auto* tp = (ONNX_NAMESPACE::TensorProto*)(initializer); + ORT_ENFORCE((tp->has_data_type() && (tp->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT))); + // ORT_ENFORCE(initializer.dims_size() == 1); + tp->set_float_data(0, data); +} +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_protobuf_utils.h b/onnxruntime/core/providers/openvino/ov_protobuf_utils.h new file mode 100644 index 0000000000000..ba8f910cd9218 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_protobuf_utils.h @@ -0,0 +1,10 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once +namespace onnxruntime { +namespace openvino_ep { +float get_float_initializer_data(const void* initializer); +void set_float_initializer_data(const void* initializer, float data); +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 88ddde8610c6e..2309ff3de751b 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -41,16 +41,14 @@ GetCapability::GetCapability(const EPCtxHandler& ep_ctx_handler, npu_qdq_optimizer_enabled = true; // see data_ops.cc ~615 where we check for int16 types for gpu, this may change to a better approach later } -#if OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 5 - data_ops_ = new DataOps(graph_viewer_, V_2024_5, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2024 && OPENVINO_VERSION_MINOR == 6 - data_ops_ = new DataOps(graph_viewer_, V_2024_6, device_type_, npu_qdq_optimizer_enabled); -#elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 0 +#if OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 0 data_ops_ = new DataOps(graph_viewer_, V_2025_0, device_type_, npu_qdq_optimizer_enabled); #elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 1 data_ops_ = new DataOps(graph_viewer_, V_2025_1, device_type_, npu_qdq_optimizer_enabled); +#elif OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR == 2 + data_ops_ = new DataOps(graph_viewer_, V_2025_2, device_type_, npu_qdq_optimizer_enabled); #else - data_ops_ = new DataOps(graph_viewer_, V_2025_1, device_type_, npu_qdq_optimizer_enabled); + data_ops_ = new DataOps(graph_viewer_, V_2025_2, device_type_, npu_qdq_optimizer_enabled); #endif } diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 84001c1161efc..336b294117cba 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -229,6 +229,7 @@ std::vector supported_op_mode = { {"Sigmoid", V_2020_4, {"CPU", "GPU"}}, {"Sign", V_2020_4, {"CPU"}}, {"Sign", V_2022_1, {"GPU"}}, + {"SimplifiedLayerNormalization", V_2025_2, {"CPU", "GPU"}}, {"Sin", V_2022_1, {"CPU", "GPU"}}, {"Sinh", V_2020_4, {"CPU"}}, {"Size", V_2022_1, {"CPU", "GPU"}}, @@ -402,7 +403,7 @@ void DataOps::populate_op_mode_supported() { // populate unsupportedmode_t { - UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1}, + UnsupportedOpMode obj = {{V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1, V_2025_2}, [this](const Node* node, const InitializedTensorSet&) { // If the Input of ReduceMax op is UINT8, it is rejected (Due to output mismatch) for (size_t i = 0; i < node->InputDefs().size(); i++) { @@ -418,7 +419,8 @@ void DataOps::populate_op_mode_supported() { } { UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, - V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1}, + V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1, + V_2025_2}, [this](const Node* node, const InitializedTensorSet&) { const auto& input_args = node->InputDefs(); const auto& input_arg = (input_args.size() > 1) ? input_args[1] : input_args[0]; @@ -437,7 +439,8 @@ void DataOps::populate_op_mode_supported() { } { UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, - V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1}, + V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1, + V_2025_2}, [this](const Node* node, const InitializedTensorSet&) { // If the operator is unsqueeze // If axes is an input, then we cannot produce a static graph. @@ -452,8 +455,8 @@ void DataOps::populate_op_mode_supported() { op_list_.insert({"Unsqueeze", obj}); } { - UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, V_2024_6, - V_2025_0, V_2025_1}, + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, V_2024_3, V_2024_4, V_2024_5, + V_2024_6, V_2025_0, V_2025_1, V_2025_2}, [this](const Node* node, const InitializedTensorSet&) { // check for attributes auto& upsample_attr = node->GetAttributes(); diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h index cf7d834d6cfc7..95905e010541e 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.h +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.h @@ -35,7 +35,8 @@ enum versionNum { V_2024_5, V_2024_6, V_2025_0, - V_2025_1 + V_2025_1, + V_2025_2 }; using VersionNum = enum versionNum; diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp new file mode 100644 index 0000000000000..c1e4815c206a2 --- /dev/null +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -0,0 +1,944 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "qdq_scales_fix.h" +#include "core/providers/openvino/ov_protobuf_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace openvino_ep { + +namespace qdq_scales_fix { + +namespace fs = std::filesystem; +using NodeRef = std::reference_wrapper; +struct GraphNode; +float get_initializer_value(const Graph& graph, const std::string& initializer_name); + +template +bool contains(V&& begin, V&& end, const T& val) { + for (V& iter = begin; iter != end; iter.operator++()) { + if (iter->Name() == val) { + return true; + } + } + return false; +} + +template +bool contains(const R& vec, const T& val) { + for (auto iter = vec.begin(); iter != vec.end(); iter++) { + if ((*iter)->Name() == val) { + return true; + } + } + return false; +} + +bool contains(const std::vector& container, const std::string& value) { + return std::find(container.begin(), container.end(), value) != container.end(); +} + +struct GraphNode { + GraphNode() = delete; + + template + GraphNode(const N& node, const std::string& op_type = {}) { + node_name = node.Name(); + if constexpr (std::is_same_v) { + node_ptr = &node; + this->op_type = node.OpType(); + for (const auto iter : node.InputDefs()) { + node_input_name.push_back(iter->Name()); + } + for (const auto iter : node.OutputDefs()) { + node_output_name.push_back(iter->Name()); + } + } else { + this->op_type = op_type; + //** node_input_name = [] + //** node_output_name = [] + } + + if (op_type == "output") { + down_to_output = true; + } + } + + bool operator==(const GraphNode&) const = default; + + void add_edge_to(GraphNode& dst_node) { + to_node.push_back(&dst_node); + } + + void add_edge_from(GraphNode& src_node) { + from_node.push_back(&src_node); + } + + std::vector apply_scale_to_graph(float scale_adj) { + std::vector affected_dq; + + auto extend = [&affected_dq, scale_adj](const std::vector& new_nodes) { + affected_dq.insert(affected_dq.end(), new_nodes.begin(), new_nodes.end()); + }; + + if (op_type == "DequantizeLinear") { + scale_factor *= scale_adj; + affected_dq.push_back(this); + } else if ((op_type == "Add") || (op_type == "QuantizeLinear")) { + for (auto node : from_node) { + extend(node->apply_scale_to_graph(scale_adj)); + } + } else if (op_type == "Conv") { + // just adjust w&b for conv&mul, stop propagate + for (auto node : from_node) { + if (node->op_type == "DequantizeLinear") { + extend(node->apply_scale_to_graph(scale_adj)); + } + } + } else if ((op_type == "Mul") || (op_type == "MatMul")) { + bool find_dq{false}; + for (auto node : from_node) { + if (node->op_type == "DequantizeLinear" && !find_dq) { + find_dq = true; + extend(node->apply_scale_to_graph(scale_adj)); + } + } + if (!find_dq) { + // cannot scale dq from here, choose input 0 to propagate + extend(from_node.back()->from_node[0]->apply_scale_to_graph(scale_adj)); + } + } else { + ORT_THROW("Unknown case, node: %s", ToString().data()); + } + + return affected_dq; + } + + std::vector down_propagate_scale() { + std::vector affected_nodes; + + if (processed) { + return affected_nodes; + } + + if ((op_type == "InstanceNormalization") || (op_type == "Softmax")) { + // pass + } else if (op_type == "Add") { + auto up_new_nodes = up_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), up_new_nodes.begin(), up_new_nodes.end()); + + for (auto node : to_node) { + auto down_new_nodes = node->down_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), down_new_nodes.begin(), down_new_nodes.end()); + } + } else { + affected_nodes.push_back(this); + processed = true; + + for (auto node : to_node) { + auto new_nodes = node->down_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), new_nodes.begin(), new_nodes.end()); + } + } + return affected_nodes; + } + + std::vector up_propagate_scale() { + std::vector affected_nodes; + + if (processed) { + return affected_nodes; + } + + if ((op_type == "InstanceNormalization") || (op_type == "Softmax")) { + ORT_THROW("Cannot propagate up through norm layer"); + } else if (op_type == "Conv") { + affected_nodes.push_back(this); + processed = true; + + for (auto node : from_node) { + if (node->op_type == "DequantizeLinear") { + affected_nodes.push_back(node); + } + } + } else if ((op_type == "Mul") || (op_type == "MatMul")) { + affected_nodes.push_back(this); + processed = true; + bool find_dq{false}; + + for (auto node : from_node) { + if ((node->op_type == "DequantizeLinear") && !find_dq) { + find_dq = true; + affected_nodes.push_back(node); + } + } + if (!find_dq) { + auto new_nodes = from_node.back()->from_node[0]->up_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), new_nodes.begin(), new_nodes.end()); + } + } else { + affected_nodes.push_back(this); + processed = true; + + for (auto node : from_node) { + auto new_nodes = node->up_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), new_nodes.begin(), new_nodes.end()); + } + } + + return affected_nodes; + } + + bool down_propagate_to_output() { + if (down_to_output.has_value()) { + return down_to_output.value(); + } + + bool local_down_to_output{false}; + if (op_type == "output") { + local_down_to_output = true; + } else if ((op_type == "InstanceNormalization") || (op_type == "Softmax")) { + local_down_to_output = false; + } else { + for (auto node : to_node) { + local_down_to_output = local_down_to_output || node->down_propagate_to_output(); + } + } + + down_to_output = local_down_to_output; + return local_down_to_output; + } + + std::string ToString() const { + // auto string = std::format("op={} name={} queued={} visited={} scale_factor={}", + // op_type, + // node_name, + // queued, + // visited, + // scale_factor); + auto print_node_vector = [](const std::vector& nodes) -> std::string { + // auto comp = [](const GraphNode* left, const GraphNode* right) -> bool { + // return left->node_name < right->node_name; + // }; + // std::sort(nodes.begin(), nodes.end(), comp); + std::string ret = "["; + for (size_t i = 0, size = nodes.size(); auto pnode : nodes) { + if (pnode->node_name.size() == 0) continue; + ret += pnode->node_name; + if (++i < size) { + ret += ", "; + } + } + ret += "]"; + return ret; + }; + std::string from_node_str = print_node_vector(from_node); + std::string to_node_str = print_node_vector(to_node); + + auto print_string_vector = [](const std::vector& nodes) -> std::string { + // std::sort(nodes.begin(), nodes.end()); + std::string ret = "["; + for (size_t i = 0, size = nodes.size(); const auto& node : nodes) { + ret += node; + if (++i < size) { + ret += ", "; + } + } + ret += "]"; + return ret; + }; + std::string node_input_name_str = print_string_vector(node_input_name); + std::string node_output_name_str = print_string_vector(node_output_name); + + auto print_bool = [](bool val) -> std::string { + return (val) ? "True" : "False"; + }; + + auto print_opt_bool = [print_bool](std::optional val) -> std::string { + return (val.has_value()) ? print_bool(val.value()) : "None"; + }; + + auto string = std::format("node_name={} op_type={} scale_factor={:.2f} visited={} queued={} down_to_output={} processed={} from_node={} to_node={} node_input_name={} node_output_name={}", + node_name, + op_type, + scale_factor, + visited, + print_bool(queued), + print_opt_bool(down_to_output), + print_bool(processed), + from_node_str, + to_node_str, + node_input_name_str, + node_output_name_str); + return string; + } + + const Node* node_ptr{nullptr}; + std::string node_name; + std::string op_type; + std::vector node_input_name; + std::vector node_output_name; + std::vector from_node; + std::vector to_node; + float scale_factor{1.f}; + int visited{0}; + bool queued{false}; + std::optional down_to_output; + bool processed{false}; +}; + +struct CustomGraph { + CustomGraph() = delete; + CustomGraph(Graph& graph) : original_graph{graph} {} + + void sort() { + auto comp_node = [](const GraphNode& left, const GraphNode& right) -> bool { + return left.node_name < right.node_name; + }; + nodes.sort(comp_node); + + for (auto& node : nodes) { + auto comp_pnode = [](const GraphNode* left, const GraphNode* right) -> bool { + return left->node_name < right->node_name; + }; + std::sort(node.from_node.begin(), node.from_node.end(), comp_pnode); + std::sort(node.to_node.begin(), node.to_node.end(), comp_pnode); + } + } + + void add_node(const GraphNode& node) { + nodes.push_back(node); + } + + void add_edge(GraphNode& src, GraphNode& dst) { + src.add_edge_to(dst); + dst.add_edge_from(src); + } + + auto get_start_nodes() { + std::list start_nodes; + + for (auto& node : nodes) { + if (node.from_node.empty()) { + start_nodes.push_back(&node); + node.queued = true; + } + } + return start_nodes; + } + + void initailize_search(float threshold = 1.f, bool scale_output = false) { + remove_qdq(threshold, scale_output); + for (auto& node : nodes) { + node.visited = 0; + node.queued = false; + } + } + + void init_propagate() { + for (auto& node : nodes) { + node.processed = false; + } + } + + void remove_qdq_pair(const GraphNode& node, std::list& removed) { + auto& q = node; + InlinedVector dq_ptrs; + + for (auto& child : q.to_node) { + if (child->node_ptr && child->node_ptr->OpType() == "DequantizeLinear") { + dq_ptrs.push_back(child); + } + } + + if (dq_ptrs.empty()) { + return; + } + + for (std::size_t i = 1; i < dq_ptrs.size(); ++i) { + if (dq_ptrs[i]->node_input_name[1] != dq_ptrs[0]->node_input_name[1] || + dq_ptrs[i]->node_input_name[2] != dq_ptrs[0]->node_input_name[2]) { + return; + } + } + + auto& prev = *node.from_node[0]; + const auto& q_node = *q.node_ptr; + + bool is_prev_input = (prev.node_ptr == nullptr); + std::string prev_output_name = is_prev_input ? prev.node_name : prev.node_output_name[0]; + + InlinedVector> output_replacements; + for (auto dq_ptr : dq_ptrs) { + for (auto dst_node : dq_ptr->to_node) { + for (auto& scr_node : dst_node->from_node) { + if (*dq_ptr == *scr_node) { + scr_node = &prev; + } + } + + auto it = std::find(dst_node->node_input_name.begin(), dst_node->node_input_name.end(), dq_ptr->node_output_name[0]); + if (it != dst_node->node_input_name.end()) { + *it = prev_output_name; + } + } + for (auto& output : original_graph.GetOutputs()) { + if (output->Name() == dq_ptr->node_output_name[0]) { + const NodeArg* replacement_arg = nullptr; + if (!is_prev_input) { + replacement_arg = prev.node_ptr->OutputDefs()[0]; + } else { + replacement_arg = original_graph.GetNodeArg(prev.node_name); + ORT_ENFORCE(replacement_arg != nullptr, "Input not found: " + prev.node_name); + } + output_replacements.emplace_back(output, replacement_arg); + } + } + } + + prev.to_node.erase(std::remove(prev.to_node.begin(), prev.to_node.end(), &q), prev.to_node.end()); + for (auto dq_ptr : dq_ptrs) { + for (auto dst_node : dq_ptr->to_node) { + auto it = std::find(prev.to_node.begin(), prev.to_node.end(), dst_node); + if (it == prev.to_node.end()) { + prev.to_node.push_back(dst_node); + } + } + } + auto q_iter = std::find(nodes.begin(), nodes.end(), q); + if (q_iter != nodes.end()) { + removed.splice(removed.end(), nodes, q_iter); + } + + for (auto dq_ptr : dq_ptrs) { + auto dq_iter = std::find(nodes.begin(), nodes.end(), *dq_ptr); + if (dq_iter != nodes.end()) { + removed.splice(removed.end(), nodes, dq_iter); + } + } + + auto remove_edge = [this](const Node& src, const Node& dst, int src_arg, int dst_arg) { + original_graph.RemoveEdge(src.Index(), dst.Index(), src_arg, dst_arg); + }; + + auto in_edge = q_node.InputEdgesBegin(); + ORT_ENFORCE(in_edge != q_node.InputEdgesEnd(), "Q node must have an input edge"); + const int prev_output_index = in_edge->GetSrcArgIndex(); + + if (in_edge != q_node.InputEdgesEnd()) { + remove_edge(in_edge->GetNode(), q_node, + in_edge->GetSrcArgIndex(), in_edge->GetDstArgIndex()); + } + for (auto dq_ptr : dq_ptrs) { + auto& dq_node_ref = *dq_ptr->node_ptr; + + for (auto edge_it = dq_node_ref.InputEdgesBegin(); edge_it != dq_node_ref.InputEdgesEnd(); ++edge_it) { + if (edge_it->GetNode().Index() == q_node.Index()) { + remove_edge(edge_it->GetNode(), dq_node_ref, edge_it->GetSrcArgIndex(), edge_it->GetDstArgIndex()); + break; + } + } + + std::vector> output_edges; // (dst_node_index, src_arg, dst_arg) + for (auto out_edge_it = dq_node_ref.OutputEdgesBegin(); out_edge_it != dq_node_ref.OutputEdgesEnd(); ++out_edge_it) { + output_edges.emplace_back(out_edge_it->GetNode().Index(), + out_edge_it->GetSrcArgIndex(), + out_edge_it->GetDstArgIndex()); + } + + for (const auto& edge : output_edges) { + original_graph.RemoveEdge(dq_node_ref.Index(), std::get<0>(edge), + std::get<1>(edge), std::get<2>(edge)); + } + + if (!is_prev_input) { + for (const auto& edge : output_edges) { + original_graph.AddEdge(prev.node_ptr->Index(), + std::get<0>(edge), + prev_output_index, + std::get<2>(edge)); + } + } + } + + if (!output_replacements.empty()) { + auto outputs = original_graph.GetOutputs(); + for (auto& output : outputs) { + for (const auto& replacement : output_replacements) { + if (output == replacement.first) { + output = replacement.second; + break; + } + } + } + original_graph.SetOutputs(outputs); + } + + original_graph.RemoveNode(q_node.Index()); + for (auto dq_ptr : dq_ptrs) { + original_graph.RemoveNode(dq_ptr->node_ptr->Index()); + } + } + + std::list remove_qdq(float threshold = 1.f, bool scale_output = false) { + std::list removed; + std::vector nodes_copy; + std::for_each(nodes.begin(), nodes.end(), [&nodes_copy](GraphNode& node) { nodes_copy.push_back(&node); }); + for (auto node : nodes_copy) { + if (std::find(nodes.begin(), nodes.end(), *node) == nodes.end()) { + continue; + } + + if ((node->op_type == "QuantizeLinear") && + (node->to_node[0]->op_type == "DequantizeLinear")) { + const auto& zero_point_name = node->node_input_name[2]; + const auto p_initializer = original_graph.GetConstantInitializer(zero_point_name, false); + bool is_16_bit = p_initializer->has_data_type() && + (p_initializer->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT16 || + p_initializer->data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT16); + if (!is_16_bit) + continue; + if (!scale_output && node->down_propagate_to_output()) { + remove_qdq_pair(*node, removed); + continue; + } + + auto scale_name = node->node_input_name[1]; // Scale + auto scale_value = get_initializer_value(original_graph, scale_name); + if (scale_value / node->scale_factor < threshold) { + remove_qdq_pair(*node, removed); + } + } + } + + // Reconnect graph outputs if disconnected + bool update_outputs{false}; + auto outputs = original_graph.GetOutputs(); + for (auto output : outputs) { + bool found{false}; + for (auto node : original_graph.Nodes()) { + if (contains(node->OutputNodesBegin(), node->OutputNodesEnd(), output->Name())) { + found = true; + break; + } + } + + if (!found) { + // Connect the last valid node to the graph output + for (auto node : std::ranges::reverse_view(original_graph.Nodes())) { + if (!node->OutputDefs().empty()) { + const auto& name = (*node->OutputDefs().begin())->Name(); + auto& node_arg = original_graph.GetOrCreateNodeArg(name, output->TypeAsProto()); + output = &node_arg; + update_outputs = true; + } + } + } + } + + if (update_outputs) { + original_graph.SetOutputs(outputs); + } + + return removed; + } + + void dump_custom_graph(fs::path path) { + if (auto file = std::ofstream(path)) { + std::vector node_ref; + for (auto& node : nodes) { + node_ref.emplace_back(&node); + } + + for (const auto& node : node_ref) { + std::string node_str = node->ToString(); + file << node_str << "\n"; + } + } + } + + std::list nodes; + std::list removed_nodes; + Graph& original_graph; +}; + +template +T* get_mutable_initializer_data(const Graph& graph, const std::string& name) { + auto initializer = graph.GetConstantInitializer(name, true); + if (!initializer) return nullptr; + + if constexpr (std::is_same_v) { + if (initializer->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) + return nullptr; + } + + return reinterpret_cast(const_cast(initializer->raw_data().data())); +} + +std::size_t get_initializer_size(const Graph& graph, const std::string& name) { + auto initializer = graph.GetConstantInitializer(name, true); + if (!initializer) return 0; + + std::size_t size = 1; + if (!initializer->dims_size()) + return 0; + for (int i = 0; i < initializer->dims_size(); ++i) { + size *= initializer->dims()[i]; + } + return size; +} + +float get_initializer_value(const Graph& graph, const std::string& initializer_name) { + const auto p_initializer = graph.GetConstantInitializer(initializer_name, false); + + if (p_initializer->has_raw_data()) { + auto raw_data = get_mutable_initializer_data(graph, initializer_name); + auto size = get_initializer_size(graph, initializer_name); + ORT_ENFORCE(size == 1, "Expected an initializer to be of size 1"); + return raw_data[0]; + } else + return get_float_initializer_data(p_initializer); +} + +void update_initializer_value(Graph& graph, const std::string& initializer_name, const float new_value) { + const auto p_initializer = graph.GetConstantInitializer(initializer_name, false); + + if (p_initializer == nullptr) { + return; + } + + const auto& initializer = *p_initializer; + + // Verify 1D tensor + ORT_ENFORCE(initializer.dims_size() == 1); + ORT_ENFORCE(initializer.data_type() == onnx::TensorProto_DataType_FLOAT); + + // Create new tensor with updated value + auto new_tensor = onnx::TensorProto::Create(); + new_tensor->copy_from(p_initializer); + *(float*)new_tensor->mutable_raw_data()->data() = new_value; + graph.RemoveInitializedTensor(initializer_name); + graph.AddInitializedTensor(*new_tensor); +} + +CustomGraph generate_graph_from_onnx(Graph& graph) { + CustomGraph gen_graph{graph}; + + for (auto pnode : graph.Nodes()) { + if (pnode->NodeType() == Node::Type::Fused) continue; + gen_graph.nodes.emplace_back(*pnode); + } + + for (auto& src_node : gen_graph.nodes) { + for (auto& dst_node : gen_graph.nodes) { + if (src_node == dst_node) { + continue; + } + + for (auto& src_output : src_node.node_output_name) { + if (contains(dst_node.node_input_name, src_output)) { + gen_graph.add_edge(src_node, dst_node); + } + } + } + } + + for (auto& input_node : graph.GetInputs()) { + auto& cur_input = gen_graph.nodes.emplace_back(*input_node, "input"); + for (auto& dst_node : gen_graph.nodes) { + for (const auto& dst_output : dst_node.node_input_name) { + if (dst_output == input_node->Name()) { + gen_graph.add_edge(cur_input, dst_node); + } + } + } + } + + for (auto& output_node : graph.GetOutputs()) { + auto& cur_output = gen_graph.nodes.emplace_back(*output_node, "output"); + for (auto& src_node : gen_graph.nodes) { + for (const auto& dst_outputs : src_node.node_output_name) { + if (dst_outputs == output_node->Name()) { + gen_graph.add_edge(src_node, cur_output); + } + } + } + } + + gen_graph.sort(); + return gen_graph; +} + +bool scale_graph(CustomGraph& gen_graph, + float threshold = 1.f, + float ratio = 10, + bool scale_output = false) { + bool needs_second_run = false; + gen_graph.initailize_search(threshold, scale_output); + auto q = gen_graph.get_start_nodes(); + auto pred = [](const GraphNode* left, const GraphNode* right) -> bool { + return left->node_name < right->node_name; + }; + q.sort(pred); + + while (!q.empty()) { + auto cur_node = q.front(); + q.pop_front(); + if (static_cast(cur_node->visited) < cur_node->from_node.size()) { + cur_node->queued = false; + } else { + if (cur_node->op_type == "QuantizeLinear" && + cur_node->to_node[0]->op_type == "DequantizeLinear") { + needs_second_run = true; + auto scale_name = *std::next(cur_node->node_input_name.begin()); + auto scale_value = get_initializer_value(gen_graph.original_graph, scale_name); + + // QDQ pair with scale over 1 + if (scale_value / cur_node->scale_factor > threshold) { + gen_graph.init_propagate(); + // adjust previous op scale to threshold / 10 + auto scale_adj = scale_value / cur_node->scale_factor / threshold * ratio; + + // find related const dq to scale down + auto affected_dq = cur_node->apply_scale_to_graph(scale_adj); + std::vector affected_nodes; + + // then propage to graph to update scale + for (auto& dq : affected_dq) { + auto cur_affected = dq->down_propagate_scale(); + affected_nodes.insert(affected_nodes.end(), cur_affected.begin(), cur_affected.end()); + } + + for (auto& node : affected_nodes) { + bool found = std::find(affected_dq.begin(), affected_dq.end(), node) != affected_dq.end(); + if (!found) { + node->scale_factor *= scale_adj; + } + } + + auto removed_qdq = gen_graph.remove_qdq(threshold, scale_output); + for (auto& qdq : removed_qdq) { + try { + q.remove(&qdq); + } catch (...) { + } + } + + gen_graph.removed_nodes.splice(gen_graph.removed_nodes.end(), removed_qdq); + + cur_node = cur_node->to_node[0]; + } + } + + for (auto dst : cur_node->to_node) { + dst->visited += 1; + if (!dst->queued) { + dst->queued = true; + q.push_back(dst); + } + } + } + } + + for (auto& node : gen_graph.nodes) { + if (node.op_type == "DequantizeLinear" && node.scale_factor != 1.0f) { + const auto& scale_name = node.node_input_name[1]; + + auto scale_data = get_mutable_initializer_data(gen_graph.original_graph, scale_name); + if (scale_data) { + const auto scale_size = get_initializer_size(gen_graph.original_graph, scale_name); + if (!scale_size) { + auto it = gen_graph.original_graph.GetConstantInitializer(scale_name, true); + auto cur_scale = get_float_initializer_data(it); + cur_scale /= node.scale_factor; + set_float_initializer_data(it, cur_scale); + } else { + for (std::size_t i = 0; i < scale_size; ++i) { + scale_data[i] /= node.scale_factor; + } + } + } + + node.scale_factor = 1.0f; + } + } + return needs_second_run; +} + +Status copy_model(const GraphViewer& src_graph_viewer, + const logging::Logger& logger, std::unique_ptr& model) { + model = src_graph_viewer.CreateModel(logger); + const auto& src_graph = src_graph_viewer.GetGraph(); + auto& dst_graph = model->MainGraph(); + + const auto& inputs = src_graph.GetInputs(); + const auto& outputs = src_graph.GetOutputs(); + + struct InputReplacement { + NodeArg* graph_input; + NodeArg* identity_output; + }; + std::unordered_map input_replacement_map; + + struct OutputReplacement { + NodeArg* intermediate_arg; + NodeArg* original_output; + }; + std::unordered_map output_replacement_map; + + InlinedVector dst_graph_inputs; + dst_graph_inputs.reserve(inputs.size()); + for (auto& input : inputs) { + const auto& input_name = input->Name(); + auto input_arg = src_graph.GetNodeArg(input_name); + + auto& dst_input_arg = dst_graph.GetOrCreateNodeArg(input_name, input_arg->TypeAsProto()); + dst_graph_inputs.push_back(&dst_input_arg); + + auto output_name = input_name + "_identity_output"; + auto& identity_output_arg = dst_graph.GetOrCreateNodeArg(output_name, input_arg->TypeAsProto()); + + input_replacement_map[input_name] = {&dst_input_arg, &identity_output_arg}; + } + + InlinedVector dst_graph_outputs; + for (auto& output : outputs) { + const auto& output_name = output->Name(); + auto output_arg = src_graph.GetNodeArg(output_name); + + std::string intermediate_name = "tmp_" + output_name; + auto& intermediate_out = dst_graph.GetOrCreateNodeArg(intermediate_name, output_arg->TypeAsProto()); + + auto& original_out = dst_graph.GetOrCreateNodeArg(output_name, output_arg->TypeAsProto()); + + output_replacement_map[output_name] = {&intermediate_out, &original_out}; + dst_graph_outputs.push_back(&original_out); + } + + dst_graph.SetInputs(dst_graph_inputs); + dst_graph.SetOutputs(dst_graph_outputs); + dst_graph.SetName(src_graph.Name()); + + for (const auto& name : src_graph_viewer.GetOuterScopeNodeArgNames()) { + auto node_arg = src_graph.GetNodeArg(name); + ORT_RETURN_IF_NOT(node_arg != nullptr, "Outer scope node arg name '" + name + "'was added but does not exist. "); + dst_graph.AddOuterScopeNodeArg(name); + } + + for (auto& input : inputs) { + const auto& input_name = input->Name(); + auto it = input_replacement_map.find(input_name); + ORT_RETURN_IF_NOT(it != input_replacement_map.end(), "Missing replacement for input: " + input_name); + + InputReplacement& repl = it->second; + InlinedVector input_args = {repl.graph_input}; + InlinedVector output_args = {repl.identity_output}; + + std::string node_name = "IdentityInsertion_" + input_name; + dst_graph.AddNode(node_name, "Identity", "Inserted identity node", + input_args, output_args, + nullptr, ""); + } + + for (auto pnode : src_graph.Nodes()) { + if (pnode->NodeType() == Node::Type::Fused) continue; + + InlinedVector new_input_args; + for (auto input_arg : pnode->InputDefs()) { + if (!input_arg) { + new_input_args.push_back(nullptr); + continue; + } + + auto it = input_replacement_map.find(input_arg->Name()); + if (it != input_replacement_map.end()) { + new_input_args.push_back(it->second.identity_output); + } else { + auto& new_arg = dst_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + new_input_args.push_back(&new_arg); + } + } + InlinedVector new_output_args; + for (auto output_arg : pnode->OutputDefs()) { + if (output_arg == nullptr) { + new_output_args.push_back(nullptr); + continue; + } + + auto it_output = output_replacement_map.find(output_arg->Name()); + if (it_output != output_replacement_map.end()) { + new_output_args.push_back(it_output->second.intermediate_arg); + } else { + auto& new_arg = dst_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + new_output_args.push_back(&new_arg); + } + } + + dst_graph.AddNode(pnode->Name(), pnode->OpType(), pnode->Description(), + new_input_args, new_output_args, + &pnode->GetAttributes(), pnode->Domain()); + } + + for (auto& output : outputs) { + const std::string& output_name = output->Name(); + auto it = output_replacement_map.find(output_name); + if (it == output_replacement_map.end()) continue; + + OutputReplacement& repl = it->second; + InlinedVector input_args = {repl.intermediate_arg}; + InlinedVector output_args = {repl.original_output}; + + std::string node_name = "IdentityInsertion_" + output_name; + dst_graph.AddNode(node_name, "Identity", "Inserted identitynode", + input_args, output_args, nullptr, ""); + } + + for (auto& [name, tensor_proto] : src_graph.GetAllInitializedTensors()) { + dst_graph.AddInitializedTensor(*tensor_proto); + } + + for (auto node_arg : src_graph.GetInputsIncludingInitializers()) { + auto check_inputs = [node_arg](auto input_node_arg) { + return input_node_arg->Name() == node_arg->Name(); + }; + if (std::find_if(dst_graph_inputs.begin(), dst_graph_inputs.end(), check_inputs) != dst_graph_inputs.end()) + continue; + + auto src_tensor_proto = src_graph.GetConstantInitializer(node_arg->Name(), true); + if (src_tensor_proto) { + auto dst_tensor_proto = onnx::TensorProto::Create(); + dst_tensor_proto->copy_from(src_tensor_proto); + dst_graph.AddInitializedTensor(*dst_tensor_proto); + } + } + + ORT_RETURN_IF_ERROR(dst_graph.Resolve()); + return Status::OK(); +} + +Status Transform(const GraphViewer& src_graph_viewer, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model) { + auto status = copy_model(src_graph_viewer, logger, model); + auto g = generate_graph_from_onnx(model->MainGraph()); + + float threshold{1.f}; + float ratio{10.f}; + bool scale_output{false}; + auto needs_second_run = scale_graph(g, threshold, ratio, scale_output); + if (needs_second_run) + scale_graph(g, threshold * 100, ratio, scale_output); + return status; +} +} // namespace qdq_scales_fix +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h new file mode 100644 index 0000000000000..c54c531e1bd40 --- /dev/null +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h @@ -0,0 +1,19 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { +class GraphViewer; + +namespace openvino_ep { + +namespace qdq_scales_fix { +Status Transform(const GraphViewer& src_graph, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model); +} +} // namespace openvino_ep +} // namespace onnxruntime