Skip to content

Commit

Permalink
SerializationNode: added SerializationMode
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Dec 1, 2023
1 parent 6be3e88 commit 63ef228
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ namespace op {
*/
class SerializationNode : public ov::op::Op {
public:
enum SerializationMode { DATA_FLOW, CONTROL_FLOW };
SerializationNode() = default;
SerializationNode(const ov::OutputVector& args, const std::shared_ptr<lowered::Expression>& expr);
SerializationNode(const ov::OutputVector& args,
const std::shared_ptr<lowered::Expression>& expr,
SerializationMode mode = SerializationMode::CONTROL_FLOW);

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector &new_args) const override;
Expand All @@ -37,6 +40,7 @@ class SerializationNode : public ov::op::Op {

private:
std::shared_ptr<lowered::Expression> m_expr;
SerializationMode m_mode;
};

} // namespace op
Expand Down
5 changes: 3 additions & 2 deletions src/common/snippets/src/lowered/pass/serialize_data_flow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ bool SerializeDataFlow::run(LinearIR& linear_ir) {
ov::ResultVector results;
ov::ParameterVector parameters;
std::map<ExpressionPtr, std::shared_ptr<Node>> ops_map;
const auto serialization_mode = op::SerializationNode::SerializationMode::DATA_FLOW;
for (const auto& expr : linear_ir) {
const auto node = expr->get_node();
ov::OutputVector inputs(expr->get_input_count());
for (size_t i = 0; i < expr->get_input_count(); ++i) {
const auto& input_expr = expr->get_input_port_connector(i)->get_source().get_expr();
OPENVINO_ASSERT(ops_map.count(input_expr), "input node wasn't found during serialization");
inputs[i] = ops_map[input_expr]->output(0);
inputs[i] = ops_map[input_expr]->output(expr->get_input_port_connector(i)->get_source().get_index());
}
if (auto ioexpr = std::dynamic_pointer_cast<IOExpression>(expr)) {
if (ioexpr->get_type() == IOExpression::io_type::INPUT) {
Expand All @@ -42,7 +43,7 @@ bool SerializeDataFlow::run(LinearIR& linear_ir) {
results.push_back(result);
}
} else {
const auto serialization_node = std::make_shared<op::SerializationNode>(inputs, expr);
const auto serialization_node = std::make_shared<op::SerializationNode>(inputs, expr, serialization_mode);
ops_map[expr] = serialization_node;
}
}
Expand Down
19 changes: 15 additions & 4 deletions src/common/snippets/src/op/serialization_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ namespace ov {
namespace snippets {
namespace op {

SerializationNode::SerializationNode(const ov::OutputVector& args, const std::shared_ptr<lowered::Expression>& expr)
: Op(args), m_expr(expr) {
SerializationNode::SerializationNode(const ov::OutputVector& args,
const std::shared_ptr<lowered::Expression>& expr,
SerializationMode mode)
: Op(args),
m_expr(expr),
m_mode(mode) {
OPENVINO_ASSERT(m_expr && m_expr->get_node(), "SerializationNode requires a valid expression with non-null node pointer");
const auto& node = expr->get_node();
set_friendly_name(node->get_friendly_name());
Expand All @@ -20,12 +24,19 @@ SerializationNode::SerializationNode(const ov::OutputVector& args, const std::sh
}

void SerializationNode::validate_and_infer_types() {
set_output_type(0, element::f32, {});
// If SerializationNode is used for control flow serialization, it always has one output
// (since it represents a linear execution order)
if (m_mode == SerializationMode::CONTROL_FLOW) {
set_output_type(0, element::f32, {});
} else if (m_mode == SerializationMode::DATA_FLOW) {
for (size_t i = 0; i < m_expr->get_output_count(); ++i)
set_output_type(i, element::f32, {});
}
}

std::shared_ptr<Node> SerializationNode::clone_with_new_inputs(const OutputVector &new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<SerializationNode>(new_args, m_expr);
return std::make_shared<SerializationNode>(new_args, m_expr, m_mode);
}

bool SerializationNode::visit_attributes(AttributeVisitor &visitor) {
Expand Down

0 comments on commit 63ef228

Please sign in to comment.