Skip to content

Commit d6835cb

Browse files
committed
Softmax support
1 parent afc5f4c commit d6835cb

File tree

8 files changed

+71
-86
lines changed

8 files changed

+71
-86
lines changed

src/common/snippets/include/snippets/lowered/linear_ir.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class LinearIR {
110110

111111
void init_emitters(const std::shared_ptr<TargetMachine>& target);
112112
void serialize(const std::string& xml, const std::string& bin) const;
113+
void serialize2(const std::string& xml, const std::string& bin) const;
113114

114115
class LoopManager;
115116
using LoopManagerPtr = std::shared_ptr<LoopManager>;

src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,16 @@ class ZeroFinalizationOffsets : public pass::SubgraphPass {
5959
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;
6060
};
6161

62-
class InsertFill : public pass::SubgraphPass {
62+
class SetFillOffset : public pass::SubgraphPass {
6363
public:
64-
InsertFill(size_t tail_size);
65-
OPENVINO_RTTI("InsertFill", "Pass")
64+
SetFillOffset(size_t offset);
65+
OPENVINO_RTTI("SetFillOffset", "Pass")
6666
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;
6767

6868
private:
69-
size_t m_tail_size;
69+
size_t m_offset;
7070
};
71+
7172
} // namespace pass
7273
} // namespace lowered
7374
} // namespace snippets

src/common/snippets/src/generator.cpp

+2-10
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,6 @@ void Generator::generate(lowered::LinearIR& linear_ir, LoweringResult& result, c
2626
std::function<opRegType(const std::shared_ptr<Node>& op)> reg_type_mapper = [&](const std::shared_ptr<Node>& op) -> opRegType {
2727
return get_op_reg_type(op);
2828
};
29-
lowered::pass::PassPipeline pre_pipeline;
30-
pre_pipeline.register_pass<lowered::pass::AssignRegisters>(reg_type_mapper);
31-
pre_pipeline.run(linear_ir);
32-
33-
// auto clone = *linear_ir.clone();
34-
// lowered::pass::PassPipeline reference_pipeline;
35-
// reference_pipeline.register_pass<lowered::pass::InsertTailLoop>();
36-
// reference_pipeline.run(clone);
37-
// clone.serialize("/home/vgolubev/models/specific_iteration_reference.xml", "");
3829

3930
lowered::pass::PassPipeline lowered_pipeline;
4031
// Note: the order of all passes in this pipeline must not be changed since they have hard dependencies
@@ -46,6 +37,7 @@ void Generator::generate(lowered::LinearIR& linear_ir, LoweringResult& result, c
4637
// since CleanupLoopOffsets can't handle loops with evaluate_once = true
4738
lowered_pipeline.register_pass<lowered::pass::AssignRegisters>(reg_type_mapper);
4839
lowered_pipeline.run(linear_ir);
40+
linear_ir.serialize2("/home/vgolubev/models/test.xml", "/dev/null");
4941

5042
// lowered::pass::PassPipeline reference_pipeline;
5143
// reference_pipeline.register_pass<lowered::pass::InsertTailLoop>();
@@ -60,7 +52,7 @@ void Generator::generate(lowered::LinearIR& linear_ir, LoweringResult& result, c
6052
target_pipeline.register_pass<lowered::pass::CleanupLoopOffsets>();
6153
target_pipeline.register_pass<lowered::pass::OptimizeLoopSingleEvaluation>();
6254
target_pipeline.run(linear_ir);
63-
linear_ir.serialize("/home/vgolubev/models/specific_iteration.xml", "");
55+
linear_ir.serialize("/home/vgolubev/models/specific_iteration.xml", "/dev/null");
6456
linear_ir.init_emitters(target);
6557

6658
OV_ITT_TASK_NEXT(GENERATE, "::EmitCode")

src/common/snippets/src/lowered/linear_ir.cpp

+42
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,48 @@ void LinearIR::serialize(const std::string& xml, const std::string& bin) const {
118118
ov::pass::Serialize(xml, bin).run_on_model(tmp_model);
119119
}
120120

121+
void LinearIR::serialize2(const std::string& xml, const std::string& bin) const {
122+
ov::ParameterVector parameters;
123+
std::map<ExpressionPtr, std::shared_ptr<Node>> ops_map;
124+
for (const auto& ioexpr : m_io_expressions) {
125+
if (ioexpr->get_type() == IOExpression::io_type::INPUT) {
126+
const auto parameter = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{});
127+
ops_map[ioexpr] = parameter;
128+
parameters.push_back(parameter);
129+
}
130+
}
131+
132+
for (const auto& expr : m_expressions) {
133+
if (std::dynamic_pointer_cast<IOExpression>(expr))
134+
continue;
135+
136+
const auto node = expr->get_node();
137+
ov::OutputVector inputs(expr->get_input_count());
138+
for (size_t i = 0; i < expr->get_input_count(); ++i) {
139+
const auto& input_expr = expr->get_input_port_connector(i);
140+
inputs[i] = ops_map[input_expr->get_source().get_expr()]->output(0);
141+
}
142+
const auto serialization_node = std::make_shared<op::SerializationNode>(inputs, expr);
143+
ops_map[expr] = serialization_node;
144+
}
145+
146+
ov::ResultVector results;
147+
for (const auto& ioexpr : m_io_expressions) {
148+
if (ioexpr->get_type() == IOExpression::io_type::OUTPUT) {
149+
ov::OutputVector inputs(ioexpr->get_input_count());
150+
for (size_t i = 0; i < ioexpr->get_input_count(); ++i) {
151+
const auto& input_expr = ioexpr->get_input_port_connector(i);
152+
inputs[i] = ops_map[input_expr->get_source().get_expr()]->output(0);
153+
}
154+
const auto result = std::make_shared<ov::op::v0::Result>(inputs[0]);
155+
ops_map[ioexpr] = result;
156+
results.push_back(result);
157+
}
158+
}
159+
const auto tmp_model = std::make_shared<ov::Model>(results, parameters, "Lowered_IR_Serialization");
160+
ov::pass::Serialize(xml, bin).run_on_model(tmp_model);
161+
}
162+
121163
LinearIR::container LinearIR::deep_copy_range(LinearIR::container::const_iterator begin,
122164
LinearIR::container::const_iterator end,
123165
ExressionMap& expression_map) {

src/common/snippets/src/lowered/pass/assign_registers.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
8080
for (const auto& tensor : input_expr_input_tensors) {
8181
const auto parent_expr = tensor->get_source().get_expr();
8282
if (ov::is_type<op::Fill>(parent_expr->get_node())) {
83-
manually_assigned_vecs[tensor] = static_cast<Reg>(accumulator_reg);
8483
if (ov::is_type<op::VectorBuffer>(parent_expr->get_input_port_connector(0)->get_source().get_expr()->get_node())) {
84+
manually_assigned_vecs[tensor] = static_cast<Reg>(accumulator_reg);
8585
manually_assigned_vecs[parent_expr->get_input_port_connector(0)] = static_cast<Reg>(accumulator_reg);
86-
}
86+
}
8787
}
8888
}
8989
const auto& output_tensor = expr->get_output_port_connector(0);

src/common/snippets/src/lowered/pass/iter_handler.cpp

+5-50
Original file line numberDiff line numberDiff line change
@@ -100,58 +100,13 @@ bool ZeroFinalizationOffsets::run(const LinearIR& linear_ir, LinearIR::constExpr
100100
return true;
101101
}
102102

103-
InsertFill::InsertFill(size_t tail_size) : m_tail_size(tail_size) {}
104-
105-
bool InsertFill::run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
106-
const auto& config = linear_ir.get_config();
107-
if (!config.m_need_fill_tail_register)
108-
return false;
109-
110-
auto insertFill = [&](const ov::Input<ov::Node>& input) -> std::shared_ptr<ov::Node> {
111-
std::shared_ptr<ov::Node> fill = nullptr;
112-
auto& rt = input.get_rt_info();
113-
auto fill_rt = rt.find("set_fill");
114-
if (fill_rt != rt.end()) {
115-
const auto fill_value = fill_rt->second.as<uint32_t>();
116-
fill = std::make_shared<ov::snippets::op::Fill>(input.get_source_output(), m_tail_size, fill_value);
117-
input.get_node()->set_argument(input.get_index(), fill);
118-
}
119-
return fill;
120-
};
103+
SetFillOffset::SetFillOffset(size_t offset) : SubgraphPass(), m_offset(offset) {}
121104

105+
bool SetFillOffset::run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
122106
for (auto expr_it = std::next(begin); expr_it != end; expr_it++) {
123-
const auto expr = expr_it->get();
124-
const auto op = expr->get_node();
125-
// Skip inner Loops
126-
const auto loop_begin = ov::as_type_ptr<op::LoopBegin>(op);
127-
if (loop_begin) {
128-
expr_it = linear_ir.find(expr_it, end, linear_ir.get_expr_by_node(loop_begin->get_loop_end()));
129-
continue;
130-
}
131-
132-
auto casted_linear_ir = const_cast<LinearIR&>(linear_ir);
133-
if (ov::is_type<ov::op::v1::Maximum>(op) || ov::is_type<ov::op::v1::Add>(op)) {
134-
for (size_t i = 0; i < op->inputs().size(); ++i) {
135-
if (auto fill = insertFill(op->input(i))) {
136-
const auto& input = expr->get_input_port_connector(i);
137-
const auto consumers = input->get_consumers();
138-
// If there are several consumers, fill expression must be inserted before first of them
139-
auto fst_consumer = std::min_element(consumers.cbegin(), consumers.cend(), [&](ExpressionPort lhs, ExpressionPort rhs) {
140-
auto lhs_it = casted_linear_ir.find(lhs.get_expr());
141-
auto rhs_it = casted_linear_ir.find(rhs.get_expr());
142-
return std::distance(casted_linear_ir.cbegin(), lhs_it) < std::distance(casted_linear_ir.cbegin(), rhs_it);
143-
});
144-
const auto insert_pos = casted_linear_ir.find(fst_consumer->get_expr());
145-
auto fill_expr = casted_linear_ir.create_expression(fill, {input});
146-
casted_linear_ir.insert(insert_pos, fill_expr);
147-
casted_linear_ir.replace_input(consumers, fill_expr->get_output_port_connector(0));
148-
// in_reg == out_reg since we want to modify vector reg inplace
149-
const auto reg = expr->get_input_port_descriptor(0)->get_reg();
150-
fill_expr->get_input_port_descriptor(0)->set_reg(reg);
151-
fill_expr->get_output_port_descriptor(0)->set_reg(reg);
152-
fill_expr->set_loop_ids(expr->get_loop_ids());
153-
}
154-
}
107+
const auto& node = expr_it->get()->get_node();
108+
if (const auto fill = ov::as_type_ptr<ov::snippets::op::Fill>(node)) {
109+
fill->set_offset(m_offset);
155110
}
156111
}
157112
return true;

src/common/snippets/src/lowered/pass/softmax_decomposition.cpp

+13-20
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,22 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
6161
// Init value of vector buffer for ReduceMax is -FLOAT_MIN.
6262
const auto fill_max = push_node(std::make_shared<op::Fill>(vector_buffer_max.second, 0, float_min_constant));
6363
// ReduceMax loop
64-
const auto& max = push_node(std::make_shared<ov::op::v1::Maximum>(softmax->get_input_source_output(0), fill_max.second));
64+
const auto fill_max_tail = push_node(std::make_shared<op::Fill>(softmax->get_input_source_output(0), m_vector_size, float_min_constant));
65+
66+
const auto& max = push_node(std::make_shared<ov::op::v1::Maximum>(fill_max_tail.second, fill_max.second));
6567

6668
const auto horizon_max = push_node(std::make_shared<op::HorizonMax>(max.second));
6769

6870
// Markup of ReduceMax Loop
69-
const auto reduce_max_loop_id = loop_manager->mark_loop(max.first, horizon_max.first, inner_work_amount, m_vector_size, 0,
70-
std::vector<ExpressionPort>{(*max.first)->get_input_port(0),
71+
const auto reduce_max_loop_id = loop_manager->mark_loop(fill_max_tail.first, horizon_max.first, inner_work_amount, m_vector_size, 0,
72+
std::vector<ExpressionPort>{(*fill_max_tail.first)->get_input_port(0),
7173
(*max.first)->get_input_port(1)},
7274
std::vector<ExpressionPort>{(*max.first)->get_output_port(0)});
7375
const auto& reduce_max_loop_info = loop_manager->get_loop_info(reduce_max_loop_id);
7476
const auto tail_size = inner_work_amount % m_vector_size;
7577
if (tail_size != 0) {
7678
reduce_max_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size);
77-
reduce_max_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<InsertFill>(tail_size);
79+
reduce_max_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<SetFillOffset>(tail_size);
7880
if (inner_work_amount > m_vector_size) {
7981
reduce_max_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ReduceWorkAmount>(tail_size);
8082
reduce_max_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ZeroFinalizationOffsets>();
@@ -88,7 +90,8 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
8890
// Sub + Exp + ReduceSum Loop
8991
const auto sub = push_node(std::make_shared<ov::op::v1::Subtract>(softmax->get_input_source_output(0), broadcast_horizon_max.second));
9092
const auto exp = push_node(std::make_shared<ov::op::v0::Exp>(sub.second));
91-
const auto sum = push_node(std::make_shared<ov::op::v1::Add>(exp.second, fill_sum.second));
93+
const auto fill_sum_tail = push_node(std::make_shared<op::Fill>(exp.second, m_vector_size, zero_constant));
94+
const auto sum = push_node(std::make_shared<ov::op::v1::Add>(fill_sum_tail.second, fill_sum.second));
9295

9396
const auto horizon_sum = push_node(std::make_shared<op::HorizonSum>(sum.second));
9497

@@ -97,12 +100,12 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
97100
std::vector<ExpressionPort>{(*sub.first)->get_input_port(0),
98101
(*sub.first)->get_input_port(1),
99102
(*sum.first)->get_input_port(1)},
100-
std::vector<ExpressionPort>{(*exp.first)->get_output_port(0),
103+
std::vector<ExpressionPort>{(*fill_sum_tail.first)->get_output_port(0),
101104
(*sum.first)->get_output_port(0)});
102105
const auto& reduce_sum_loop_info = loop_manager->get_loop_info(reduce_sum_loop_id);
103106
if (tail_size != 0) {
104-
reduce_max_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<InsertFill>(tail_size);
105107
reduce_sum_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size);
108+
reduce_sum_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<SetFillOffset>(tail_size);
106109
if (inner_work_amount > m_vector_size) {
107110
reduce_sum_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ReduceWorkAmount>(tail_size);
108111
reduce_sum_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ZeroFinalizationOffsets>();
@@ -114,10 +117,10 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
114117
const auto broadcast_pow = push_node(std::make_shared<op::BroadcastMove>(pow.second, broadcasted_dim));
115118

116119
// Mul (pseudo-Divide loop)
117-
const auto mul = push_node(std::make_shared<ov::op::v1::Multiply>(exp.second, broadcast_pow.second));
120+
const auto mul = push_node(std::make_shared<ov::op::v1::Multiply>(fill_sum_tail.second, broadcast_pow.second));
118121

119122
// Transfer original ExpressionPorts
120-
linear_ir.replace_input((*max.first)->get_input_port(0), input_connector);
123+
linear_ir.replace_input((*fill_max_tail.first)->get_input_port(0), input_connector);
121124
linear_ir.replace_input((*sub.first)->get_input_port(0), input_connector);
122125
linear_ir.replace_input(output_connector->get_consumers(), (*mul.first)->get_output_port_connector(0));
123126

@@ -136,24 +139,14 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
136139
}
137140

138141
// Update Loop info for outer loops
139-
const auto entry_points = std::vector<ExpressionPort>{(*max.first)->get_input_port(0),
142+
const auto entry_points = std::vector<ExpressionPort>{(*fill_max_tail.first)->get_input_port(0),
140143
(*sub.first)->get_input_port(0)};
141144
const auto exit_points = std::vector<ExpressionPort>{(*mul.first)->get_output_port(0)};
142145
for (auto loop_id : softmax_loop_ids) {
143146
loop_manager->expression_replacement(vector_buffer_max.first, expr_it, softmax_expr, loop_id, entry_points, exit_points);
144147
}
145148

146149
expr_it = linear_ir.erase(expr_it); // Remove Softmax
147-
148-
/* =========================================== */
149-
150-
/* ============= Runtime Info ================ */
151-
152-
// For tail loop we should fill input of Max by float min and
153-
// input of Sum by zero to avoid math incorrect calculations
154-
// TODO [111383]: It should be covered via general pipeline (for example, via analyze in InsertTailLoop?)
155-
max.second->input(0).get_rt_info()["set_fill"] = float_min_constant;
156-
sum.second->input(0).get_rt_info()["set_fill"] = zero_constant;
157150
modified = true;
158151
}
159152
}

src/common/snippets/src/op/subgraph.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ void Subgraph::data_flow_transformations(const BlockedShapeVector& blocked_input
410410
manager.register_pass<snippets::pass::PropagatePrecision>(m_generator->get_target_machine());
411411
manager.register_pass<ov::pass::ConstantFolding>();
412412
manager.register_pass<snippets::pass::ConvertConstantsToScalars>();
413+
manager.register_pass<ov::pass::Serialize>("/home/vgolubev/models/data_flow.xml", "");
413414

414415
manager.register_positioned_passes(backend_passes);
415416
manager.run_passes(body_ptr());

0 commit comments

Comments
 (0)