Skip to content

Commit

Permalink
Fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Oct 4, 2022
1 parent 11ffccd commit 7b4c43f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 27 deletions.
22 changes: 10 additions & 12 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,28 +208,26 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
))};
}},

/*{"aten::layer_norm",
{"aten::layer_norm",
[](NodeContext& context) -> OutputVector {
auto normalized_shape = context.const_input<Shape>(1);
auto in_pshape_last_dim = *context.get_input(0).get_partial_shape().rbegin();
OV_FRONTEND_REQUIRE(normalized_shape.size() == 1 && in_pshape_last_dim.is_static() &&
static_cast<uint64_t>(in_pshape_last_dim.get_length()) == normalized_shape.back());
// TODO: do we need this check?
//auto in_pshape_last_dim = *context.get_input(0).get_partial_shape().rbegin();
//OV_FRONTEND_REQUIRE(normalized_shape.size() == 1 && in_pshape_last_dim.is_static() &&
// static_cast<uint64_t>(in_pshape_last_dim.get_length()) == normalized_shape.back());
auto eps = context.const_input<float>(4);
auto axes = context.mark_node(
opset8::Constant::create(element::i64, Shape{1}, {-1})); // TODO: support any dimention
auto mvn = context.mark_node(
auto out_node = context.mark_node(
std::make_shared<opset8::MVN>(context.get_input(0), axes, true, eps, ov::op::MVNEpsMode::INSIDE_SQRT));
std::shared_ptr<ov::Node> out_node = std::dynamic_pointer_cast<ov::Node>(mvn);
if (!context.input_is_none(2)) {
auto mul = std::make_shared<opset8::Multiply>(out_node, context.get_input(2));
out_node = std::dynamic_pointer_cast<ov::Node>(mul);
out_node = context.mark_node(std::make_shared<opset8::Multiply>(out_node, context.get_input(2)));
}
if (!context.input_is_none(3)) {
auto add = std::make_shared<opset8::Add>(out_node, context.get_input(3));
out_node = std::dynamic_pointer_cast<ov::Node>(add);
out_node = context.mark_node(std::make_shared<opset8::Add>(out_node, context.get_input(3)));
}
return {context.mark_node(out_node)};
}},*/
return {out_node};
}},

{"aten::add", op::add},
{"aten::add_", inplace_op<op::add>},
Expand Down
17 changes: 2 additions & 15 deletions src/frontends/pytorch/src/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <openvino/frontend/pytorch/decoder.hpp>

#include "transforms.hpp"
#include "utils.hpp"


namespace ov {
namespace frontend {
Expand Down Expand Up @@ -54,21 +56,6 @@ std::shared_ptr<FrameworkNode> make_list_pack (const OutputVector& inputs, Any o
}


std::shared_ptr<FrameworkNode> cast_fw_node(std::shared_ptr<Node> node, const std::string& type) {
auto fw_node = std::dynamic_pointer_cast<FrameworkNode>(node);
if(!fw_node) {
std::cerr << "[ ERROR ] Incorrect matcher triggering\n";
return nullptr;
}
const auto& attrs = fw_node->get_attrs();
if(attrs.find("PtTypeName") == attrs.end() || attrs.at("PtTypeName") != type) {
return nullptr;
}
return fw_node;
}



std::shared_ptr<FrameworkNode> cast_internal_node(std::shared_ptr<Node> node, const std::string& type) {
auto fw_node = std::dynamic_pointer_cast<FrameworkNode>(node);
if(!fw_node) {
Expand Down

0 comments on commit 7b4c43f

Please sign in to comment.