Skip to content

Commit

Permalink
Refactor Relu6 and Placeholder
Browse files Browse the repository at this point in the history
Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants committed Nov 12, 2023
1 parent 761fe96 commit d095f36
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
16 changes: 10 additions & 6 deletions src/frontends/tensorflow_common/src/op/placeholder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,39 @@
//

#include "common_op_table.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/op/parameter.hpp"

using namespace std;
using namespace ov::opset8;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {

OutputVector translate_placeholder_op(const NodeContext& node) {
auto dtype = node.get_attribute<ov::element::Type>("dtype");
auto shape = node.get_attribute<ov::PartialShape>("shape", ov::PartialShape::dynamic());
default_op_checks(node, 0, {});

auto dtype = node.get_attribute<element::Type>("dtype");
auto shape = node.get_attribute<PartialShape>("shape", PartialShape::dynamic());
if (shape.rank().is_static() && shape.rank().get_length() == 0 && node.has_attribute("_output_shapes")) {
// we know some cases when Placeholder operation has empty scalar `shape` attribute value
// and non-empty `_output_shapes` attribute value.
// `_output_shapes` attribute value turns to be correct in this case
auto output_shapes = node.get_attribute<std::vector<ov::PartialShape>>("_output_shapes");
auto output_shapes = node.get_attribute<vector<PartialShape>>("_output_shapes");
if (output_shapes.size() == 1 && output_shapes[0].rank().is_static()) {
shape = output_shapes[0];
}
}

auto res = std::make_shared<Parameter>(dtype, shape);
auto res = make_shared<v0::Parameter>(dtype, shape);
set_node_name(node.get_name(), res);
return res->outputs();
}

OutputVector translate_placeholder_with_default_op(const NodeContext& node) {
default_op_checks(node, 0, {});

// For parity with legacy frontend, it creates a constant node with the default value
// As a rule, PlaceholderWithDefault is mainly used for is_training variables in the model
TENSORFLOW_OP_VALIDATION(node,
Expand Down
10 changes: 6 additions & 4 deletions src/frontends/tensorflow_common/src/op/relu_6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@
//

#include "common_op_table.hpp"
#include "openvino/opsets/opset8.hpp"
#include "openvino/op/clamp.hpp"

using namespace std;
using namespace ov::opset8;
using namespace ov::op;

namespace ov {
namespace frontend {
namespace tensorflow {
namespace op {
ov::OutputVector translate_relu_6_op(const NodeContext& node) {
OutputVector translate_relu_6_op(const NodeContext& node) {
default_op_checks(node, 1, {"Relu6", "RELU6"});

auto data = node.get_input(0);
auto res = std::make_shared<Clamp>(data, 0.0, 6.0f);
auto res = make_shared<v0::Clamp>(data, 0.0, 6.0f);
set_node_name(node.get_name(), res);
return res->outputs();
}
Expand Down

0 comments on commit d095f36

Please sign in to comment.