diff --git a/.ci/azure/linux.yml b/.ci/azure/linux.yml index 7b1ee18d792d74..73bac3bb255f83 100644 --- a/.ci/azure/linux.yml +++ b/.ci/azure/linux.yml @@ -109,6 +109,7 @@ jobs: -DENABLE_WHEEL=ON -DENABLE_TESTS=ON -DNGRAPH_ONNX_IMPORT_ENABLE=ON + -DNGRAPH_ONNX_FRONTEND_ENABLE=ON -DENABLE_FASTER_BUILD=ON -DENABLE_STRICT_DEPENDENCIES=OFF -DIE_EXTRA_MODULES=$(OPENVINO_CONTRIB_REPO_DIR)/modules diff --git a/.ci/azure/linux_onnxruntime.yml b/.ci/azure/linux_onnxruntime.yml index fce8fdddcc4f91..a2bfee8c70ac3e 100644 --- a/.ci/azure/linux_onnxruntime.yml +++ b/.ci/azure/linux_onnxruntime.yml @@ -95,6 +95,7 @@ jobs: -DENABLE_SAMPLES=OFF -DENABLE_SPEECH_DEMO=OFF -DNGRAPH_ONNX_IMPORT_ENABLE=ON + -DNGRAPH_ONNX_FRONTEND_ENABLE=ON -DNGRAPH_DEBUG_ENABLE=OFF $(REPO_DIR) workingDirectory: $(BUILD_DIR) diff --git a/.ci/openvino-onnx/Dockerfile b/.ci/openvino-onnx/Dockerfile index 9b0f48cf66cc3e..315598225627e0 100644 --- a/.ci/openvino-onnx/Dockerfile +++ b/.ci/openvino-onnx/Dockerfile @@ -69,6 +69,7 @@ RUN cmake .. \ -DENABLE_PYTHON=ON \ -DPYTHON_EXECUTABLE=/usr/bin/python3 \ -DNGRAPH_ONNX_IMPORT_ENABLE=ON \ + -DNGRAPH_ONNX_FRONTEND_ENABLE=ON \ -DNGRAPH_DEBUG_ENABLE=OFF \ -DCMAKE_INSTALL_PREFIX=/openvino/dist \ -DNGRAPH_USE_PROTOBUF_LITE=${PROTOBUF_LITE} diff --git a/cmake/coverage.cmake b/cmake/coverage.cmake index 60c137337b3173..4d8976e0a80beb 100644 --- a/cmake/coverage.cmake +++ b/cmake/coverage.cmake @@ -92,9 +92,15 @@ ie_coverage_genhtml(INFO_FILE "ngraph" if(NGRAPH_ONNX_IMPORT_ENABLE) ie_coverage_extract(INPUT "openvino" OUTPUT "onnx_importer" - PATTERNS "${OV_COVERAGE_BASE_DIRECTORY}/ngraph/frontend/onnx_common*" - "${OV_COVERAGE_BASE_DIRECTORY}/ngraph/frontend/onnx_editor*" - "${OV_COVERAGE_BASE_DIRECTORY}/ngraph/frontend/onnx_import*") + PATTERNS "${OV_COVERAGE_BASE_DIRECTORY}/ngraph/frontend/onnx/onnx_common*" + "${OV_COVERAGE_BASE_DIRECTORY}/ngraph/frontend/onnx/onnx_import*") ie_coverage_genhtml(INFO_FILE "onnx_importer" PREFIX "${OV_COVERAGE_BASE_DIRECTORY}") endif() + +if(NGRAPH_ONNX_FRONTEND_ENABLE) + ie_coverage_extract(INPUT "openvino" OUTPUT "onnx_ngraph_frontend" + PATTERNS "${OV_COVERAGE_BASE_DIRECTORY}/ngraph/frontend/onnx/frontend*") + ie_coverage_genhtml(INFO_FILE "onnx_ngraph_frontend" + PREFIX "${OV_COVERAGE_BASE_DIRECTORY}") +endif() diff --git a/cmake/features.cmake b/cmake/features.cmake index 22ddd0a55834eb..29f7135546c858 100644 --- a/cmake/features.cmake +++ b/cmake/features.cmake @@ -125,6 +125,7 @@ else() endif() ie_dependent_option(NGRAPH_ONNX_IMPORT_ENABLE "Enable ONNX importer" ON "protoc_available" OFF) +ie_dependent_option(NGRAPH_ONNX_FRONTEND_ENABLE "Enable ONNX FrontEnd" OFF "NGRAPH_ONNX_IMPORT_ENABLE" OFF) ie_dependent_option(NGRAPH_PDPD_FRONTEND_ENABLE "Enable PaddlePaddle FrontEnd" ON "protoc_available" OFF) ie_dependent_option(NGRAPH_USE_PROTOBUF_LITE "Compiles and links with protobuf-lite" OFF "NGRAPH_ONNX_IMPORT_ENABLE OR NGRAPH_PDPD_FRONTEND_ENABLE" OFF) diff --git a/ngraph/frontend/frontend_manager/include/frontend_manager/frontend.hpp b/ngraph/frontend/frontend_manager/include/frontend_manager/frontend.hpp index da54a1f7993a95..34456d4df7f4b9 100644 --- a/ngraph/frontend/frontend_manager/include/frontend_manager/frontend.hpp +++ b/ngraph/frontend/frontend_manager/include/frontend_manager/frontend.hpp @@ -63,7 +63,7 @@ namespace ngraph /// \param partiallyConverted partially converted nGraph function /// \return fully converted nGraph function virtual std::shared_ptr - convert(std::shared_ptr partiallyConverted) const; + convert(std::shared_ptr partially_converted) const; /// \brief Convert only those parts of the model that can be converted leaving others /// as-is. Converted parts are not normalized by additional transformations; normalize diff --git a/ngraph/frontend/frontend_manager/include/frontend_manager/input_model.hpp b/ngraph/frontend/frontend_manager/include/frontend_manager/input_model.hpp index 6761e1dda383f2..6ec3f6060fc4d9 100644 --- a/ngraph/frontend/frontend_manager/include/frontend_manager/input_model.hpp +++ b/ngraph/frontend/frontend_manager/include/frontend_manager/input_model.hpp @@ -68,47 +68,47 @@ namespace ngraph /// \brief Returns a tensor place by a tensor name following framework conventions, or /// nullptr if a tensor with this name doesn't exist. - /// \param tensorName Name of tensor + /// \param tensor_name Name of tensor /// \return Tensor place corresponding to specifed tensor name - virtual Place::Ptr get_place_by_tensor_name(const std::string& tensorName) const; + virtual Place::Ptr get_place_by_tensor_name(const std::string& tensor_name) const; /// \brief Returns an operation place by an operation name following framework /// conventions, or nullptr if an operation with this name doesn't exist. \param - /// operationName Name of operation \return Place representing operation - virtual Place::Ptr get_place_by_operation_name(const std::string& operationName); + /// operation_name Name of operation \return Place representing operation + virtual Place::Ptr get_place_by_operation_name(const std::string& operation_name); /// \brief Returns an input port place by operation name and appropriate port index - /// \param operationName Name of operation - /// \param outputPortIndex Index of input port for this operation + /// \param operation_name Name of operation + /// \param output_port_index Index of input port for this operation /// \return Place representing input port of operation virtual Place::Ptr - get_place_by_operation_name_and_input_port(const std::string& operationName, - int inputPortIndex); + get_place_by_operation_name_and_input_port(const std::string& operation_name, + int input_port_index); /// \brief Returns an output port place by operation name and appropriate port index - /// \param operationNameNname of operation - /// \param outputPortIndex Index of output port for this operation + /// \param operation_name Name of operation + /// \param output_port_index Index of output port for this operation /// \return Place representing output port of operation virtual Place::Ptr - get_place_by_operation_name_and_output_port(const std::string& operationName, - int outputPortIndex); + get_place_by_operation_name_and_output_port(const std::string& operation_name, + int output_port_index); ///// Naming and annotation ///// /// \brief Sets name for tensor. Overwrites existing names of this place /// \param operation Tensor place - /// \param newName New name for this tensor - virtual void set_name_for_tensor(Place::Ptr tensor, const std::string& newName); + /// \param new_name New name for this tensor + virtual void set_name_for_tensor(Place::Ptr tensor, const std::string& new_name); /// \brief Adds new name for tensor /// \param operation Tensor place - /// \param newName New name to be added to this place - virtual void add_name_for_tensor(Place::Ptr tensor, const std::string& newName); + /// \param new_name New name to be added to this place + virtual void add_name_for_tensor(Place::Ptr tensor, const std::string& new_name); /// \brief Sets name for operation. Overwrites existing names of this place /// \param operation Operation place - /// \param newName New name for this operation - virtual void set_name_for_operation(Place::Ptr operation, const std::string& newName); + /// \param new_name New name for this operation + virtual void set_name_for_operation(Place::Ptr operation, const std::string& new_name); /// \brief Unassign specified name from tensor place(s) /// \param name Name of tensor @@ -120,27 +120,27 @@ namespace ngraph /// \brief Set name for a particular dimension of a place (e.g. batch dimension) /// \param place Model's place - /// \param shapeDimIndex Dimension index - /// \param dimName Name to assign on this dimension + /// \param shape_dim_index Dimension index + /// \param dim_name Name to assign on this dimension virtual void set_name_for_dimension(Place::Ptr place, - size_t shapeDimIndex, - const std::string& dimName); + size_t shape_dim_index, + const std::string& dim_name); ///// Topology Editing ///// /// \brief Cut immediately before this place and assign this place as new input; prune /// all nodes that don't contribute to any output. /// \param place New place to be assigned as input - /// \param newNameOptional Optional new name assigned to this input place + /// \param new_name_optional Optional new name assigned to this input place virtual void cut_and_add_new_input(Place::Ptr place, - const std::string& newNameOptional = ""); + const std::string& new_name_optional = ""); /// \brief Cut immediately after this place and assign this place as new output; prune /// all nodes that don't contribute to any output. /// \param place New place to be assigned as output - /// \param newNameOptional Optional new name assigned to this output place + /// \param new_name_optional Optional new name assigned to this output place virtual void cut_and_add_new_output(Place::Ptr place, - const std::string& newNameOptional = ""); + const std::string& new_name_optional = ""); /// \brief Assign this place as new output or add necessary nodes to represent a new /// output. @@ -200,13 +200,13 @@ namespace ngraph virtual void set_tensor_value(Place::Ptr place, const void* value); /// \brief Defines partial value (lower bound and upper bound) for a tensor place - /// TODO: more details for minValue and maxValue format; who defines shape? + /// TODO: more details for min_value and max_value format; who defines shape? /// \param place Tensor place - /// \param minValue Lower bound of partial value for tensor place - /// \param maxValue Upper bound of partial value for tensor place + /// \param min_value Lower bound of partial value for tensor place + /// \param max_value Upper bound of partial value for tensor place virtual void set_tensor_partial_value(Place::Ptr place, - const void* minValue, - const void* maxValue); + const void* min_value, + const void* max_value); }; } // namespace frontend diff --git a/ngraph/frontend/frontend_manager/include/frontend_manager/place.hpp b/ngraph/frontend/frontend_manager/include/frontend_manager/place.hpp index 5df561fa0d5e4d..045d93dc4c5d98 100644 --- a/ngraph/frontend/frontend_manager/include/frontend_manager/place.hpp +++ b/ngraph/frontend/frontend_manager/include/frontend_manager/place.hpp @@ -87,12 +87,12 @@ namespace ngraph /// \note It can be called for any kind of graph place searching for the first consuming /// operations. /// - /// \param outputPortIndex If place is an operational node it specifies which output + /// \param output_port_index If place is an operational node it specifies which output /// port should be considered. /// /// \return A vector with all operation node references that consumes data from this /// place - virtual std::vector get_consuming_operations(int outputPortIndex) const; + virtual std::vector get_consuming_operations(int output_port_index) const; /// \brief Returns a tensor place that gets data from this place; applicable for /// operations, output ports and output edges which have only one output port @@ -103,11 +103,11 @@ namespace ngraph /// \brief Returns a tensor place that gets data from this place; applicable for /// operations, output ports and output edges /// - /// \param outputPortIndex Output port index if the current place is an operation node + /// \param output_port_index Output port index if the current place is an operation node /// and has multiple output ports /// /// \return A tensor place which hold the resulting value for this place - virtual Ptr get_target_tensor(int outputPortIndex) const; + virtual Ptr get_target_tensor(int output_port_index) const; /// \brief Returns a tensor place that supplies data for this place; applicable for /// operations, input ports and input edges which have only one input port @@ -118,10 +118,10 @@ namespace ngraph /// \brief Returns a tensor place that supplies data for this place; applicable for /// operations, input ports and input edges /// - /// \param inputPortIndex Input port index for operational nodes. + /// \param input_port_index Input port index for operational nodes. /// /// \return A tensor place which supplies data for this place - virtual Ptr get_source_tensor(int inputPortIndex) const; + virtual Ptr get_source_tensor(int input_port_index) const; /// \brief Get an operation node place that immediately produces data for this place; /// applicable if place has only one input port @@ -131,11 +131,11 @@ namespace ngraph /// \brief Get an operation node place that immediately produces data for this place /// - /// \param inputPortIndex If a given place is itself an operation node, this specifies a - /// port index + /// \param input_port_index If a given place is itself an operation node, this specifies + /// a port index /// /// \return An operation place that produces data for this place - virtual Ptr get_producing_operation(int inputPortIndex) const; + virtual Ptr get_producing_operation(int input_port_index) const; /// Returns a port that produces data for this place virtual Ptr get_producing_port() const; @@ -148,28 +148,28 @@ namespace ngraph /// \brief For operation node returns reference to an input port with specified index /// - /// \param inputPortIndex Input port index + /// \param input_port_index Input port index /// /// \return Appropriate input port place - virtual Ptr get_input_port(int inputPortIndex) const; + virtual Ptr get_input_port(int input_port_index) const; /// \brief For operation node returns reference to an input port with specified name; /// applicable if port group has only one input port /// - /// \param inputName Name of port group + /// \param input_name Name of port group /// /// \return Appropriate input port place - virtual Ptr get_input_port(const std::string& inputName) const; + virtual Ptr get_input_port(const std::string& input_name) const; /// \brief For operation node returns reference to an input port with specified name and /// index /// - /// \param inputName Name of port group, each group can have multiple ports + /// \param input_name Name of port group, each group can have multiple ports /// - /// \param inputPortIndex Input port index in a group + /// \param input_port_index Input port index in a group /// /// \return Appropriate input port place - virtual Ptr get_input_port(const std::string& inputName, int inputPortIndex) const; + virtual Ptr get_input_port(const std::string& input_name, int input_port_index) const; /// \brief For operation node returns reference to an output port; applicable for /// operations with only one output port @@ -179,28 +179,29 @@ namespace ngraph /// \brief For operation node returns reference to an output port with specified index /// - /// \param outputPortIndex Output port index + /// \param output_port_index Output port index /// /// \return Appropriate output port place - virtual Ptr get_output_port(int outputPortIndex) const; + virtual Ptr get_output_port(int output_port_index) const; /// \brief For operation node returns reference to an output port with specified name; /// applicable if port group has only one output port /// - /// \param outputName Name of output port group + /// \param output_name Name of output port group /// /// \return Appropriate output port place - virtual Ptr get_output_port(const std::string& outputName) const; + virtual Ptr get_output_port(const std::string& output_name) const; /// \brief For operation node returns reference to an output port with specified name /// and index /// - /// \param outputName Name of output port group, each group can have multiple ports + /// \param output_name Name of output port group, each group can have multiple ports /// - /// \param outputPortIndex Output port index + /// \param output_port_index Output port index /// /// \return Appropriate output port place - virtual Ptr get_output_port(const std::string& outputName, int outputPortIndex) const; + virtual Ptr get_output_port(const std::string& output_name, + int output_port_index) const; /// \brief Returns all input ports that consume data flows through this place virtual std::vector get_consuming_ports() const; diff --git a/ngraph/frontend/frontend_manager/src/frontend_manager.cpp b/ngraph/frontend/frontend_manager/src/frontend_manager.cpp index 95dfe1ccbdedf5..68fcccb0538198 100644 --- a/ngraph/frontend/frontend_manager/src/frontend_manager.cpp +++ b/ngraph/frontend/frontend_manager/src/frontend_manager.cpp @@ -16,22 +16,22 @@ using namespace ngraph::frontend; //----------- FrontEndManager --------------------------- class FrontEndManager::Impl { - std::vector m_loadedLibs; // must be a first class member (destroyed last) + std::vector m_loaded_libs; // must be a first class member (destroyed last) std::map m_factories; public: - Impl() { registerPlugins(); } + Impl() { register_plugins(); } ~Impl() = default; - FrontEnd::Ptr loadByFramework(const std::string& framework) + FrontEnd::Ptr load_by_framework(const std::string& framework) { FRONT_END_INITIALIZATION_CHECK( m_factories.count(framework), "FrontEnd for Framework ", framework, " is not found"); return m_factories[framework](); } - std::vector availableFrontEnds() const + std::vector available_front_ends() const { std::vector keys; @@ -43,7 +43,7 @@ class FrontEndManager::Impl return keys; } - FrontEnd::Ptr loadByModel(const std::vector>& variants) + FrontEnd::Ptr load_by_model(const std::vector>& variants) { for (const auto& factory : m_factories) { @@ -56,41 +56,41 @@ class FrontEndManager::Impl return FrontEnd::Ptr(); } - void registerFrontEnd(const std::string& name, FrontEndFactory creator) + void register_front_end(const std::string& name, FrontEndFactory creator) { m_factories.insert({name, creator}); } private: - void registerPlugins() + void register_plugins() { - auto registerFromDir = [&](const std::string& dir) { + auto register_from_dir = [&](const std::string& dir) { if (!dir.empty()) { - auto plugins = loadPlugins(dir); + auto plugins = load_plugins(dir); for (auto& plugin : plugins) { - registerFrontEnd(plugin.m_pluginInfo.m_name, plugin.m_pluginInfo.m_creator); - m_loadedLibs.push_back(std::move(plugin.m_libHandle)); + register_front_end(plugin.m_plugin_info.m_name, plugin.m_plugin_info.m_creator); + m_loaded_libs.push_back(std::move(plugin.m_lib_handle)); } } }; - std::string envPath = ngraph::getenv_string("OV_FRONTEND_PATH"); - if (!envPath.empty()) + std::string env_path = ngraph::getenv_string("OV_FRONTEND_PATH"); + if (!env_path.empty()) { auto start = 0u; - auto sepPos = envPath.find(PathSeparator, start); - while (sepPos != std::string::npos) + auto sep_pos = env_path.find(PathSeparator, start); + while (sep_pos != std::string::npos) { - registerFromDir(envPath.substr(start, sepPos - start)); - start = sepPos + 1; - sepPos = envPath.find(PathSeparator, start); + register_from_dir(env_path.substr(start, sep_pos - start)); + start = sep_pos + 1; + sep_pos = env_path.find(PathSeparator, start); } - registerFromDir(envPath.substr(start, sepPos)); + register_from_dir(env_path.substr(start, sep_pos)); } else { - registerFromDir(getFrontendLibraryPath()); + register_from_dir(get_frontend_library_path()); } } }; @@ -107,23 +107,23 @@ FrontEndManager::~FrontEndManager() = default; FrontEnd::Ptr FrontEndManager::load_by_framework(const std::string& framework) { - return m_impl->loadByFramework(framework); + return m_impl->load_by_framework(framework); } FrontEnd::Ptr FrontEndManager::load_by_model_impl(const std::vector>& variants) { - return m_impl->loadByModel(variants); + return m_impl->load_by_model(variants); } std::vector FrontEndManager::get_available_front_ends() const { - return m_impl->availableFrontEnds(); + return m_impl->available_front_ends(); } void FrontEndManager::register_front_end(const std::string& name, FrontEndFactory creator) { - m_impl->registerFrontEnd(name, creator); + m_impl->register_front_end(name, creator); } //----------- FrontEnd --------------------------- @@ -158,7 +158,7 @@ std::shared_ptr FrontEnd::convert_partially(InputModel::Ptr mo std::shared_ptr FrontEnd::decode(InputModel::Ptr model) const { - FRONT_END_NOT_IMPLEMENTED(convertDecodingOnly); + FRONT_END_NOT_IMPLEMENTED(decode); } void FrontEnd::normalize(std::shared_ptr function) const @@ -177,39 +177,40 @@ std::vector InputModel::get_outputs() const FRONT_END_NOT_IMPLEMENTED(get_outputs); } -Place::Ptr InputModel::get_place_by_tensor_name(const std::string& tensorName) const +Place::Ptr InputModel::get_place_by_tensor_name(const std::string& tensor_name) const { FRONT_END_NOT_IMPLEMENTED(get_place_by_tensor_name); } -Place::Ptr InputModel::get_place_by_operation_name(const std::string& operationName) +Place::Ptr InputModel::get_place_by_operation_name(const std::string& operation_name) { FRONT_END_NOT_IMPLEMENTED(get_place_by_operation_name); } -Place::Ptr InputModel::get_place_by_operation_name_and_input_port(const std::string& operationName, - int inputPortIndex) +Place::Ptr InputModel::get_place_by_operation_name_and_input_port(const std::string& operation_name, + int input_port_index) { FRONT_END_NOT_IMPLEMENTED(get_place_by_operation_name_and_input_port); } -Place::Ptr InputModel::get_place_by_operation_name_and_output_port(const std::string& operationName, - int outputPortIndex) +Place::Ptr + InputModel::get_place_by_operation_name_and_output_port(const std::string& operation_name, + int output_port_index) { FRONT_END_NOT_IMPLEMENTED(get_place_by_operation_name_and_output_port); } -void InputModel::set_name_for_tensor(Place::Ptr tensor, const std::string& newName) +void InputModel::set_name_for_tensor(Place::Ptr tensor, const std::string& new_name) { FRONT_END_NOT_IMPLEMENTED(set_name_for_tensor); } -void InputModel::add_name_for_tensor(Place::Ptr tensor, const std::string& newName) +void InputModel::add_name_for_tensor(Place::Ptr tensor, const std::string& new_name) { FRONT_END_NOT_IMPLEMENTED(add_name_for_tensor); } -void InputModel::set_name_for_operation(Place::Ptr operation, const std::string& newName) +void InputModel::set_name_for_operation(Place::Ptr operation, const std::string& new_name) { FRONT_END_NOT_IMPLEMENTED(set_name_for_operation); } @@ -225,18 +226,18 @@ void InputModel::free_name_for_operation(const std::string& name) } void InputModel::set_name_for_dimension(Place::Ptr place, - size_t shapeDimIndex, - const std::string& dimName) + size_t shape_dim_index, + const std::string& dim_name) { FRONT_END_NOT_IMPLEMENTED(set_name_for_dimension); } -void InputModel::cut_and_add_new_input(Place::Ptr place, const std::string& newNameOptional) +void InputModel::cut_and_add_new_input(Place::Ptr place, const std::string& new_name_optional) { FRONT_END_NOT_IMPLEMENTED(cut_and_add_new_input); } -void InputModel::cut_and_add_new_output(Place::Ptr place, const std::string& newNameOptional) +void InputModel::cut_and_add_new_output(Place::Ptr place, const std::string& new_name_optional) { FRONT_END_NOT_IMPLEMENTED(cut_and_add_new_output); } @@ -289,8 +290,8 @@ void InputModel::set_tensor_value(Place::Ptr place, const void* value) } void InputModel::set_tensor_partial_value(Place::Ptr place, - const void* minValue, - const void* maxValue) + const void* min_value, + const void* max_value) { FRONT_END_NOT_IMPLEMENTED(set_tensor_partial_value); } @@ -306,7 +307,7 @@ std::vector Place::get_consuming_operations() const FRONT_END_NOT_IMPLEMENTED(get_consuming_operations); } -std::vector Place::get_consuming_operations(int outputPortIndex) const +std::vector Place::get_consuming_operations(int output_port_index) const { FRONT_END_NOT_IMPLEMENTED(get_consuming_operations); } @@ -316,7 +317,7 @@ Place::Ptr Place::get_target_tensor() const FRONT_END_NOT_IMPLEMENTED(get_target_tensor); } -Place::Ptr Place::get_target_tensor(int outputPortIndex) const +Place::Ptr Place::get_target_tensor(int output_port_index) const { FRONT_END_NOT_IMPLEMENTED(get_target_tensor); } @@ -326,7 +327,7 @@ Place::Ptr Place::get_producing_operation() const FRONT_END_NOT_IMPLEMENTED(get_producing_operation); } -Place::Ptr Place::get_producing_operation(int inputPortIndex) const +Place::Ptr Place::get_producing_operation(int input_port_index) const { FRONT_END_NOT_IMPLEMENTED(get_producing_operation); } @@ -341,17 +342,17 @@ Place::Ptr Place::get_input_port() const FRONT_END_NOT_IMPLEMENTED(get_input_port); } -Place::Ptr Place::get_input_port(int inputPortIndex) const +Place::Ptr Place::get_input_port(int input_port_index) const { FRONT_END_NOT_IMPLEMENTED(get_input_port); } -Place::Ptr Place::get_input_port(const std::string& inputName) const +Place::Ptr Place::get_input_port(const std::string& input_name) const { FRONT_END_NOT_IMPLEMENTED(get_input_port); } -Place::Ptr Place::get_input_port(const std::string& inputName, int inputPortIndex) const +Place::Ptr Place::get_input_port(const std::string& input_name, int input_port_index) const { FRONT_END_NOT_IMPLEMENTED(get_input_port); } @@ -361,17 +362,17 @@ Place::Ptr Place::get_output_port() const FRONT_END_NOT_IMPLEMENTED(get_output_port); } -Place::Ptr Place::get_output_port(int outputPortIndex) const +Place::Ptr Place::get_output_port(int output_port_index) const { FRONT_END_NOT_IMPLEMENTED(get_output_port); } -Place::Ptr Place::get_output_port(const std::string& outputName) const +Place::Ptr Place::get_output_port(const std::string& output_name) const { FRONT_END_NOT_IMPLEMENTED(get_output_port); } -Place::Ptr Place::get_output_port(const std::string& outputName, int outputPortIndex) const +Place::Ptr Place::get_output_port(const std::string& output_name, int output_port_index) const { FRONT_END_NOT_IMPLEMENTED(get_output_port); } @@ -406,7 +407,7 @@ Place::Ptr Place::get_source_tensor() const FRONT_END_NOT_IMPLEMENTED(get_source_tensor); } -Place::Ptr Place::get_source_tensor(int inputPortIndex) const +Place::Ptr Place::get_source_tensor(int input_port_index) const { FRONT_END_NOT_IMPLEMENTED(get_source_tensor); } diff --git a/ngraph/frontend/frontend_manager/src/plugin_loader.cpp b/ngraph/frontend/frontend_manager/src/plugin_loader.cpp index f42dbbad3a6c96..7ce0e956bab700 100644 --- a/ngraph/frontend/frontend_manager/src/plugin_loader.cpp +++ b/ngraph/frontend/frontend_manager/src/plugin_loader.cpp @@ -25,17 +25,17 @@ using namespace ngraph; using namespace ngraph::frontend; #ifdef WIN32 -#define DLOPEN(fileStr) LoadLibrary(TEXT(fileStr.c_str())) +#define DLOPEN(file_str) LoadLibrary(TEXT(file_str.c_str())) #define DLSYM(obj, func) GetProcAddress(obj, func) #define DLCLOSE(obj) FreeLibrary(obj) #else -#define DLOPEN(fileStr) dlopen(file.c_str(), RTLD_LAZY) +#define DLOPEN(file_str) dlopen(file_str.c_str(), RTLD_LAZY) #define DLSYM(obj, func) dlsym(obj, func) #define DLCLOSE(obj) dlclose(obj) #endif // TODO: change to std::filesystem for C++17 -static std::vector listFiles(const std::string& path) +static std::vector list_files(const std::string& path) { std::vector res; try @@ -68,9 +68,9 @@ static std::vector listFiles(const std::string& path) return res; } -std::vector ngraph::frontend::loadPlugins(const std::string& dirName) +std::vector ngraph::frontend::load_plugins(const std::string& dir_name) { - auto files = listFiles(dirName); + auto files = list_files(dir_name); std::vector res; for (const auto& file : files) { @@ -80,32 +80,29 @@ std::vector ngraph::frontend::loadPlugins(const std::string& dirName continue; } - PluginHandle guard([shared_object, file]() { - // std::cout << "Closing plugin library " << file << std::endl; - DLCLOSE(shared_object); - }); + PluginHandle guard([shared_object, file]() { DLCLOSE(shared_object); }); - auto infoAddr = reinterpret_cast(DLSYM(shared_object, "GetAPIVersion")); - if (!infoAddr) + auto info_addr = reinterpret_cast(DLSYM(shared_object, "GetAPIVersion")); + if (!info_addr) { continue; } - FrontEndVersion plugInfo{reinterpret_cast(infoAddr())}; + FrontEndVersion plug_info{reinterpret_cast(info_addr())}; - if (plugInfo != OV_FRONTEND_API_VERSION) + if (plug_info != OV_FRONTEND_API_VERSION) { // Plugin has incompatible API version, do not load it continue; } - auto creatorAddr = reinterpret_cast(DLSYM(shared_object, "GetFrontEndData")); - if (!creatorAddr) + auto creator_addr = reinterpret_cast(DLSYM(shared_object, "GetFrontEndData")); + if (!creator_addr) { continue; } std::unique_ptr fact{ - reinterpret_cast(creatorAddr())}; + reinterpret_cast(creator_addr())}; res.push_back(PluginData(std::move(guard), std::move(*fact))); } diff --git a/ngraph/frontend/frontend_manager/src/plugin_loader.hpp b/ngraph/frontend/frontend_manager/src/plugin_loader.hpp index 1ab3fc73baa227..9d01d4f3437689 100644 --- a/ngraph/frontend/frontend_manager/src/plugin_loader.hpp +++ b/ngraph/frontend/frontend_manager/src/plugin_loader.hpp @@ -7,11 +7,11 @@ #include #ifdef _WIN32 -const char FileSeparator[] = "\\"; -const char PathSeparator[] = ";"; +static const char FileSeparator[] = "\\"; +static const char PathSeparator[] = ";"; #else -const char FileSeparator[] = "/"; -const char PathSeparator[] = ":"; +static const char FileSeparator[] = "/"; +static const char PathSeparator[] = ":"; #endif // _WIN32 namespace ngraph @@ -23,8 +23,8 @@ namespace ngraph class PluginHandle { public: - PluginHandle(std::function callOnDestruct) - : m_callOnDestruct(callOnDestruct) + PluginHandle(std::function call_on_destruct) + : m_call_on_destruct(call_on_destruct) { } @@ -38,31 +38,31 @@ namespace ngraph ~PluginHandle() { - if (m_callOnDestruct) + if (m_call_on_destruct) { - m_callOnDestruct(); + m_call_on_destruct(); } } private: - std::function m_callOnDestruct; + std::function m_call_on_destruct; }; struct PluginData { PluginData(PluginHandle&& h, FrontEndPluginInfo&& info) - : m_libHandle(std::move(h)) - , m_pluginInfo(info) + : m_lib_handle(std::move(h)) + , m_plugin_info(info) { } PluginHandle - m_libHandle; // Shall be destroyed when plugin is not needed anymore to free memory - FrontEndPluginInfo m_pluginInfo; + m_lib_handle; // Shall be destroyed when plugin is not needed anymore to free memory + FrontEndPluginInfo m_plugin_info; }; // Searches for available plugins in a specified directory - std::vector loadPlugins(const std::string& dirName); + std::vector load_plugins(const std::string& dir_name); } // namespace frontend } // namespace ngraph diff --git a/ngraph/frontend/frontend_manager/src/utils.cpp b/ngraph/frontend/frontend_manager/src/utils.cpp index e940512e6e7872..262aa7556127e6 100644 --- a/ngraph/frontend/frontend_manager/src/utils.cpp +++ b/ngraph/frontend/frontend_manager/src/utils.cpp @@ -26,7 +26,7 @@ namespace { - std::string getPathName(const std::string& s) + std::string get_path_name(const std::string& s) { size_t i = s.rfind(FileSeparator, s.length()); if (i != std::string::npos) @@ -39,30 +39,30 @@ namespace } // namespace -static std::string _getFrontendLibraryPath() +static std::string _get_frontend_library_path() { #ifdef _WIN32 CHAR ie_library_path[MAX_PATH]; HMODULE hm = NULL; if (!GetModuleHandleExA(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, - reinterpret_cast(ngraph::frontend::getFrontendLibraryPath), + reinterpret_cast(ngraph::frontend::get_frontend_library_path), &hm)) { FRONT_END_INITIALIZATION_CHECK(false, "GetModuleHandle returned ", GetLastError()); } GetModuleFileNameA(hm, (LPSTR)ie_library_path, sizeof(ie_library_path)); - return getPathName(std::string(ie_library_path)); + return get_path_name(std::string(ie_library_path)); #elif defined(__APPLE__) || defined(__linux__) Dl_info info; - dladdr(reinterpret_cast(ngraph::frontend::getFrontendLibraryPath), &info); - return getPathName(std::string(info.dli_fname)).c_str(); + dladdr(reinterpret_cast(ngraph::frontend::get_frontend_library_path), &info); + return get_path_name(std::string(info.dli_fname)).c_str(); #else #error "Unsupported OS" #endif // _WIN32 } -std::string ngraph::frontend::getFrontendLibraryPath() +std::string ngraph::frontend::get_frontend_library_path() { - return _getFrontendLibraryPath(); + return _get_frontend_library_path(); } diff --git a/ngraph/frontend/frontend_manager/src/utils.hpp b/ngraph/frontend/frontend_manager/src/utils.hpp index 26d6f5273c30e4..1c05466bf29024 100644 --- a/ngraph/frontend/frontend_manager/src/utils.hpp +++ b/ngraph/frontend/frontend_manager/src/utils.hpp @@ -9,6 +9,6 @@ namespace ngraph { namespace frontend { - FRONTEND_API std::string getFrontendLibraryPath(); + FRONTEND_API std::string get_frontend_library_path(); } // namespace frontend } // namespace ngraph diff --git a/ngraph/frontend/onnx/CMakeLists.txt b/ngraph/frontend/onnx/CMakeLists.txt index 3327ff61b08e0e..5bf43f04931165 100644 --- a/ngraph/frontend/onnx/CMakeLists.txt +++ b/ngraph/frontend/onnx/CMakeLists.txt @@ -4,3 +4,6 @@ add_subdirectory(onnx_common) add_subdirectory(onnx_import) +if (NGRAPH_ONNX_FRONTEND_ENABLE) + add_subdirectory(frontend) +endif() diff --git a/ngraph/frontend/onnx/frontend/CMakeLists.txt b/ngraph/frontend/onnx/frontend/CMakeLists.txt new file mode 100644 index 00000000000000..aab7a150db87c9 --- /dev/null +++ b/ngraph/frontend/onnx/frontend/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +file(GLOB_RECURSE LIBRARY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) +file(GLOB_RECURSE LIBRARY_PUBLIC_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/include/*.hpp) + +set(ONNX_FRONTEND_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) + +# Create named folders for the sources within the .vcproj +# Empty name lists them directly under the .vcproj + +source_group("src" FILES ${LIBRARY_SRC}) +source_group("include" FILES ${LIBRARY_HEADERS}) +source_group("public include" FILES ${LIBRARY_PUBLIC_HEADERS}) + +# Create shared library +add_library(onnx_ngraph_frontend SHARED ${LIBRARY_SRC} ${LIBRARY_HEADERS} ${LIBRARY_PUBLIC_HEADERS}) +add_library(ngraph::onnx_ngraph_frontend ALIAS onnx_ngraph_frontend) + +add_clang_format_target(onnx_ngraph_frontend_clang FOR_TARGETS onnx_ngraph_frontend) + +if(COMMAND ie_add_vs_version_file) + ie_add_vs_version_file(NAME onnx_ngraph_frontend + FILEDESCRIPTION "nGraph ONNX frontend library") +endif() + +target_link_libraries(onnx_ngraph_frontend PRIVATE onnx_importer frontend_manager) + +target_include_directories(onnx_ngraph_frontend PUBLIC $ + $) + +target_include_directories(onnx_ngraph_frontend PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) + +install(TARGETS onnx_ngraph_frontend EXPORT ngraphTargets + RUNTIME DESTINATION ${NGRAPH_INSTALL_LIB} COMPONENT ngraph + ARCHIVE DESTINATION ${NGRAPH_INSTALL_LIB} COMPONENT ngraph + LIBRARY DESTINATION ${NGRAPH_INSTALL_LIB} COMPONENT ngraph) + +install(DIRECTORY ${ONNX_FRONTEND_INCLUDE_DIR}/onnx_frontend + DESTINATION ${FRONTEND_INSTALL_INCLUDE} + COMPONENT ngraph_dev + FILES_MATCHING PATTERN "*.hpp") + +export(TARGETS onnx_ngraph_frontend NAMESPACE ngraph:: APPEND FILE "${NGRAPH_TARGETS_FILE}") diff --git a/ngraph/frontend/onnx/frontend/include/onnx_frontend/frontend.hpp b/ngraph/frontend/onnx/frontend/include/onnx_frontend/frontend.hpp new file mode 100644 index 00000000000000..fdc004365d6672 --- /dev/null +++ b/ngraph/frontend/onnx/frontend/include/onnx_frontend/frontend.hpp @@ -0,0 +1,34 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#ifdef onnx_ngraph_frontend_EXPORTS +#define ONNX_FRONTEND_API NGRAPH_HELPER_DLL_EXPORT +#else +#define ONNX_FRONTEND_API NGRAPH_HELPER_DLL_IMPORT +#endif + +namespace ngraph +{ + namespace frontend + { + class ONNX_FRONTEND_API FrontEndONNX : public FrontEnd + { + public: + std::shared_ptr convert(InputModel::Ptr model) const override; + std::shared_ptr + convert(std::shared_ptr partially_converted) const override; + std::shared_ptr decode(InputModel::Ptr model) const override; + + protected: + InputModel::Ptr + load_impl(const std::vector>& params) const override; + }; + + } // namespace frontend + +} // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/frontend.cpp b/ngraph/frontend/onnx/frontend/src/frontend.cpp new file mode 100644 index 00000000000000..3caa85db68c365 --- /dev/null +++ b/ngraph/frontend/onnx/frontend/src/frontend.cpp @@ -0,0 +1,56 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include + +using namespace ngraph; +using namespace ngraph::frontend; + +extern "C" ONNX_FRONTEND_API FrontEndVersion GetAPIVersion() +{ + return OV_FRONTEND_API_VERSION; +} + +extern "C" ONNX_FRONTEND_API void* GetFrontEndData() +{ + FrontEndPluginInfo* res = new FrontEndPluginInfo(); + res->m_name = "onnx"; + res->m_creator = []() { return std::make_shared(); }; + return res; +} + +InputModel::Ptr FrontEndONNX::load_impl(const std::vector>& variants) const +{ + NGRAPH_CHECK(variants.size() == 1, + "Only one parameter to load function is expected. Got " + + std::to_string(variants.size())); + NGRAPH_CHECK(is_type>(variants[0]), + "Parameter to load function need to be a std::string"); + auto path = as_type_ptr>(variants[0])->get(); + return std::make_shared(path); +} + +std::shared_ptr FrontEndONNX::convert(InputModel::Ptr model) const +{ + auto model_onnx = std::dynamic_pointer_cast(model); + NGRAPH_CHECK(model_onnx != nullptr, "Invalid input model"); + return model_onnx->convert(); +} + +std::shared_ptr + FrontEndONNX::convert(std::shared_ptr partially_converted) const +{ + return onnx_import::convert_decoded_function(partially_converted); +} + +std::shared_ptr FrontEndONNX::decode(InputModel::Ptr model) const +{ + auto model_onnx = std::dynamic_pointer_cast(model); + NGRAPH_CHECK(model_onnx != nullptr, "Invalid input model"); + return model_onnx->decode(); +} diff --git a/ngraph/frontend/onnx/frontend/src/input_model.cpp b/ngraph/frontend/onnx/frontend/src/input_model.cpp new file mode 100644 index 00000000000000..58ebe098f84b39 --- /dev/null +++ b/ngraph/frontend/onnx/frontend/src/input_model.cpp @@ -0,0 +1,56 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +using namespace ngraph; +using namespace ngraph::frontend; + +InputModelONNX::InputModelONNX(const std::string& path) + : m_editor(path) +{ +} + +std::vector InputModelONNX::get_inputs() const +{ + auto inputs = m_editor.model_inputs(); + std::vector ret; + ret.reserve(inputs.size()); + for (const auto& input : inputs) + { + ret.push_back(std::make_shared(input, m_editor)); + } + return ret; +} + +Place::Ptr InputModelONNX::get_place_by_tensor_name(const std::string& tensor_name) const +{ + return std::make_shared(tensor_name, m_editor); +} + +void InputModelONNX::set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape) +{ + std::map m; + m[place->get_names()[0]] = shape; + m_editor.set_input_shapes(m); +} + +void InputModelONNX::set_element_type(Place::Ptr place, const ngraph::element::Type& type) +{ + std::map m; + m[place->get_names()[0]] = type; + m_editor.set_input_types(m); +} + +std::shared_ptr InputModelONNX::decode() +{ + return m_editor.decode(); +} + +std::shared_ptr InputModelONNX::convert() +{ + return m_editor.get_function(); +} diff --git a/ngraph/frontend/onnx/frontend/src/input_model.hpp b/ngraph/frontend/onnx/frontend/src/input_model.hpp new file mode 100644 index 00000000000000..e1003e3c1bbee5 --- /dev/null +++ b/ngraph/frontend/onnx/frontend/src/input_model.hpp @@ -0,0 +1,33 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +namespace ngraph +{ + namespace frontend + { + class InputModelONNX : public InputModel + { + public: + InputModelONNX(const std::string& path); + + std::vector get_inputs() const override; + Place::Ptr get_place_by_tensor_name(const std::string& tensor_name) const override; + void set_partial_shape(Place::Ptr place, const ngraph::PartialShape& shape) override; + void set_element_type(Place::Ptr place, const ngraph::element::Type& type) override; + + std::shared_ptr decode(); + std::shared_ptr convert(); + + private: + onnx_editor::ONNXModelEditor m_editor; + }; + + } // namespace frontend + +} // namespace ngraph diff --git a/ngraph/frontend/onnx/frontend/src/place.hpp b/ngraph/frontend/onnx/frontend/src/place.hpp new file mode 100644 index 00000000000000..28bbc558741d38 --- /dev/null +++ b/ngraph/frontend/onnx/frontend/src/place.hpp @@ -0,0 +1,78 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ngraph +{ + namespace frontend + { + class PlaceInputEdgeONNX : public Place + { + public: + PlaceInputEdgeONNX(const onnx_editor::InputEdge& edge) + : m_edge(edge) + { + } + + private: + onnx_editor::InputEdge m_edge; + }; + + class PlaceOutputEdgeONNX : public Place + { + public: + PlaceOutputEdgeONNX(const onnx_editor::OutputEdge& edge) + : m_edge(edge) + { + } + + private: + onnx_editor::OutputEdge m_edge; + }; + + class PlaceTensorONNX : public Place + { + public: + PlaceTensorONNX(const std::string& name, const onnx_editor::ONNXModelEditor& editor) + : m_name(name) + , m_editor(editor) + { + } + + std::vector get_names() const override { return {m_name}; } + + Place::Ptr get_producing_port() const override + { + return std::make_shared(m_editor.find_output_edge(m_name)); + } + + std::vector get_consuming_ports() const override + { + std::vector ret; + auto edges = m_editor.find_output_consumers(m_name); + std::transform(edges.begin(), + edges.end(), + std::back_inserter(ret), + [](const onnx_editor::InputEdge& edge) { + return std::make_shared(edge); + }); + return ret; + } + + Ptr get_input_port(int input_port_index) const override + { + return std::make_shared(m_editor.find_input_edge( + onnx_editor::EditorNode(m_name), onnx_editor::EditorInput(input_port_index))); + } + + private: + std::string m_name; + const onnx_editor::ONNXModelEditor& m_editor; + }; + } // namespace frontend + +} // namespace ngraph diff --git a/ngraph/frontend/onnx/onnx_import/include/onnx_editor/editor.hpp b/ngraph/frontend/onnx/onnx_import/include/onnx_editor/editor.hpp index 4f31ab2d323c35..67052da996929f 100644 --- a/ngraph/frontend/onnx/onnx_import/include/onnx_editor/editor.hpp +++ b/ngraph/frontend/onnx/onnx_import/include/onnx_editor/editor.hpp @@ -161,6 +161,11 @@ namespace ngraph /// bool is_correct_and_unambiguous_node(const EditorNode& node) const; + /// \brief Returns a nGraph function based on edited model + /// decoded to framework nodes + /// + std::shared_ptr decode(); + private: void update_mapper_if_needed() const; diff --git a/ngraph/frontend/onnx/onnx_import/include/onnx_import/core/model.hpp b/ngraph/frontend/onnx/onnx_import/include/onnx_import/core/model.hpp new file mode 100644 index 00000000000000..82a6e0c4f10f8f --- /dev/null +++ b/ngraph/frontend/onnx/onnx_import/include/onnx_import/core/model.hpp @@ -0,0 +1,85 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "onnx_import/core/operator_set.hpp" + +namespace ngraph +{ + namespace onnx_import + { + /// \brief Type of container which stores opset version and domain in ONNX format + using OpsetImports = + ::google::protobuf::RepeatedPtrField; + + std::string get_node_domain(const ONNX_NAMESPACE::NodeProto& node_proto); + + std::int64_t get_opset_version(const ONNX_NAMESPACE::ModelProto& model_proto, + const std::string& domain); + + class Model + { + public: + Model() = delete; + explicit Model(std::shared_ptr model_proto); + + Model(const Model&) = delete; + Model(Model&&) = delete; + + Model& operator=(const Model&) = delete; + Model& operator=(Model&&) = delete; + + const std::string& get_producer_name() const { return m_model_proto->producer_name(); } + const ONNX_NAMESPACE::GraphProto& get_graph() const { return m_model_proto->graph(); } + std::int64_t get_model_version() const { return m_model_proto->model_version(); } + const OpsetImports& get_opset_imports() const; + const std::string& get_producer_version() const + { + return m_model_proto->producer_version(); + } + + /// \brief Access an operator object by its type name and domain name + /// The function will return the operator object if it exists, or report an error + /// in case of domain or operator absence. + /// \param name type name of the operator object, + /// \param domain domain name of the operator object. + /// \return Reference to the operator object. + /// \throw error::UnknownDomain there is no operator set defined for the given + /// domain, + /// \throw error::UnknownOperator the given operator type name does not exist in + /// operator set. + const Operator& get_operator(const std::string& name, const std::string& domain) const; + + /// \brief Check availability of operator base on NodeProto. + /// \return `true` if the operator is available, otherwise it returns `false`. + bool is_operator_available(const ONNX_NAMESPACE::NodeProto& node_proto) const; + + /// \brief Enable operators from provided domain to use by this model. + /// + /// \note This function makes visible all currently registered in provided domain + /// operators for use in this model. + /// + /// \param[in] domain The domain name. + /// + void enable_opset_domain(const std::string& domain); + + private: + const std::shared_ptr m_model_proto; + std::unordered_map m_opset; + }; + + inline std::ostream& operator<<(std::ostream& outs, const Model& model) + { + return (outs << ""); + } + + } // namespace onnx_import + +} // namespace ngraph diff --git a/ngraph/frontend/onnx/onnx_import/include/onnx_import/core/node.hpp b/ngraph/frontend/onnx/onnx_import/include/onnx_import/core/node.hpp index 097c100e28f8de..1fefd73bada578 100644 --- a/ngraph/frontend/onnx/onnx_import/include/onnx_import/core/node.hpp +++ b/ngraph/frontend/onnx/onnx_import/include/onnx_import/core/node.hpp @@ -58,7 +58,6 @@ namespace ngraph Node& operator=(const Node&) = delete; OutputVector get_ng_inputs() const; - OutputVector get_ng_nodes() const; const std::string& domain() const; const std::string& op_type() const; const std::string& get_name() const; diff --git a/ngraph/frontend/onnx/onnx_import/include/onnx_import/onnx.hpp b/ngraph/frontend/onnx/onnx_import/include/onnx_import/onnx.hpp index 04e6be0cdd527a..39b923328c123f 100644 --- a/ngraph/frontend/onnx/onnx_import/include/onnx_import/onnx.hpp +++ b/ngraph/frontend/onnx/onnx_import/include/onnx_import/onnx.hpp @@ -72,6 +72,14 @@ namespace ngraph /// \return An nGraph function that represents a single output from the created graph. ONNX_IMPORTER_API std::shared_ptr import_onnx_model(const std::string& file_path); + + /// \brief Converts a nGraph function (onnx model decoded to function with + /// ONNXFrameworkNode(s)) + /// to a complete function with actual compute operations + /// + /// \return A nGraph function. + ONNX_IMPORTER_API + std::shared_ptr convert_decoded_function(std::shared_ptr function); } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx/onnx_import/include/onnx_import/utils/onnx_internal.hpp b/ngraph/frontend/onnx/onnx_import/include/onnx_import/utils/onnx_internal.hpp index 92df626cd24e6b..006fcc561e6c36 100644 --- a/ngraph/frontend/onnx/onnx_import/include/onnx_import/utils/onnx_internal.hpp +++ b/ngraph/frontend/onnx/onnx_import/include/onnx_import/utils/onnx_internal.hpp @@ -36,8 +36,24 @@ namespace ngraph /// /// \return An nGraph function that represents a single output from the created /// graph. - std::shared_ptr import_onnx_model(ONNX_NAMESPACE::ModelProto& model_proto, - const std::string& model_path); + std::shared_ptr + import_onnx_model(std::shared_ptr model_proto, + const std::string& model_path); + + /// \brief Decode ONNX model to nGraph function with ONNXFrameworkNode(s) + /// + /// \param[in] model_proto Reference to a GraphProto object. + /// \param[in] model_path The path to the imported onnx model. + /// It is required if the imported model uses data saved in + /// external files. + /// + /// \return A nGraph function with ONNXFrameworkNodes + ONNX_IMPORTER_API + std::shared_ptr + decode_to_framework_nodes(std::shared_ptr model_proto, + const std::string& model_path); + + std::shared_ptr convert_decoded_function(std::shared_ptr function); } // namespace detail } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx/onnx_import/src/core/attribute.cpp b/ngraph/frontend/onnx/onnx_import/src/core/attribute.cpp index 1fd61931de9629..204959b2fac9b1 100644 --- a/ngraph/frontend/onnx/onnx_import/src/core/attribute.cpp +++ b/ngraph/frontend/onnx/onnx_import/src/core/attribute.cpp @@ -4,8 +4,8 @@ #include "core/attribute.hpp" #include "core/graph.hpp" -#include "core/model.hpp" #include "ngraph/log.hpp" +#include "onnx_import/core/model.hpp" namespace ngraph { @@ -18,15 +18,14 @@ namespace ngraph throw error::attribute::InvalidData{m_attribute_proto->type()}; } - auto model_proto = common::make_unique(); + auto model_proto = std::make_shared(); const auto& graph = m_attribute_proto->g(); model_proto->mutable_graph()->CopyFrom(graph); // set opset version and domain from the parent graph model_proto->mutable_opset_import()->CopyFrom(parent_graph.get_opset_imports()); - auto model = common::make_unique(std::move(model_proto)); - return Subgraph{std::move(model), parent_graph}; + return Subgraph{model_proto, parent_graph}; } } // namespace onnx_import diff --git a/ngraph/frontend/onnx/onnx_import/src/core/graph.cpp b/ngraph/frontend/onnx/onnx_import/src/core/graph.cpp index 569d3849774859..1b1bdcaf8e6953 100644 --- a/ngraph/frontend/onnx/onnx_import/src/core/graph.cpp +++ b/ngraph/frontend/onnx/onnx_import/src/core/graph.cpp @@ -9,12 +9,14 @@ #include "core/graph.hpp" #include "core/null_node.hpp" +#include "core/value_info.hpp" +#include "default_opset.hpp" #include "exceptions.hpp" #include "ngraph/log.hpp" #include "ngraph/node.hpp" #include "ngraph/provenance.hpp" +#include "onnx_framework_node.hpp" #include "onnx_import/core/node.hpp" -#include "onnx_import/onnx_framework_node.hpp" #include "utils/common.hpp" #include "utils/provenance_tag.hpp" @@ -51,15 +53,62 @@ namespace ngraph std::string domain = get_node_domain(node_proto); return (domain.empty() ? "" : domain + ".") + node_proto.op_type(); } + + void add_provenance_tag_to_initializer(const Tensor& tensor, + std::shared_ptr node) + { + if (!ngraph::get_provenance_enabled()) + { + return; + } + + const std::string tag = + detail::build_input_provenance_tag(tensor.get_name(), tensor.get_shape()); + + node->add_provenance_tag(tag); + } + + void add_provenance_tag_to_input(const ValueInfo& input, + std::shared_ptr node) + { + if (!ngraph::get_provenance_enabled()) + { + return; + } + + const std::string tag = + detail::build_input_provenance_tag(input.get_name(), input.get_shape()); + + node->add_provenance_tag(tag); + } + + void add_provenance_tags(const Node& onnx_node, const OutputVector& ng_node_vector) + { + if (!ngraph::get_provenance_enabled()) + { + return; + } + + const auto tag = detail::build_op_provenance_tag(onnx_node); + const auto ng_inputs = onnx_node.get_ng_inputs(); + + ngraph::traverse_nodes( + as_node_vector(ng_node_vector), + [&tag](std::shared_ptr ng_node) { + ng_node->add_provenance_tag(tag); + }, + as_node_vector(ng_inputs)); + } } // namespace detail - Graph::Graph(std::unique_ptr&& model) - : Graph(std::move(model), common::make_unique()) + Graph::Graph(std::shared_ptr model_proto) + : Graph(model_proto, common::make_unique()) { } - Graph::Graph(std::unique_ptr&& model, std::unique_ptr&& cache) - : m_model{std::move(model)} + Graph::Graph(std::shared_ptr model_proto, + std::unique_ptr&& cache) + : m_model{common::make_unique(model_proto)} , m_cache{std::move(cache)} { std::map initializers; @@ -96,7 +145,7 @@ namespace ngraph } initializers.emplace(initializer_tensor.name(), tensor); - add_provenance_tag_to_initializer(tensor, ng_constant); + detail::add_provenance_tag_to_initializer(tensor, ng_constant); m_cache->emplace_node(initializer_tensor.name(), std::move(ng_constant)); } } @@ -104,26 +153,18 @@ namespace ngraph // Process all ONNX graph inputs, convert them to nGraph nodes and store in cache for (const auto& input : m_model->get_graph().input()) { - m_inputs.emplace_back(input); - // Check if a Constant node was already created from an initializer if (m_cache->contains(input.name())) { continue; } - const auto value_info = m_inputs.back(); + ValueInfo value_info{input}; auto ng_node = value_info.get_ng_node(m_parameters, initializers); - add_provenance_tag_to_input(value_info, ng_node); + detail::add_provenance_tag_to_input(value_info, ng_node); m_cache->emplace_node(input.name(), std::move(ng_node)); } - // Process all graph outputs - for (const auto& output : m_model->get_graph().output()) - { - m_outputs.emplace_back(output); - } - // Verify that ONNX graph contains only nodes of available operator types std::map> unknown_operators; @@ -163,19 +204,13 @@ namespace ngraph // Process ONNX graph nodes, convert to nGraph nodes for (const auto& node_proto : m_model->get_graph().node()) { - m_nodes.emplace_back(node_proto, *this); - const Node& node{m_nodes.back()}; + const Node node{node_proto, *this}; if (node.has_subgraph()) { auto subgraph = node.get_subgraph(); auto body_func = subgraph->convert(); } - OutputVector ng_nodes{node.get_ng_nodes()}; - set_friendly_names(node, ng_nodes); - for (std::size_t i{0}; i < node.get_outputs_size(); ++i) - { - m_cache->emplace_node(node.output(i), std::move(ng_nodes.at(i))); - } + OutputVector ng_nodes{make_ng_nodes(node)}; } } @@ -186,11 +221,14 @@ namespace ngraph if ((*param_it)->get_output_target_inputs(0).size() == 0) { const auto& name = (*param_it)->get_friendly_name(); - auto out_it = std::find_if( - m_outputs.begin(), m_outputs.end(), [&name](const ValueInfo& info) { - return info.get_name() == name; - }); - if (out_it == m_outputs.end()) + const auto& onnx_outputs = m_model->get_graph().output(); + auto out_it = + std::find_if(onnx_outputs.begin(), + onnx_outputs.end(), + [&name](const ONNX_NAMESPACE::ValueInfoProto& output) -> bool { + return output.name() == name; + }); + if (out_it == onnx_outputs.end()) { m_cache->remove_node(name); param_it = m_parameters.erase(param_it); @@ -213,8 +251,7 @@ namespace ngraph // Process ONNX graph nodes, convert to nGraph nodes for (const auto& node_proto : m_model->get_graph().node()) { - m_nodes.emplace_back(node_proto, *this); - const Node& node{m_nodes.back()}; + const Node node{node_proto, *this}; std::shared_ptr framework_node; if (node.has_subgraph()) { @@ -223,12 +260,13 @@ namespace ngraph auto inputs = node.get_ng_inputs(); for (const auto& input : subgraph->get_inputs_from_parent()) inputs.push_back(input); - framework_node = - std::make_shared(node, inputs); + framework_node = std::make_shared( + shared_from_this(), node, inputs); } else { - framework_node = std::make_shared(node); + framework_node = std::make_shared( + shared_from_this(), node); } OutputVector ng_nodes{framework_node->outputs()}; set_friendly_names(node, ng_nodes); @@ -245,9 +283,10 @@ namespace ngraph std::shared_ptr Graph::create_function() { auto function = std::make_shared(get_ng_outputs(), m_parameters, get_name()); + const auto& onnx_outputs = m_model->get_graph().output(); for (std::size_t i{0}; i < function->get_output_size(); ++i) { - function->get_output_op(i)->set_friendly_name(m_outputs.at(i).get_name()); + function->get_output_op(i)->set_friendly_name(onnx_outputs.Get(i).name()); } return function; } @@ -307,7 +346,7 @@ namespace ngraph std::rethrow_exception(std::current_exception()); } set_friendly_names(onnx_node, ng_node_vector); - add_provenance_tags(onnx_node, ng_node_vector); + detail::add_provenance_tags(onnx_node, ng_node_vector); for (std::size_t i{0}; i < onnx_node.get_outputs_size(); ++i) { @@ -340,58 +379,14 @@ namespace ngraph } } - void Graph::add_provenance_tag_to_initializer( - const Tensor& tensor, std::shared_ptr node) const - { - if (!ngraph::get_provenance_enabled()) - { - return; - } - - const std::string tag = - detail::build_input_provenance_tag(tensor.get_name(), tensor.get_shape()); - - node->add_provenance_tag(tag); - } - - void Graph::add_provenance_tag_to_input(const ValueInfo& input, - std::shared_ptr node) const - { - if (!ngraph::get_provenance_enabled()) - { - return; - } - - const std::string tag = - detail::build_input_provenance_tag(input.get_name(), input.get_shape()); - - node->add_provenance_tag(tag); - } - - void Graph::add_provenance_tags(const Node& onnx_node, - const OutputVector& ng_node_vector) const - { - if (!ngraph::get_provenance_enabled()) - { - return; - } - - const auto tag = detail::build_op_provenance_tag(onnx_node); - const auto ng_inputs = onnx_node.get_ng_inputs(); - - ngraph::traverse_nodes( - as_node_vector(ng_node_vector), - [&tag](std::shared_ptr ng_node) { ng_node->add_provenance_tag(tag); }, - as_node_vector(ng_inputs)); - } - const OpsetImports& Graph::get_opset_imports() const { return m_model->get_opset_imports(); } - Subgraph::Subgraph(std::unique_ptr&& model, const Graph& parent_graph) - : Graph(std::move(model), common::make_unique()) + Subgraph::Subgraph(std::shared_ptr model_proto, + const Graph& parent_graph) + : Graph(model_proto, common::make_unique()) , m_parent_graph_cache(&parent_graph.get_graph_cache()) { } diff --git a/ngraph/frontend/onnx/onnx_import/src/core/graph.hpp b/ngraph/frontend/onnx/onnx_import/src/core/graph.hpp index 33c2be5d4d20e8..fea67c3e146dc0 100644 --- a/ngraph/frontend/onnx/onnx_import/src/core/graph.hpp +++ b/ngraph/frontend/onnx/onnx_import/src/core/graph.hpp @@ -10,20 +10,18 @@ #include #include "core/graph_cache.hpp" -#include "core/model.hpp" -#include "core/value_info.hpp" -#include "default_opset.hpp" #include "ngraph/op/parameter.hpp" +#include "onnx_import/core/model.hpp" #include "onnx_import/core/operator_set.hpp" namespace ngraph { namespace onnx_import { - class Graph + class Graph : public std::enable_shared_from_this { public: - Graph(std::unique_ptr&& model); + Graph(std::shared_ptr model_proto); Graph() = delete; Graph(const Graph&) = delete; @@ -31,35 +29,24 @@ namespace ngraph Graph& operator=(const Graph&) = delete; Graph& operator=(Graph&&) = default; - virtual std::shared_ptr convert(); std::shared_ptr decode(); - const std::vector& get_nodes() const { return m_nodes; } - const std::vector& get_inputs() const { return m_inputs; } - const std::vector& get_outputs() const { return m_outputs; } + virtual std::shared_ptr convert(); OutputVector get_ng_outputs() const; + const std::string& get_name() const { return m_model->get_graph().name(); } + const GraphCache& get_graph_cache() const; const ParameterVector& get_ng_parameters() const { return m_parameters; } virtual Output get_ng_node_from_cache(const std::string& name) const; - const std::string& get_name() const { return m_model->get_graph().name(); } OutputVector make_ng_nodes(const Node& onnx_node) const; - const GraphCache& get_graph_cache() const; const OpsetImports& get_opset_imports() const; virtual ~Graph() = default; protected: - Graph(std::unique_ptr&& model, std::unique_ptr&& cache); + Graph(std::shared_ptr model, + std::unique_ptr&& cache); void set_friendly_names(const Node& onnx_node, const OutputVector& ng_node_vector) const; - void add_provenance_tag_to_initializer( - const Tensor& initializer, std::shared_ptr node) const; - - void add_provenance_tag_to_input(const ValueInfo& input, - std::shared_ptr node) const; - - void add_provenance_tags(const Node& onnx_node, - const OutputVector& ng_node_vector) const; - protected: virtual void decode_to_framework_nodes(); void convert_to_ngraph_nodes(); @@ -72,8 +59,6 @@ namespace ngraph private: std::vector m_nodes; - std::vector m_inputs; - std::vector m_outputs; }; /// \brief Representation of ONNX subgraph. It is used for example by ONNX Loop op. @@ -86,7 +71,7 @@ namespace ngraph /// /// \param[in] model The ONNX model object. /// \param[in] parent_graph The reference to the parent graph. - Subgraph(std::unique_ptr&& model, const Graph& parent_graph); + Subgraph(std::shared_ptr model, const Graph& parent_graph); /// \brief Return nodes which are on the edge the subgraph and the parent graph. /// \return Vector of edge nodes from parent scope. diff --git a/ngraph/frontend/onnx/onnx_import/src/core/model.cpp b/ngraph/frontend/onnx/onnx_import/src/core/model.cpp index 2ddd3edac02e7a..c7c0993eda1ccd 100644 --- a/ngraph/frontend/onnx/onnx_import/src/core/model.cpp +++ b/ngraph/frontend/onnx/onnx_import/src/core/model.cpp @@ -4,9 +4,9 @@ #include -#include "core/model.hpp" #include "ngraph/log.hpp" -#include "onnx_import/onnx_framework_node.hpp" +#include "onnx_framework_node.hpp" +#include "onnx_import/core/model.hpp" #include "ops_bridge.hpp" namespace ngraph @@ -32,8 +32,8 @@ namespace ngraph throw ngraph_error("Couldn't find operator set's version for domain: " + domain + "."); } - Model::Model(std::unique_ptr&& model_proto) - : m_model_proto{std::move(model_proto)} + Model::Model(std::shared_ptr model_proto) + : m_model_proto{model_proto} { // Walk through the elements of opset_import field and register operator sets // for each domain. An exception UnknownDomain() will raise if the domain is diff --git a/ngraph/frontend/onnx/onnx_import/src/core/model.hpp b/ngraph/frontend/onnx/onnx_import/src/core/model.hpp index 993dfb97e1a97a..82a6e0c4f10f8f 100644 --- a/ngraph/frontend/onnx/onnx_import/src/core/model.hpp +++ b/ngraph/frontend/onnx/onnx_import/src/core/model.hpp @@ -28,7 +28,7 @@ namespace ngraph { public: Model() = delete; - explicit Model(std::unique_ptr&& model_proto); + explicit Model(std::shared_ptr model_proto); Model(const Model&) = delete; Model(Model&&) = delete; @@ -71,7 +71,7 @@ namespace ngraph void enable_opset_domain(const std::string& domain); private: - const std::unique_ptr m_model_proto; + const std::shared_ptr m_model_proto; std::unordered_map m_opset; }; diff --git a/ngraph/frontend/onnx/onnx_import/src/core/node.cpp b/ngraph/frontend/onnx/onnx_import/src/core/node.cpp index b6f2797263b384..fbdcd8d216407f 100644 --- a/ngraph/frontend/onnx/onnx_import/src/core/node.cpp +++ b/ngraph/frontend/onnx/onnx_import/src/core/node.cpp @@ -53,7 +53,6 @@ namespace ngraph } const std::vector& attributes() const; - OutputVector get_ng_nodes(const Node& node) const; OutputVector get_ng_inputs() const; const std::string& domain() const; @@ -172,11 +171,6 @@ namespace ngraph return get_subgraph_from_attribute(name); } - OutputVector Node::Impl::get_ng_nodes(const Node& node) const - { - return m_graph->make_ng_nodes(node); - } - OutputVector Node::Impl::get_ng_inputs() const { OutputVector result; @@ -232,7 +226,6 @@ namespace ngraph } OutputVector Node::get_ng_inputs() const { return m_pimpl->get_ng_inputs(); } - OutputVector Node::get_ng_nodes() const { return m_pimpl->get_ng_nodes(*this); } const std::string& Node::domain() const { return m_pimpl->domain(); } const std::string& Node::op_type() const { return m_pimpl->op_type(); } const std::string& Node::get_description() const { return m_pimpl->description(); } diff --git a/ngraph/frontend/onnx/onnx_import/src/core/transform.cpp b/ngraph/frontend/onnx/onnx_import/src/core/transform.cpp index 119a602ef30bd4..30ce7d78b83503 100644 --- a/ngraph/frontend/onnx/onnx_import/src/core/transform.cpp +++ b/ngraph/frontend/onnx/onnx_import/src/core/transform.cpp @@ -5,8 +5,8 @@ #include #include -#include "core/model.hpp" #include "core/transform.hpp" +#include "onnx_import/core/model.hpp" #include "ngraph/file_util.hpp" #include "ops_bridge.hpp" diff --git a/ngraph/frontend/onnx/onnx_import/src/editor.cpp b/ngraph/frontend/onnx/onnx_import/src/editor.cpp index c4568edbf9c5a8..ef33fad5fcda7a 100644 --- a/ngraph/frontend/onnx/onnx_import/src/editor.cpp +++ b/ngraph/frontend/onnx/onnx_import/src/editor.cpp @@ -187,19 +187,20 @@ namespace /// \brief A helper class used to hold the ModelProto object as its field struct onnx_editor::ONNXModelEditor::Impl { - ONNX_NAMESPACE::ModelProto m_model_proto; + std::shared_ptr m_model_proto; EdgeMapper m_edge_mapper; bool m_is_mapper_updated = false; Impl() = delete; Impl(const std::string& model_path) - : m_model_proto{onnx_common::parse_from_file(model_path)} + : m_model_proto{std::make_shared( + onnx_common::parse_from_file(model_path))} { } - void infer_shapes() { ONNX_NAMESPACE::shape_inference::InferShapes(m_model_proto); } - void remove_shape_inference_info() { m_model_proto.mutable_graph()->clear_value_info(); } + void infer_shapes() { ONNX_NAMESPACE::shape_inference::InferShapes(*m_model_proto.get()); } + void remove_shape_inference_info() { m_model_proto->mutable_graph()->clear_value_info(); } }; onnx_editor::ONNXModelEditor::ONNXModelEditor(const std::string& model_path) @@ -222,7 +223,7 @@ void onnx_editor::ONNXModelEditor::serialize(const std::string& out_file_path) c throw ngraph_error("Could not open the file: " + out_file_path); }; - if (!m_pimpl->m_model_proto.SerializeToOstream(&out_file)) + if (!m_pimpl->m_model_proto->SerializeToOstream(&out_file)) { throw ngraph_error("Could not serialize the model to: " + out_file_path); } @@ -235,7 +236,7 @@ void onnx_editor::ONNXModelEditor::serialize(const std::string& out_file_path) c void onnx_editor::ONNXModelEditor::set_input_types( const std::map& input_types) { - auto* onnx_graph = m_pimpl->m_model_proto.mutable_graph(); + auto* onnx_graph = m_pimpl->m_model_proto->mutable_graph(); for (const auto& input_desc : input_types) { @@ -256,7 +257,7 @@ void onnx_editor::ONNXModelEditor::set_input_types( void onnx_editor::ONNXModelEditor::set_input_shapes( const std::map& input_shapes) { - auto* onnx_graph = m_pimpl->m_model_proto.mutable_graph(); + auto* onnx_graph = m_pimpl->m_model_proto->mutable_graph(); for (const auto& input_desc : input_shapes) { @@ -283,7 +284,7 @@ void onnx_editor::ONNXModelEditor::cut_graph_fragment(const std::vectorinfer_shapes(); - SubgraphExtractor editor{*(m_pimpl->m_model_proto.mutable_graph())}; + SubgraphExtractor editor{*(m_pimpl->m_model_proto->mutable_graph())}; editor.add_new_inputs(inputs); editor.add_new_outputs(outputs); editor.extract_subgraph(outputs); @@ -294,7 +295,7 @@ void onnx_editor::ONNXModelEditor::cut_graph_fragment(const std::vector onnx_editor::ONNXModelEditor::model_inputs() const { - const auto& graph = m_pimpl->m_model_proto.graph(); + const auto& graph = m_pimpl->m_model_proto->graph(); std::vector inputs_and_initializers; inputs_and_initializers.reserve(graph.input_size() + graph.initializer_size()); @@ -314,7 +315,7 @@ std::vector onnx_editor::ONNXModelEditor::model_inputs() const std::string onnx_editor::ONNXModelEditor::model_string() const { - return m_pimpl->m_model_proto.SerializeAsString(); + return m_pimpl->m_model_proto->SerializeAsString(); } std::shared_ptr onnx_editor::ONNXModelEditor::get_function() const @@ -325,7 +326,7 @@ std::shared_ptr onnx_editor::ONNXModelEditor::get_function() const void onnx_editor::ONNXModelEditor::set_input_values( const std::map>& input_values) { - auto onnx_graph = m_pimpl->m_model_proto.mutable_graph(); + auto onnx_graph = m_pimpl->m_model_proto->mutable_graph(); for (const auto& input : input_values) { @@ -354,7 +355,7 @@ void onnx_editor::ONNXModelEditor::update_mapper_if_needed() const { if (!m_pimpl->m_is_mapper_updated) { - m_pimpl->m_edge_mapper = EdgeMapper(m_pimpl->m_model_proto.graph()); + m_pimpl->m_edge_mapper = EdgeMapper(m_pimpl->m_model_proto->graph()); } m_pimpl->m_is_mapper_updated = true; } @@ -391,3 +392,8 @@ bool onnx_editor::ONNXModelEditor::is_correct_and_unambiguous_node(const EditorN update_mapper_if_needed(); return m_pimpl->m_edge_mapper.is_correct_and_unambiguous_node(node); } + +std::shared_ptr onnx_editor::ONNXModelEditor::decode() +{ + return onnx_import::detail::decode_to_framework_nodes(m_pimpl->m_model_proto, m_model_path); +} diff --git a/ngraph/frontend/onnx/onnx_import/src/onnx.cpp b/ngraph/frontend/onnx/onnx_import/src/onnx.cpp index 09f6623611d4eb..39beac60108b06 100644 --- a/ngraph/frontend/onnx/onnx_import/src/onnx.cpp +++ b/ngraph/frontend/onnx/onnx_import/src/onnx.cpp @@ -19,8 +19,8 @@ namespace ngraph std::shared_ptr import_onnx_model(std::istream& stream, const std::string& model_path) { - ONNX_NAMESPACE::ModelProto model_proto{onnx_common::parse_from_istream(stream)}; - + auto model_proto = std::make_shared( + onnx_common::parse_from_istream(stream)); return detail::import_onnx_model(model_proto, model_path); } @@ -58,6 +58,11 @@ namespace ngraph op_name, version, domain == "ai.onnx" ? "" : domain); } + std::shared_ptr convert_decoded_function(std::shared_ptr function) + { + return detail::convert_decoded_function(function); + } + } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx/onnx_import/src/onnx_framework_node.cpp b/ngraph/frontend/onnx/onnx_import/src/onnx_framework_node.cpp index bf52a1a2c0b8a0..c9432d84bfc82b 100644 --- a/ngraph/frontend/onnx/onnx_import/src/onnx_framework_node.cpp +++ b/ngraph/frontend/onnx/onnx_import/src/onnx_framework_node.cpp @@ -14,7 +14,7 @@ // limitations under the License. //***************************************************************************** -#include +#include namespace ngraph { @@ -25,7 +25,7 @@ namespace ngraph std::shared_ptr ONNXFrameworkNode::clone_with_new_inputs(const OutputVector& inputs) const { - return std::make_shared(m_node, inputs); + return std::make_shared(m_graph, m_node, inputs); } NGRAPH_RTTI_DEFINITION(ONNXSubgraphFrameworkNode, "ONNXSubgraphFrameworkNode", 1); diff --git a/ngraph/frontend/onnx/onnx_import/include/onnx_import/onnx_framework_node.hpp b/ngraph/frontend/onnx/onnx_import/src/onnx_framework_node.hpp similarity index 68% rename from ngraph/frontend/onnx/onnx_import/include/onnx_import/onnx_framework_node.hpp rename to ngraph/frontend/onnx/onnx_import/src/onnx_framework_node.hpp index bfa902a5ac449c..7a5269e65986f2 100644 --- a/ngraph/frontend/onnx/onnx_import/include/onnx_import/onnx_framework_node.hpp +++ b/ngraph/frontend/onnx/onnx_import/src/onnx_framework_node.hpp @@ -17,6 +17,8 @@ #pragma once #include +#include +#include #include #include #include @@ -41,19 +43,32 @@ namespace ngraph public: NGRAPH_RTTI_DECLARATION; - ONNXFrameworkNode(const onnx_import::Node& node) + ONNXFrameworkNode(std::shared_ptr graph, + const onnx_import::Node& node) : FrameworkNode(node.get_ng_inputs(), node.get_outputs_size()) , m_node(node) + , m_graph(graph) { } - ONNXFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs) + ONNXFrameworkNode(std::shared_ptr graph, + const onnx_import::Node& node, + const OutputVector& inputs) : FrameworkNode(inputs, node.get_outputs_size()) , m_node(node) + , m_graph(graph) { } - const onnx_import::Node& get_onnx_node() const { return m_node; } + OutputVector get_ng_nodes() const + { + OutputVector ng_nodes{m_graph->make_ng_nodes(m_node)}; + if (ng_nodes.size() > get_output_size()) + { + ng_nodes.resize(get_output_size()); + } + return ng_nodes; + } virtual std::shared_ptr clone_with_new_inputs(const OutputVector& inputs) const override; @@ -68,8 +83,11 @@ namespace ngraph return true; } - private: + protected: onnx_import::Node m_node; + + private: + std::shared_ptr m_graph; }; class ONNXSubgraphFrameworkNode : public ONNXFrameworkNode @@ -77,19 +95,18 @@ namespace ngraph public: NGRAPH_RTTI_DECLARATION; - ONNXSubgraphFrameworkNode(const onnx_import::Node& node, const OutputVector& inputs) - : ONNXFrameworkNode(node, inputs) + ONNXSubgraphFrameworkNode(std::shared_ptr graph, + const onnx_import::Node& node, + const OutputVector& inputs) + : ONNXFrameworkNode(graph, node, inputs) { } - void infer_inputs_from_parent() - { - get_onnx_node().get_subgraph()->infer_inputs_from_parent(); - } + void infer_inputs_from_parent() { m_node.get_subgraph()->infer_inputs_from_parent(); } std::shared_ptr get_subgraph_body() const { - auto subgraph = get_onnx_node().get_subgraph(); + auto subgraph = m_node.get_subgraph(); return std::make_shared(subgraph->get_ng_outputs(), subgraph->get_ng_parameters(), subgraph->get_name()); diff --git a/ngraph/frontend/onnx/onnx_import/src/utils/onnx_internal.cpp b/ngraph/frontend/onnx/onnx_import/src/utils/onnx_internal.cpp index 8e60171a198c91..689eef00cc3f35 100644 --- a/ngraph/frontend/onnx/onnx_import/src/utils/onnx_internal.cpp +++ b/ngraph/frontend/onnx/onnx_import/src/utils/onnx_internal.cpp @@ -5,10 +5,10 @@ #include #include "core/graph.hpp" -#include "core/model.hpp" #include "core/null_node.hpp" #include "core/transform.hpp" -#include "onnx_import/onnx_framework_node.hpp" +#include "onnx_framework_node.hpp" +#include "onnx_import/core/model.hpp" #include "onnx_import/utils/onnx_internal.hpp" namespace ngraph @@ -61,7 +61,7 @@ namespace ngraph } } - void convert_decoded_function(std::shared_ptr function) + std::shared_ptr convert_decoded_function(std::shared_ptr function) { for (const auto& node : function->get_ordered_ops()) { @@ -75,12 +75,7 @@ namespace ngraph subgraph_node->infer_inputs_from_parent(); convert_decoded_function(subgraph_node->get_subgraph_body()); } - const auto& onnx_node = raw_node->get_onnx_node(); - OutputVector ng_nodes{onnx_node.get_ng_nodes()}; - if (ng_nodes.size() > raw_node->get_output_size()) - { - ng_nodes.resize(raw_node->get_output_size()); - } + auto ng_nodes = raw_node->get_ng_nodes(); replace_node(raw_node, ng_nodes); } else @@ -90,22 +85,37 @@ namespace ngraph node->revalidate_and_infer_types(); } } - remove_dangling_parameters(function); - remove_dangling_results(function); + detail::remove_dangling_parameters(function); + detail::remove_dangling_results(function); + + return function; } - std::shared_ptr import_onnx_model(ONNX_NAMESPACE::ModelProto& model_proto, - const std::string& model_path) + void apply_transformations(ONNX_NAMESPACE::ModelProto& model_proto, + const std::string& model_path) { transform::expand_onnx_functions(model_proto); transform::fixup_legacy_operators(model_proto); transform::update_external_data_paths(model_proto, model_path); + } - auto p_model_proto = common::make_unique(model_proto); - auto model = common::make_unique(std::move(p_model_proto)); - Graph graph{std::move(model)}; + std::shared_ptr + import_onnx_model(std::shared_ptr model_proto, + const std::string& model_path) + { + apply_transformations(*model_proto, model_path); + Graph graph{model_proto}; return graph.convert(); } + + std::shared_ptr + decode_to_framework_nodes(std::shared_ptr model_proto, + const std::string& model_path) + { + apply_transformations(*model_proto, model_path); + auto graph = std::make_shared(model_proto); + return graph->decode(); + } } // namespace detail } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/python/tests/test_frontend/test_frontend_onnx.py b/ngraph/python/tests/test_frontend/test_frontend_onnx.py new file mode 100644 index 00000000000000..1dbe6a34ae637c --- /dev/null +++ b/ngraph/python/tests/test_frontend/test_frontend_onnx.py @@ -0,0 +1,97 @@ +# Copyright (C) 2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import os +import onnx +import numpy as np +from onnx.helper import make_graph, make_model, make_tensor_value_info +import pytest + +from ngraph.frontend import FrontEndManager +from tests.runtime import get_runtime + + +def create_onnx_model(): + add = onnx.helper.make_node("Add", inputs=["x", "y"], outputs=["z"]) + const_tensor = onnx.helper.make_tensor("const_tensor", onnx.TensorProto.FLOAT, (2, 2), [0.5, 1, 1.5, 2.0]) + const_node = onnx.helper.make_node("Constant", [], outputs=["const_node"], + value=const_tensor, name="const_node") + mul = onnx.helper.make_node("Mul", inputs=["z", "const_node"], outputs=["out"]) + input_tensors = [ + make_tensor_value_info("x", onnx.TensorProto.FLOAT, (2, 2)), + make_tensor_value_info("y", onnx.TensorProto.FLOAT, (2, 2)), + ] + output_tensors = [make_tensor_value_info("out", onnx.TensorProto.FLOAT, (2, 2))] + graph = make_graph([add, const_node, mul], "graph", input_tensors, output_tensors) + return make_model(graph, producer_name="ngraph ONNX Importer") + + +def run_function(function, *inputs, expected): + runtime = get_runtime() + computation = runtime.computation(function) + actual = computation(*inputs) + assert len(actual) == len(expected) + for i in range(len(actual)): + np.testing.assert_allclose(expected[i], actual[i], rtol=1e-3, atol=1e-6) + + +fem = FrontEndManager() +onnx_model_filename = "model.onnx" + + +def setup_module(): + onnx.save_model(create_onnx_model(), onnx_model_filename) + + +def teardown_module(): + os.remove(onnx_model_filename) + + +def skip_if_onnx_frontend_is_disabled(): + front_ends = fem.get_available_front_ends() + if "onnx" not in front_ends: + pytest.skip() + + +def test_convert(): + skip_if_onnx_frontend_is_disabled() + + fe = fem.load_by_framework(framework="onnx") + assert fe + + model = fe.load(onnx_model_filename) + assert model + + function = fe.convert(model) + assert function + + a = np.array([[1, 2], [3, 4]], dtype=np.float32) + b = np.array([[2, 3], [4, 5]], dtype=np.float32) + expected = np.array([[1.5, 5], [10.5, 18]], dtype=np.float32) + run_function(function, a, b, expected=[expected]) + + +def test_decode_and_convert(): + skip_if_onnx_frontend_is_disabled() + + fe = fem.load_by_framework(framework="onnx") + assert fe + + model = fe.load(onnx_model_filename) + assert model + + decoded_function = fe.decode(model) + assert decoded_function + for op in decoded_function.get_ordered_ops(): + assert op.get_type_name() in ["Parameter", "Constant", "ONNXFrameworkNode", + "ONNXSubgraphFrameworkNode", "Result"] + + function = fe.convert(decoded_function) + assert function + for op in function.get_ordered_ops(): + assert op.get_type_name() not in ["ONNXFrameworkNode", "ONNXSubgraphFrameworkNode"] + + a = np.array([[1, 2], [3, 4]], dtype=np.float32) + b = np.array([[2, 3], [4, 5]], dtype=np.float32) + expected = np.array([[1.5, 5], [10.5, 18]], dtype=np.float32) + run_function(function, a, b, expected=[expected]) diff --git a/ngraph/python/tests/test_ngraph/test_frontendmanager.py b/ngraph/python/tests/test_frontend/test_frontendmanager.py similarity index 100% rename from ngraph/python/tests/test_ngraph/test_frontendmanager.py rename to ngraph/python/tests/test_frontend/test_frontendmanager.py