From 11ffccdfd601371d25d48097391fd23c6015e29c Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Fri, 30 Sep 2022 14:18:39 +0200 Subject: [PATCH 1/2] Support shufflenet and add layer test for relu --- src/core/src/op/scatter_elements_update.cpp | 3 +- .../openvino/frontend/pytorch/frontend.hpp | 40 +++++- .../frontend/pytorch/node_context.hpp | 2 +- src/frontends/pytorch/src/frontend.cpp | 58 ++++++-- src/frontends/pytorch/src/node_context.cpp | 4 +- src/frontends/pytorch/src/op/transpose.cpp | 64 +++++++++ src/frontends/pytorch/src/op_table.cpp | 71 +++------- .../pytorch/src/pt_framework_node.hpp | 4 +- .../src/transforms/aten_cat_replacer.cpp | 77 +++++++++++ .../src/transforms/aten_cat_replacer.hpp | 26 ++++ .../transforms/prim_list_unpack_replacer.cpp | 85 ++++++++++++ .../transforms/prim_list_unpack_replacer.hpp | 25 ++++ src/frontends/pytorch/src/utils.cpp | 125 +++--------------- src/frontends/pytorch/src/utils.hpp | 8 ++ tests/layer_tests/pytorch_tests/conftest.py | 12 ++ .../pytorch_tests/pytorch_layer_test_class.py | 111 ++++++++++++++++ tests/layer_tests/pytorch_tests/test_relu.py | 31 +++++ 17 files changed, 563 insertions(+), 183 deletions(-) create mode 100644 src/frontends/pytorch/src/op/transpose.cpp create mode 100644 src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp create mode 100644 src/frontends/pytorch/src/transforms/aten_cat_replacer.hpp create mode 100644 src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp create mode 100644 src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.hpp create mode 100644 tests/layer_tests/pytorch_tests/conftest.py create mode 100644 tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py create mode 100644 tests/layer_tests/pytorch_tests/test_relu.py diff --git a/src/core/src/op/scatter_elements_update.cpp b/src/core/src/op/scatter_elements_update.cpp index 1a9de336f64e0e..dc531f25ab0f25 100644 --- a/src/core/src/op/scatter_elements_update.cpp +++ b/src/core/src/op/scatter_elements_update.cpp @@ -44,8 +44,9 @@ void op::v3::ScatterElementsUpdate::validate_and_infer_types() { NODE_VALIDATION_CHECK(this, axis_et.is_integral(), "Axis element type must be integral_number, but is: ", axis_et); + element::Type merged_type; NODE_VALIDATION_CHECK(this, - data_et == updates_et, + element::Type::merge(merged_type, data_et, updates_et), "Data type and updates type are required to be the same. ", "Got: ", data_et, diff --git a/src/frontends/pytorch/include/openvino/frontend/pytorch/frontend.hpp b/src/frontends/pytorch/include/openvino/frontend/pytorch/frontend.hpp index 64599c90855b52..bc1893b4e300ae 100644 --- a/src/frontends/pytorch/include/openvino/frontend/pytorch/frontend.hpp +++ b/src/frontends/pytorch/include/openvino/frontend/pytorch/frontend.hpp @@ -19,19 +19,51 @@ class PYTORCH_API FrontEnd : public ov::frontend::FrontEnd { public: using Ptr = std::shared_ptr; + /// \brief Completely convert and normalize entire Model, throws if it is not possible + /// \param model Input model + /// \return fully converted OV Model std::shared_ptr convert(const ov::frontend::InputModel::Ptr& model) const override; + /// \brief Completely convert the remaining, not converted part of a Model. + /// \param partiallyConverted partially converted OV Model + void convert(const std::shared_ptr& partiallyConverted) const override; + + /// \brief Convert only those parts of the model that can be converted leaving others + /// as-is. Converted parts are not normalized by additional transformations; normalize + /// function or another form of convert function should be called to finalize the + /// conversion process. + /// \param model Input model + /// \return partially converted OV Model + std::shared_ptr convert_partially(const InputModel::Ptr& model) const override; + + /// \brief Convert operations with one-to-one mapping with decoding nodes. + /// Each decoding node is an OV node representing a single FW operation node with + /// all attributes represented in FW-independent way. + /// \param model Input model + /// \return OV Model after decoding + std::shared_ptr decode(const InputModel::Ptr& model) const override; + + /// \brief Runs normalization passes on Model that was loaded with partial conversion + /// \param Model partially converted OV Model + void normalize(const std::shared_ptr& model) const override; + + /// \brief Gets name of this FrontEnd. Can be used by clients + /// if frontend is selected automatically by FrontEndManager::load_by_model + /// \return Paddle frontend name. std::string get_name() const override { return "pytorch"; } + /// \brief Register base extension in the FrontEnd + /// \param extension base extension + void add_extension(const std::shared_ptr& extension) override; + protected: bool supported_impl(const std::vector& variants) const override; ov::frontend::InputModel::Ptr load_impl(const std::vector& variants) const override; }; -} // namespace pytorch -} // namespace frontend -} // namespace ov - +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp b/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp index 91d052d45d3a8b..fdbece66ce5482 100644 --- a/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp +++ b/src/frontends/pytorch/include/openvino/frontend/pytorch/node_context.hpp @@ -145,7 +145,7 @@ class NodeContext : public frontend::NodeContext { parameter->get_output_tensor(0).add_names({std::to_string(index)}); (*m_tensor_map)[index] = parameter; m_external_parameters->push_back(parameter); - std::cout << "Nested case, created: " << parameter << std::endl; + //std::cout << "Nested case, created: " << parameter << std::endl; return parameter; } } diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index f22b3a9305ae60..2ac6efd4404d2c 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -9,10 +9,12 @@ #include "exception.hpp" #include "input_model.hpp" #include "transforms.hpp" -#include "openvino/frontend/pytorch/node_context.hpp" #include "op_table.hpp" #include "openvino/frontend/exception.hpp" +#include "openvino/frontend/pytorch/node_context.hpp" #include "pt_framework_node.hpp" +#include "transforms/aten_cat_replacer.hpp" +#include "transforms/prim_list_unpack_replacer.hpp" #include "utils.hpp" namespace ov { @@ -20,11 +22,31 @@ namespace frontend { namespace pytorch { std::shared_ptr FrontEnd::convert(const ov::frontend::InputModel::Ptr& model) const { + auto converted_model = convert_partially(model); + normalize(converted_model); + std::set unconverted_ops_types; + for (const auto& node : converted_model->get_ordered_ops()) { + if (const auto& fw_node = ov::as_type_ptr(node)) { + auto op_type = fw_node->get_decoder()->get_op_type(); + unconverted_ops_types.insert(op_type); + } + } + std::stringstream ops_str; + for (auto&& op_type : unconverted_ops_types) { + ops_str << op_type << "\n"; + } + FRONT_END_OP_CONVERSION_CHECK(unconverted_ops_types.size() == 0, + "Model wasn't fully converted. Unconverted operation types:\n" + ops_str.str()); + return converted_model; +} + +void FrontEnd::convert(const std::shared_ptr& partiallyConverted) const { + FRONT_END_NOT_IMPLEMENTED(convert); +} + +std::shared_ptr FrontEnd::convert_partially(const ov::frontend::InputModel::Ptr& model) const { try { - // std::cerr << "[ HERE ]\n"; auto pytorch_model = std::dynamic_pointer_cast(model); - // TODO: Remove this super-hack, tensor_map should be local for each conversion activity, see more info where - // tensor_map is defined now auto model = convert_pytorch_model(pytorch_model->m_model); // TODO: Propose better solution for the next code block @@ -36,8 +58,8 @@ std::shared_ptr FrontEnd::convert(const ov::frontend::InputModel::Ptr& mo auto self = model->get_parameters()[0]; if (self->output(0).get_target_inputs().empty()) { // There is no consumers: safe to remove - std::cout << "[ WARNING ] Removing parameter[0] in converted Pytorch model, because it is never " - "used and treated as `self`\n"; + // std::cout << "[ WARNING ] Removing parameter[0] in converted Pytorch model, because it is never " + // "used and treated as `self`\n"; model->remove_parameter(self); } else { std::cout << "[ WARNING ] Couldn't remove parameter[0] in converted Pytorch model\n"; @@ -46,25 +68,39 @@ std::shared_ptr FrontEnd::convert(const ov::frontend::InputModel::Ptr& mo apply_pytorch_conversion_transforms(model); return model; } catch (const std::runtime_error& e) { - std::cerr << "[ ERROR ] Error while converting pytorch model: " << e.what() << "\n"; + std::cerr << "[ ERROR ] Unexpected error while converting pytorch model: " << e.what() << "\n"; std::cerr << "Rethrowing. Misleading error message from pybind11 may come next. TODO."; throw; } } +std::shared_ptr FrontEnd::decode(const InputModel::Ptr& model) const { + FRONT_END_NOT_IMPLEMENTED(decode); +} + +void FrontEnd::normalize(const std::shared_ptr& model) const { + ov::pass::Manager manager; + + manager.register_pass(); + manager.register_pass(); + + manager.run_passes(model); +} + +void FrontEnd::add_extension(const std::shared_ptr& extension) { + FRONT_END_NOT_IMPLEMENTED(add_extension); +} + bool FrontEnd::supported_impl(const std::vector& variants) const { - // std::cout << "[ ----- DEBUG ------ ] supported_impl with " << variants.size() << " arguments\n"; return false; } ov::frontend::InputModel::Ptr FrontEnd::load_impl(const std::vector& variants) const { - // std::cout << "[ ----- DEBUG ----- ] load_impl with " << variants.size() << " parameters\n"; if (variants.size() != 1) { throw std::runtime_error("Pytorch frontend supports exactly one parameter in model representation, got " + - std::to_string(variants.size()) + "instead."); + std::to_string(variants.size()) + " instead."); } auto decoder = variants[0].as>(); - // std::cout << "Recognized decoder: " << decoder << "\n"; return std::make_shared(decoder); } diff --git a/src/frontends/pytorch/src/node_context.cpp b/src/frontends/pytorch/src/node_context.cpp index bc49bbd7b82b59..c8edd96ed836ca 100644 --- a/src/frontends/pytorch/src/node_context.cpp +++ b/src/frontends/pytorch/src/node_context.cpp @@ -49,8 +49,8 @@ std::shared_ptr NodeContext::convert_subgraph(size_t index) { auto parameter = model->get_parameters()[i]; if (parameter->output(0).get_target_inputs().empty()) { // There is no consumers: safe to remove - std::cout << "[ WARNING ] Removing parameter " << parameter - << " in converted Pytorch model, because it is never used" << std::endl; + //std::cout << "[ WARNING ] Removing parameter " << parameter + // << " in converted Pytorch model, because it is never used" << std::endl; model->remove_parameter(parameter); } } diff --git a/src/frontends/pytorch/src/op/transpose.cpp b/src/frontends/pytorch/src/op/transpose.cpp new file mode 100644 index 00000000000000..4c8b2a935ea553 --- /dev/null +++ b/src/frontends/pytorch/src/op/transpose.cpp @@ -0,0 +1,64 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/opsets/opset8.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +OutputVector translate_transpose(NodeContext& context) { + auto dim0 = context.const_input(1); + auto dim1 = context.const_input(2); + auto shape = std::make_shared(context.get_input(0), element::i32); + auto rank_ = std::make_shared(shape, element::i32); + auto rank = std::make_shared(rank_); + // Use opset::If for dim normalization + auto dim0_node = context.get_input(1); + auto dim1_node = context.get_input(2); + if (dim0 < 0) { + dim0_node = std::make_shared(rank, dim0_node); + } + if (dim1 < 0) { + dim1_node = std::make_shared(rank, dim1_node); + } + auto start = opset8::Constant::create(element::i32, {}, {0}); + auto step = opset8::Constant::create(element::i32, {}, {1}); + auto range = std::make_shared(start, rank, step, element::i32); + + auto axis_0 = opset8::Constant::create(element::i64, Shape{}, {0}); + dim0_node = std::make_shared(dim0_node, axis_0); + dim1_node = std::make_shared(dim1_node, axis_0); + auto indices = std::make_shared(OutputVector{dim0_node, dim1_node}, 0); + auto updates = std::make_shared(OutputVector{dim1_node, dim0_node}, 0); + auto scatter = std::make_shared(range, indices, updates, axis_0); + + /*auto data_pshape = context.get_input(0).get_partial_shape(); + auto rank = data_pshape.rank(); + OV_FRONTEND_REQUIRE(rank.is_static()); + auto _rank = rank.get_length(); + if (dim0 < 0) { + dim0 = _rank + dim0; + } + if (dim1 < 0) { + dim1 = _rank + dim1; + } + OV_FRONTEND_REQUIRE(dim0 > 0 && dim1 > 0); + OV_FRONTEND_REQUIRE(dim0 < _rank && dim1 < _rank); + std::vector order(_rank, 0); + std::iota(order.begin(), order.end(), 0); + std::swap(order[dim0], order[dim1]); + auto order_const = context.mark_node(opset8::Constant::create(element::i64, {order.size()}, order));*/ + return {context.mark_node(std::make_shared(context.get_input(0), scatter))}; +}; + +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 1c08a6cdef05f2..66460b0e6a9193 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -18,6 +18,7 @@ namespace op { OP_CONVERTER(translate_if); OP_CONVERTER(translate_loop); OP_CONVERTER(translate_slice); +OP_CONVERTER(translate_transpose); OutputVector relu(NodeContext& context) { return {context.mark_node(std::make_shared(context.get_input(0)))}; @@ -207,7 +208,7 @@ const std::map get_supported_ops() { ))}; }}, - {"aten::layer_norm", + /*{"aten::layer_norm", [](NodeContext& context) -> OutputVector { auto normalized_shape = context.const_input(1); auto in_pshape_last_dim = *context.get_input(0).get_partial_shape().rbegin(); @@ -228,7 +229,7 @@ const std::map get_supported_ops() { out_node = std::dynamic_pointer_cast(add); } return {context.mark_node(out_node)}; - }}, + }},*/ {"aten::add", op::add}, {"aten::add_", inplace_op}, @@ -252,6 +253,12 @@ const std::map get_supported_ops() { std::make_shared(context.get_input(0), context.get_input(1), pythondiv))}; }}, + {"aten::floordiv", + [](NodeContext& context) -> OutputVector { + return { + context.mark_node(std::make_shared(context.get_input(0), context.get_input(1), true))}; + }}, + {"aten::tanh", [](NodeContext& context) -> OutputVector { return {context.mark_node(std::make_shared(context.get_input(0)))}; @@ -325,32 +332,7 @@ const std::map get_supported_ops() { context.mark_node(std::make_shared(context.get_input(0), static_cast(axis)))}; }}, - {"aten::cat", - [](NodeContext& context) -> OutputVector { - // aten::cat needs a special handling since it takes a Tensor[] as - // input. We set the inputs of ListConstruct as the inputs of cat. - // - // Pytorch IR: LLGA sees: - // %a %b %c %dim %a %b %c - // \ | / | \ | / - // prim::ListConstruct prim::Constant llga::Concat[axis=%dim] - // \ / - // aten::cat - auto listConstruct = context.get_input(0).get_node(); - auto listConstruct_fw_node = dynamic_cast(listConstruct); - OV_FRONTEND_REQUIRE(listConstruct_fw_node); - OV_FRONTEND_REQUIRE(listConstruct_fw_node->get_op_type() == "prim::ListConstruct"); - auto axis = context.const_input(1); - OutputVector inputs; - for (auto& input : listConstruct->inputs()) { - inputs.push_back(input.get_source_output()); - } - auto result = context.mark_node(std::make_shared(inputs, axis)); - // TODO: do we really need to do that? - // auto list_set = listConstruct_fw_node->get_rt_info()["pt_node"].as>(); - // result->get_rt_info()["pt_node"].as>().insert(list_set.begin(), list_set.end()); - return {result}; - }}, + //{"aten::cat", done as transformation}, {"aten::matmul", [](NodeContext& context) -> OutputVector { @@ -494,32 +476,11 @@ const std::map get_supported_ops() { std::make_shared(context.get_input(0), context.get_input(1), axis_0))}; }}, - {"aten::transpose", - [](NodeContext& context) -> OutputVector { - auto dim0 = context.const_input(1); - auto dim1 = context.const_input(2); - auto data_pshape = context.get_input(0).get_partial_shape(); - auto rank = data_pshape.rank(); - OV_FRONTEND_REQUIRE(rank.is_static()); - auto _rank = rank.get_length(); - if (dim0 < 0) { - dim0 = _rank + dim0; - } - if (dim1 < 0) { - dim1 = _rank + dim1; - } - OV_FRONTEND_REQUIRE(dim0 > 0 && dim1 > 0); - OV_FRONTEND_REQUIRE(dim0 < _rank && dim1 < _rank); - std::vector order(_rank, 0); - std::iota(order.begin(), order.end(), 0); - std::swap(order[dim0], order[dim1]); - auto order_const = context.mark_node(opset8::Constant::create(element::i64, {order.size()}, order)); - return {context.mark_node(std::make_shared(context.get_input(0), order_const))}; - }}, + {"aten::transpose", op::translate_transpose}, {"aten::size", [](NodeContext& context) -> OutputVector { - auto shape = context.mark_node(std::make_shared(context.get_input(0))); + auto shape = context.mark_node(std::make_shared(context.get_input(0), element::i32)); if (context.input_is_none(1)) { return shape->outputs(); } else { @@ -533,8 +494,8 @@ const std::map get_supported_ops() { auto shape_node = context.get_input(1).get_node(); auto shape_node_fw_node = dynamic_cast(shape_node); std::shared_ptr reshape; + // TODO: move this to transform stage if (shape_node_fw_node && shape_node_fw_node->get_decoder()->get_op_type() == "prim::ListConstruct") { - // TODO: maybe use pt shape instead of whole shape subgraph, because it may be more efficent OutputVector inputs; auto axis_0 = context.mark_node(opset8::Constant::create(element::i64, Shape{}, {0})); for (auto& input : shape_node->inputs()) { @@ -546,8 +507,10 @@ const std::map get_supported_ops() { } auto concat = context.mark_node(std::make_shared(inputs, 0)); reshape = context.mark_node(std::make_shared(context.get_input(0), concat, false)); - auto list_set = shape_node_fw_node->get_rt_info()["pt_node"].as>(); - reshape->get_rt_info()["pt_node"].as>().insert(list_set.begin(), list_set.end()); + // TODO: fix rt_info + // auto list_set = shape_node_fw_node->get_rt_info()["pt_node"].as>(); + // reshape->get_rt_info()["pt_node"].as>().insert(list_set.begin(), + // list_set.end()); } else { reshape = context.mark_node( std::make_shared(context.get_input(0), context.get_input(1), false)); diff --git a/src/frontends/pytorch/src/pt_framework_node.hpp b/src/frontends/pytorch/src/pt_framework_node.hpp index 0fe52fe21a412b..bdbc937fb60492 100644 --- a/src/frontends/pytorch/src/pt_framework_node.hpp +++ b/src/frontends/pytorch/src/pt_framework_node.hpp @@ -42,8 +42,8 @@ class PtFrameworkNode : public ov::op::util::FrameworkNode { std::cerr << "[ ERROR ] Cannot retrieve type\n" << e.what() << std::endl; } } else { - std::cerr << "[ WARNING ] Cannot retrieve type for output not existent in pt node: " - << m_decoder->get_op_type() << " with 0 input: " << m_decoder->input(0) << std::endl; + //std::cerr << "[ WARNING ] Cannot retrieve type for output not existent in pt node: " + // << m_decoder->get_op_type() << " with 0 input: " << m_decoder->input(0) << std::endl; } // Let's see what type we have // std::cout << "Can be represented as element::Type: " << type.is() << std::endl; diff --git a/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp new file mode 100644 index 00000000000000..92da806cdffd2c --- /dev/null +++ b/src/frontends/pytorch/src/transforms/aten_cat_replacer.cpp @@ -0,0 +1,77 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "aten_cat_replacer.hpp" + +#include +#include + +#include "openvino/frontend/pytorch/visibility.hpp" +#include "openvino/op/util/framework_node.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +// aten::cat needs a special handling since it takes a Tensor[] as input. We set the inputs of ListConstruct as the +// inputs of cat. +// +// Pytorch IR: OV model: +// %a %b %c %dim %a %b %c +// \ | / | \ | / +// prim::ListConstruct prim::Constant Concat[axis=%dim] +// \ / +// aten::cat +AtenCatToConcat::AtenCatToConcat() { + auto aten_cat = ov::pass::pattern::wrap_type(); + + ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { + auto cat = cast_fw_node(m.get_match_root(), "aten::cat"); + if (!cat) + return false; + + auto axis_node = cat->input(1).get_source_output().get_node_shared_ptr(); + auto axis_const = std::dynamic_pointer_cast(axis_node); + if (!axis_const) + return false; + auto axis = axis_const->cast_vector(); + if (axis.size() != 1) + return false; + + OutputVector tmp_inputs; + NodeVector rt_copy_from{cat}; + std::shared_ptr input_node = cat->input(0).get_source_output().get_node_shared_ptr(); + while (const auto& input_fw_node = cast_fw_node(input_node, "aten::append")) { + rt_copy_from.push_back(input_fw_node); + tmp_inputs.push_back(input_fw_node->input(1).get_source_output()); + input_node = input_fw_node->input(0).get_source_output().get_node_shared_ptr(); + } + auto list_construct = cast_fw_node(input_node, "prim::ListConstruct"); + if (!list_construct) + return false; + rt_copy_from.push_back(list_construct); + OutputVector inputs; + for (auto& input : list_construct->inputs()) { + inputs.push_back(input.get_source_output()); + } + inputs.insert(inputs.end(), tmp_inputs.rbegin(), tmp_inputs.rend()); + auto result = std::make_shared(inputs, axis[0]); + copy_runtime_info(rt_copy_from, result); + replace_node(cat, result); + + return true; + }; + + auto m = std::make_shared(aten_cat, "ov::frontend::pytorch::pass::AtenCatToConcat"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/transforms/aten_cat_replacer.hpp b/src/frontends/pytorch/src/transforms/aten_cat_replacer.hpp new file mode 100644 index 00000000000000..2f6ef0646d67ce --- /dev/null +++ b/src/frontends/pytorch/src/transforms/aten_cat_replacer.hpp @@ -0,0 +1,26 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/frontend/pytorch/visibility.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +// This transformation replaces pattern prim::ListConstruct->aten::append{none or many}->aten::cat +class PYTORCH_API AtenCatToConcat : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::AtenCatToConcat"); + AtenCatToConcat(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp new file mode 100644 index 00000000000000..ce2e04b5935d4f --- /dev/null +++ b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp @@ -0,0 +1,85 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "prim_list_unpack_replacer.hpp" + +#include +#include + +#include "openvino/frontend/pytorch/visibility.hpp" +#include "openvino/op/util/framework_node.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +PrimListUnpackReplacer::PrimListUnpackReplacer() { + auto list_unpack = ov::pass::pattern::wrap_type(); + + ov::matcher_pass_callback callback = [](ov::pass::pattern::Matcher& m) { + auto list_unpack = cast_fw_node(m.get_match_root(), "prim::ListUnpack"); + if (!list_unpack) + return false; + + auto input_node = list_unpack->input(0).get_source_output().get_node_shared_ptr(); + if (auto split_with_sizes = cast_fw_node(input_node, "aten::split_with_sizes")) { + auto split = std::make_shared(split_with_sizes->get_input_source_output(0), + split_with_sizes->get_input_source_output(2), + split_with_sizes->get_input_source_output(1)); + + copy_runtime_info({list_unpack, input_node}, split); + replace_node(list_unpack, split); + + return true; + } + + if (auto chunk = cast_fw_node(input_node, "aten::chunk")) { + // Using number of ListUnpack outputs instead of 1st input to chunk. + // TODO: confirm it works for all cases + auto split = std::make_shared(chunk->get_input_source_output(0), + chunk->get_input_source_output(2), + list_unpack->get_output_size()); + + copy_runtime_info({list_unpack, input_node}, split); + replace_node(list_unpack, split); + + return true; + } + + if (auto shape_of = std::dynamic_pointer_cast(input_node)) { + // case aten::size as input + // Number of ListUnpack outputs should be equal to rank of input shape. + auto axis_0 = opset8::Constant::create(element::i64, Shape{}, {0}); + auto split = std::make_shared(shape_of, axis_0, list_unpack->get_output_size()); + + NodeVector to_copy_rt{axis_0, split}; + OutputVector res; + for (auto output: split->outputs()) { + auto squeeze = std::make_shared(output, axis_0); + to_copy_rt.push_back(squeeze); + res.push_back(squeeze); + } + + copy_runtime_info({list_unpack, input_node}, to_copy_rt); + replace_node(list_unpack, res); + + return true; + } + + return false; + }; + + auto m = std::make_shared(list_unpack, + "ov::frontend::pytorch::pass::PrimListUnpackReplacer"); + this->register_matcher(m, callback); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.hpp b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.hpp new file mode 100644 index 00000000000000..7302c006ca5697 --- /dev/null +++ b/src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/frontend/pytorch/visibility.hpp" +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pass.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace pass { + +class PYTORCH_API PrimListUnpackReplacer : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("ov::frontend::pytorch::pass::PrimListUnpackReplacer"); + PrimListUnpackReplacer(); +}; + +} // namespace pass +} // namespace pytorch +} // namespace frontend +} // namespace ov \ No newline at end of file diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 07f3b4f50e4304..16e2a9d9be6231 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -10,8 +10,6 @@ namespace ov { namespace frontend { namespace pytorch { -int LEVEL = 0; -int NUMBER = 0; int COUNTER = 0; Output make_optional_bias(Output base_op, @@ -89,40 +87,14 @@ Output reshape_kernel_for_group(const NodeContext& context, } OutputVector convert_node(NodeContext* context) { - // std::cout << "[ ---- DEBUG ---- ] convert_node\n"; - - // std::cerr << "---\nAttempting to convert " << node->kind().toQualString() << "\n"; - // node->dump(); - - // std::cerr << "[ DEBUG ] Attempting to convert " << context.get_op_type() << "\n"; - try { auto CONVERTERS_MAP = get_supported_ops(); auto it = CONVERTERS_MAP.find(context->get_op_type()); if (it != CONVERTERS_MAP.end()) { - // std::cout << "FOUND converter for " << context.get_op_type() << "\n"; return it->second(*context); - } else { + } /*else { const std::set known_skips{"prim::RaiseException", - "aten::warn", - // remove all that above - "prim::TupleConstruct", - "prim::ListConstruct", - "aten::format", - "aten::append", - "aten::update", - "aten::dict", - "aten::list", - "aten::_set_item", - "aten::__getitem__", - "aten::__isnot__", - "aten::__contains__", - "prim::unchecked_cast", - "prim::Uninitialized", - "prim::SetAttr", - "prim::GetAttr", - "prim::ListUnpack", - "aten::__not__"}; + "aten::warn"}; if (!known_skips.count(context->get_op_type())) { std::cout << "DIDN'T FIND converter for " << context->get_op_type() << " with inputs:"; if (context->inputs().size() == 0) { @@ -133,29 +105,18 @@ OutputVector convert_node(NodeContext* context) { } std::cout << " with schema: " << context->get_schema() << std::endl; } - } + }*/ } - // catch(pybind11::error_already_set& e) { - // std::cout << "Python exception: " << e << "\n"; - // } catch (std::runtime_error& e) { std::cout << "Exception happened during conversion of op: " << context->get_op_type() << " with schema: " << context->get_schema() << ": " << e.what() << '\n'; - // throw; } catch (...) { std::cout << "Some exception happened during conversion of node of type: " << context->get_op_type() << std::endl; - // throw; } - // if (node->kind() != prim::ListConstruct) { - // std::cout << "Making unsupported " << node->kind().toQualString() << std::endl; - // node->dump(); - // } - // Create PtFrameworkNode for everything that wasn't able to be converted normally // Pay attention to subgraphs that may appear in the node - // std::cerr << "[ DEBUG ] Before PtFramewokNode creation\n"; auto schema = context->get_schema(); if (schema.find('!') != std::string::npos) { @@ -167,7 +128,7 @@ OutputVector convert_node(NodeContext* context) { auto outputs = fw_node->outputs(); // update writes to input 0, so we need to replace this input with output from update context->mutate_input(0, outputs.back()); - std::cerr << "[ WARNING ] Created node with mutated 0 input. Schema: " << schema << std::endl; + //std::cerr << "[ WARNING ] Created node with mutated 0 input. Schema: " << schema << std::endl; context->get_decoder()->mark_node(fw_node); return outputs; } @@ -217,8 +178,6 @@ OutputVector convert_node(NodeContext* context) { std::shared_ptr convert_pytorch_model(std::shared_ptr pytorch_model, const TensorMap& external_tensor_map) { - LEVEL++; - // std::cout << "=====Convert model:" << LEVEL << " start=====" << std::endl; std::shared_ptr resulting_model; // define here to make a conversion in a nested scope { ParameterVector parameters; @@ -227,15 +186,11 @@ std::shared_ptr convert_pytorch_model(std::shared_ptr pytorc // Go over all pytorch_model inputs and register them in the tensor map: auto inputs = pytorch_model->inputs(); - // std::cout << "[ --- DEBUG --- ] convert_pytorch_model: number of inputs: " << inputs.size() << '\n'; for (int i = 0; i < inputs.size(); ++i) { - // std::cout << "Input: " << i << ": " << inputs[i] << "\n"; PartialShape ps = pytorch_model->get_input_shape(i); - // std::cout << "PartialShape = " << ps << "\n"; auto parameter = std::make_shared(ov::element::custom, pytorch_model->get_input_type(i), ps); parameter->get_output_tensor(0).add_names({std::to_string(pytorch_model->input(i))}); - // std::cout << "Parameter: " << parameter << "\n"; parameters.push_back(parameter); auto order = pytorch_model->get_input_transpose_order(i); if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) { @@ -253,23 +208,18 @@ std::shared_ptr convert_pytorch_model(std::shared_ptr pytorc } else { tensor_map[pytorch_model->input(i)] = parameter; } - // std::cout << "Level:" << LEVEL << " Added model input: " << tensor_map[pytorch_model->input(i)] << - // std::endl; } auto node_visitor = [&](std::shared_ptr node) { - // std::cerr << "Node convert start" << std::endl; - // Explore all inputs of node. Node may refer to input value that hasn't been created in the current scope. // But this value can be found in the outer scope, for this purpose we need to search node in // external_tensor_map as well + //std::cout << "Node visitor start: " << node->get_op_type() << ", schema: " << node->get_schema() << std::endl; auto raw_inputs = node->inputs(); for (size_t i = 0; i < raw_inputs.size(); ++i) { auto input = node->input(i); if (tensor_map.find(input) == tensor_map.end()) { - // std::cout << "Level:" << LEVEL << " Trampoline for input index " << i << " with value " << input - // << "\n"; // input refers value in the outer scope, need to create a new Parameter in the current scope // TODO: Connect outer scope and inner scope properly -- should be handled at the level of that // operation that introduced this nest of scopes (e.g. loop or if) @@ -280,20 +230,13 @@ std::shared_ptr convert_pytorch_model(std::shared_ptr pytorc auto parameter = std::make_shared(node->get_input_type(i), ps); // TODO: Missing get_input_transpose_order handling for not trivial layouts tensor_map[input] = parameter; - // std::cout << "Parameter created\n"; // set name of parameter to the index of node in the model parameter->get_output_tensor(0).add_names({std::to_string(input)}); parameters.push_back(parameter); - // std::cout << "External tensor: " << input << " node: " << external_tensor_map.at(input) << - // std::endl; } } - // std::cerr << "Node convert before translator: " << node->get_op_type() << ", schema: " << - // node->get_schema() << std::endl; - auto context = NodeContext(node, &tensor_map, ¶meters, external_tensor_map); auto converted_outputs = convert_node(&context); - // std::cerr << "Node convert before outputs" << std::endl; auto mutated_t = context.get_mutated_tensors(); mutated_tensors.insert(mutated_t.begin(), mutated_t.end()); @@ -310,37 +253,24 @@ std::shared_ptr convert_pytorch_model(std::shared_ptr pytorc for (size_t i = 0; i < fw_outputs.size(); ++i) { size_t fw_tensor_id = node->output(i); if (tensor_map.find(fw_tensor_id) != tensor_map.end()) { - // std::cerr << "Duplicated producer for tensor with id = " << fw_tensor_id << " discovered at - // output " - // << "port " << i << " of node " << node->kind().toQualString() << "\n"; throw std::runtime_error("Duplicated producer for PT value with unique ID: " + std::to_string(fw_tensor_id)); } // Output shape of converted node should match the original output shape - // std::cerr << "[ DEBUG ] PT output shape = " << get_ov_shape(fw_outputs[i]) << '\n'; - // std::cerr << "[ DEBUG ] OV output shape = " << converted_outputs[i].get_partial_shape() << '\n'; // OV_FRONTEND_REQUIRE(get_ov_shape(fw_outputs[i]) == converted_outputs[i].get_partial_shape()); tensor_map[fw_tensor_id] = converted_outputs[i]; converted_outputs[i].get_tensor().add_names({std::to_string(fw_tensor_id)}); - // std::cout << "Level:" << LEVEL << " Added node: " << converted_outputs[i] << std::endl; - // std::cout << "Converted node output " << fw_tensor_id << ": " << converted_outputs[i] << std::endl; } - // std::cout << "Node convert end" << std::endl; }; OV_FRONTEND_REQUIRE(pytorch_model->get_subgraph_size() == 1); pytorch_model->visit_subgraph(0, node_visitor); - // std::cout << "All nodes convert end" << std::endl; ResultVector results; - // std::cerr << "Outputs:" << pytorch_model->num_of_outputs() << "\n"; for (size_t i = 0; i < pytorch_model->num_of_outputs(); ++i) { size_t id = pytorch_model->output(i); - // std::cerr << "Output:" << i << ": " << id << "\n"; - // std::cout << "value = " << id << '\n'; - // std::cout << "X\n"; if (tensor_map.find(id) == tensor_map.end()) { // Not found in this scope, searching in the outer scope // TODO: do real search here, skipped for now @@ -349,23 +279,16 @@ std::shared_ptr convert_pytorch_model(std::shared_ptr pytorc parameter->get_output_tensor(0).add_names({std::to_string(id)}); parameters.push_back(parameter); tensor_map[id] = parameter; - // std::cout << "Level:" << LEVEL << "Added new parameter based on external value " << id << "\n"; } auto ov_output = tensor_map[id]; - // std::cout << "X\n"; auto order = pytorch_model->get_output_transpose_order(i); - // std::cout << "X\n"; if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) { throw "Output strides have wrong order."; } // TODO: remove when all nodes has ids ov_output.add_names({std::to_string(id)}); - // std::cout << "X\n"; - // std::cout << ov_output << '\n'; auto result = std::make_shared(ov_output); - // std::cout << "X\n"; results.push_back(result); - // std::cerr << "Model result " << result << "\n"; } // Since parameters can be added we need to list all current parameters @@ -382,39 +305,25 @@ std::shared_ptr convert_pytorch_model(std::shared_ptr pytorc results.push_back(std::make_shared(tensor_map.at(tensor))); } } - // std::cout << "Y\n"; - - /*for (size_t i = 0; i < parameters.size(); ++i) { - auto parameter = parameters[i]; - // std::cerr << "parameter[" << i << "].shape = " - // << parameter->get_output_shape(0) << ", consumers: " << - // parameter->output(0).get_target_inputs().size() << "\n"; - }*/ - // std::cout << "Convert end" << std::endl; - // std::cout << "Number of values collected: " << tensor_map.size() << "\n"; - - // std::cout << "=====Construct model start=====" << std::endl; - /*std::cout << "=====Tensor map start=====" << std::endl; - for (auto node : tensor_map) { - std::cout << node.first << ": " << node.second.get_node_shared_ptr() << std::endl; - }*/ resulting_model = std::make_shared(results, parameters); - /*std::string m_name = "model_" + std::to_string(LEVEL) + "_" + std::to_string(NUMBER++); - try { - ov::serialize(resulting_model, m_name + ".xml", m_name + ".bin"); - } catch (...) { - std::cout << "Exception happened during model serialization: " + m_name << std::endl; - }*/ - // std::cout << "=====Construct model end=====" << std::endl; - // Did a conversion in a nested scope to automatically remove any holders of nodes except those in the graph } - // std::cout << "=====Convert model:" << LEVEL << " end=====" << std::endl; - LEVEL--; return resulting_model; } +std::shared_ptr cast_fw_node(std::shared_ptr node, const std::string& type) { + auto fw_node = std::dynamic_pointer_cast(node); + if (!fw_node) { + return nullptr; + } + const auto& attrs = fw_node->get_attrs(); + if (attrs.find("PtTypeName") == attrs.end() || attrs.at("PtTypeName") != type) { + return nullptr; + } + return fw_node; +} + } // namespace pytorch } // namespace frontend } // namespace ov diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 125ca93e3bc05c..b74030b150fe74 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -5,6 +5,13 @@ #include "openvino/frontend/pytorch/node_context.hpp" namespace ov { + +namespace op { +namespace util { +class FrameworkNode; +} +} // namespace op + namespace frontend { namespace pytorch { @@ -33,6 +40,7 @@ OutputVector inplace_op(NodeContext& context) { context.mutate_input(idx, translation_res[0]); return translation_res; } +std::shared_ptr cast_fw_node(std::shared_ptr node, const std::string& type); } // namespace pytorch } // namespace frontend diff --git a/tests/layer_tests/pytorch_tests/conftest.py b/tests/layer_tests/pytorch_tests/conftest.py new file mode 100644 index 00000000000000..a8595e4a802bb3 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/conftest.py @@ -0,0 +1,12 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import inspect + +from pytorch_layer_test_class import get_params + + +def pytest_generate_tests(metafunc): + test_gen_attrs_names = list(inspect.signature(get_params).parameters) + params = get_params() + metafunc.parametrize(test_gen_attrs_names, params, scope="function") diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py new file mode 100644 index 00000000000000..34446350bffa28 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -0,0 +1,111 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import itertools +import warnings + +import numpy as np +from common.constants import test_device, test_precision + +from openvino.frontend import FrontEndManager +from openvino.frontend.pytorch import TorchScriptPythonDecoder +from openvino.runtime import Core, Type, PartialShape + + +class PytorchLayerTest: + _type_map = { + "float64": Type.f64, + "float32": Type.f32, + "int32": Type.i32 + } + + def _test(self, model, ref_net, ie_device, precision, ir_version, infer_timeout=60, **kwargs): + """ + :param enabled_transforms/disabled_transforms: string with idxs of transforms that should be enabled/disabled. + Example: "transform_1,transform_2" + """ + import torch + with torch.no_grad(): + model.eval() + model = torch.jit.freeze(torch.jit.script(model)) + + fe_manager = FrontEndManager() + fe = fe_manager.load_by_framework('pytorch') + + decoder = TorchScriptPythonDecoder(model.inlined_graph) + + im = fe.load(decoder) + om = fe.convert(im) + + if 'kwargs_to_prepare_input' in kwargs and kwargs['kwargs_to_prepare_input']: + inputs = self._prepare_input(kwargs['kwargs_to_prepare_input']) + else: + inputs = self._prepare_input() + + params = om.get_parameters() + # todo: support lists and dicts + for i in range(len(inputs)): + inp = inputs[i] + assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}." + params[i].set_element_type(self._type_map[inp.dtype.name]) + dyn_shape = [-1] * len(inp.shape) + params[i].set_partial_shape(PartialShape(dyn_shape)) + om.validate_nodes_and_infer_types() + + # OV infer: + core = Core() + compiled = core.compile_model(om, ie_device) + infer_res = compiled(inputs) + + if hasattr(self, 'skip_framework') and self.skip_framework: + warnings.warn('Framework is skipped') + return + + # Framework infer: + torch_inps = [torch.from_numpy(inp) for inp in inputs] + fw_res = model(*torch_inps) + if not isinstance(fw_res, tuple): + fw_res = (fw_res,) + + if 'custom_eps' in kwargs and kwargs['custom_eps'] is not None: + custom_eps = kwargs['custom_eps'] + else: + custom_eps = 1e-4 + + # Compare Ie results with Framework results + fw_eps = custom_eps if precision == 'FP32' else 5e-2 + for i in range(len(infer_res)): + cur_fw_res = fw_res[i].numpy() + cur_ov_res = infer_res[compiled.output(i)] + print(f"fw_re: {cur_fw_res}; ov_res: {cur_ov_res}") + if not np.allclose(cur_ov_res, cur_fw_res, + atol=fw_eps, + rtol=fw_eps): + is_ok = False + print("Max diff is {}".format( + np.array( + abs(cur_ov_res - cur_fw_res)).max())) + else: + print("Accuracy validation successful!\n") + print("absolute eps: {}, relative eps: {}".format(fw_eps, fw_eps)) + + # Each model should specify inputs + def _prepare_input(self): + raise RuntimeError("Please provide inputs generation function") + + +def get_params(ie_device=None, precision=None): + """ + :param ie_device: list of devices + :param precision: list of precisions + """ + + ie_device_params = ie_device if ie_device else test_device + precision_params = precision if precision else test_precision + + test_args = [] + for element in itertools.product(ie_device_params, precision_params): + if element[0] == 'CPU' and element[1] == 'FP16': + continue + test_args.append(element) + return test_args diff --git a/tests/layer_tests/pytorch_tests/test_relu.py b/tests/layer_tests/pytorch_tests/test_relu.py new file mode 100644 index 00000000000000..ef2e2cca56cc35 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_relu.py @@ -0,0 +1,31 @@ +# Copyright (C) 2018-2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from pytorch_layer_test_class import PytorchLayerTest + + +class TestRelu(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(1, 3, 224, 224).astype(np.float32),) + + def create_model(self): + + import torch + import torch.nn.functional as F + + class aten_relu(torch.nn.Module): + def __init__(self): + super(aten_relu, self).__init__() + + def forward(self, x): + return F.relu(x) + + ref_net = None + + return aten_relu(), ref_net + + @pytest.mark.nightly + def test_relu(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version) From 7b4c43fc784c7dc1e73123e7943d3a607cb82485 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Tue, 4 Oct 2022 15:34:14 +0200 Subject: [PATCH 2/2] Fix build --- src/frontends/pytorch/src/op_table.cpp | 22 ++++++++++------------ src/frontends/pytorch/src/transforms.cpp | 17 ++--------------- 2 files changed, 12 insertions(+), 27 deletions(-) diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 66460b0e6a9193..08493b5f66efb6 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -208,28 +208,26 @@ const std::map get_supported_ops() { ))}; }}, - /*{"aten::layer_norm", + {"aten::layer_norm", [](NodeContext& context) -> OutputVector { auto normalized_shape = context.const_input(1); - auto in_pshape_last_dim = *context.get_input(0).get_partial_shape().rbegin(); - OV_FRONTEND_REQUIRE(normalized_shape.size() == 1 && in_pshape_last_dim.is_static() && - static_cast(in_pshape_last_dim.get_length()) == normalized_shape.back()); + // TODO: do we need this check? + //auto in_pshape_last_dim = *context.get_input(0).get_partial_shape().rbegin(); + //OV_FRONTEND_REQUIRE(normalized_shape.size() == 1 && in_pshape_last_dim.is_static() && + // static_cast(in_pshape_last_dim.get_length()) == normalized_shape.back()); auto eps = context.const_input(4); auto axes = context.mark_node( opset8::Constant::create(element::i64, Shape{1}, {-1})); // TODO: support any dimention - auto mvn = context.mark_node( + auto out_node = context.mark_node( std::make_shared(context.get_input(0), axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT)); - std::shared_ptr out_node = std::dynamic_pointer_cast(mvn); if (!context.input_is_none(2)) { - auto mul = std::make_shared(out_node, context.get_input(2)); - out_node = std::dynamic_pointer_cast(mul); + out_node = context.mark_node(std::make_shared(out_node, context.get_input(2))); } if (!context.input_is_none(3)) { - auto add = std::make_shared(out_node, context.get_input(3)); - out_node = std::dynamic_pointer_cast(add); + out_node = context.mark_node(std::make_shared(out_node, context.get_input(3))); } - return {context.mark_node(out_node)}; - }},*/ + return {out_node}; + }}, {"aten::add", op::add}, {"aten::add_", inplace_op}, diff --git a/src/frontends/pytorch/src/transforms.cpp b/src/frontends/pytorch/src/transforms.cpp index 082ae189dcfbeb..77e05266476196 100644 --- a/src/frontends/pytorch/src/transforms.cpp +++ b/src/frontends/pytorch/src/transforms.cpp @@ -7,6 +7,8 @@ #include #include "transforms.hpp" +#include "utils.hpp" + namespace ov { namespace frontend { @@ -54,21 +56,6 @@ std::shared_ptr make_list_pack (const OutputVector& inputs, Any o } -std::shared_ptr cast_fw_node(std::shared_ptr node, const std::string& type) { - auto fw_node = std::dynamic_pointer_cast(node); - if(!fw_node) { - std::cerr << "[ ERROR ] Incorrect matcher triggering\n"; - return nullptr; - } - const auto& attrs = fw_node->get_attrs(); - if(attrs.find("PtTypeName") == attrs.end() || attrs.at("PtTypeName") != type) { - return nullptr; - } - return fw_node; -} - - - std::shared_ptr cast_internal_node(std::shared_ptr node, const std::string& type) { auto fw_node = std::dynamic_pointer_cast(node); if(!fw_node) {