Skip to content

Commit b3bd9be

Browse files
committed
SplitLoops case works
1 parent 3c35699 commit b3bd9be

File tree

11 files changed

+114
-46
lines changed

11 files changed

+114
-46
lines changed

src/common/snippets/include/snippets/lowered/linear_ir.hpp

-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ class LinearIR {
110110

111111
void init_emitters(const std::shared_ptr<TargetMachine>& target);
112112
void serialize(const std::string& xml, const std::string& bin) const;
113-
void serialize2(const std::string& xml, const std::string& bin) const;
114113

115114
class LoopManager;
116115
using LoopManagerPtr = std::shared_ptr<LoopManager>;

src/common/snippets/include/snippets/lowered/loop_manager.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ class LinearIR::LoopManager {
196196
static void fuse_loop_ports(std::vector<LinearIR::LoopManager::LoopPort>& exit_points,
197197
std::vector<LinearIR::LoopManager::LoopPort>& entry_points,
198198
size_t loop_id);
199+
static std::vector<lowered::pass::SubgraphPassPipeline> fuse_loop_handlers(
200+
std::vector<lowered::pass::SubgraphPassPipeline>& lhs,
201+
std::vector<lowered::pass::SubgraphPassPipeline>& rhs);
199202

200203
/* ===== The methods for work with Loop IDs of Expression ===== */
201204
// Notes:

src/common/snippets/include/snippets/lowered/pass/iter_handler.hpp

+10
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@ class SetFillOffset : public pass::SubgraphPass {
6969
size_t m_offset;
7070
};
7171

72+
class TransformInnerSplitLoop : public pass::SubgraphPass {
73+
public:
74+
TransformInnerSplitLoop(size_t tail_size);
75+
OPENVINO_RTTI("TransformInnerSplitLoop", "Pass")
76+
bool run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) override;
77+
78+
private:
79+
size_t m_tail_size;
80+
};
81+
7282
} // namespace pass
7383
} // namespace lowered
7484
} // namespace snippets

src/common/snippets/include/snippets/lowered/pass/pass_pipeline.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class SubgraphPassPipeline {
5454
SubgraphPassPipeline() = default;
5555

5656
void run(const lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) const;
57+
const std::vector<std::shared_ptr<SubgraphPass>>& get_passes() const;
5758
void register_pass(const std::shared_ptr<SubgraphPass>& pass);
5859
bool empty() const { return m_passes.empty(); }
5960

src/common/snippets/src/generator.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ void Generator::generate(lowered::LinearIR& linear_ir, LoweringResult& result, c
3737
// since CleanupLoopOffsets can't handle loops with evaluate_once = true
3838
lowered_pipeline.register_pass<lowered::pass::AssignRegisters>(reg_type_mapper);
3939
lowered_pipeline.run(linear_ir);
40-
linear_ir.serialize2("/home/vgolubev/models/test.xml", "/dev/null");
4140

4241
// lowered::pass::PassPipeline reference_pipeline;
4342
// reference_pipeline.register_pass<lowered::pass::InsertTailLoop>();

src/common/snippets/src/lowered/linear_ir.cpp

-42
Original file line numberDiff line numberDiff line change
@@ -118,48 +118,6 @@ void LinearIR::serialize(const std::string& xml, const std::string& bin) const {
118118
ov::pass::Serialize(xml, bin).run_on_model(tmp_model);
119119
}
120120

121-
void LinearIR::serialize2(const std::string& xml, const std::string& bin) const {
122-
ov::ParameterVector parameters;
123-
std::map<ExpressionPtr, std::shared_ptr<Node>> ops_map;
124-
for (const auto& ioexpr : m_io_expressions) {
125-
if (ioexpr->get_type() == IOExpression::io_type::INPUT) {
126-
const auto parameter = std::make_shared<ov::op::v0::Parameter>(element::f32, Shape{});
127-
ops_map[ioexpr] = parameter;
128-
parameters.push_back(parameter);
129-
}
130-
}
131-
132-
for (const auto& expr : m_expressions) {
133-
if (std::dynamic_pointer_cast<IOExpression>(expr))
134-
continue;
135-
136-
const auto node = expr->get_node();
137-
ov::OutputVector inputs(expr->get_input_count());
138-
for (size_t i = 0; i < expr->get_input_count(); ++i) {
139-
const auto& input_expr = expr->get_input_port_connector(i);
140-
inputs[i] = ops_map[input_expr->get_source().get_expr()]->output(0);
141-
}
142-
const auto serialization_node = std::make_shared<op::SerializationNode>(inputs, expr);
143-
ops_map[expr] = serialization_node;
144-
}
145-
146-
ov::ResultVector results;
147-
for (const auto& ioexpr : m_io_expressions) {
148-
if (ioexpr->get_type() == IOExpression::io_type::OUTPUT) {
149-
ov::OutputVector inputs(ioexpr->get_input_count());
150-
for (size_t i = 0; i < ioexpr->get_input_count(); ++i) {
151-
const auto& input_expr = ioexpr->get_input_port_connector(i);
152-
inputs[i] = ops_map[input_expr->get_source().get_expr()]->output(0);
153-
}
154-
const auto result = std::make_shared<ov::op::v0::Result>(inputs[0]);
155-
ops_map[ioexpr] = result;
156-
results.push_back(result);
157-
}
158-
}
159-
const auto tmp_model = std::make_shared<ov::Model>(results, parameters, "Lowered_IR_Serialization");
160-
ov::pass::Serialize(xml, bin).run_on_model(tmp_model);
161-
}
162-
163121
LinearIR::container LinearIR::deep_copy_range(LinearIR::container::const_iterator begin,
164122
LinearIR::container::const_iterator end,
165123
ExressionMap& expression_map) {

src/common/snippets/src/lowered/loop_manager.cpp

+33
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,8 @@ void LinearIR::LoopManager::fuse_loops(LinearIR::constExprIt loop_begin_target,
421421
loop_info->set_entry_points(new_entries);
422422
loop_info->set_exit_points(new_exits);
423423

424+
loop_info->handlers = fuse_loop_handlers(loop_info_upper->handlers, loop_info_lower->handlers);
425+
424426
const auto& from = fuse_into_upper ? loop_id_lower : loop_id_upper;
425427
const auto& to = fuse_into_upper ? loop_id_upper : loop_id_lower;
426428
for (auto it = loop_begin_target; it != loop_end_target; ++it) {
@@ -431,6 +433,37 @@ void LinearIR::LoopManager::fuse_loops(LinearIR::constExprIt loop_begin_target,
431433
remove_loop_info(from);
432434
}
433435

436+
std::vector<lowered::pass::SubgraphPassPipeline> LinearIR::LoopManager::fuse_loop_handlers(
437+
std::vector<lowered::pass::SubgraphPassPipeline>& lhs,
438+
std::vector<lowered::pass::SubgraphPassPipeline>& rhs) {
439+
auto merge_pass_pipeline = [](const lowered::pass::SubgraphPassPipeline& lhs_pipeline,
440+
const lowered::pass::SubgraphPassPipeline& rhs_pipeline) {
441+
lowered::pass::SubgraphPassPipeline merged_pipeline = lhs_pipeline;
442+
const auto& res_passes = merged_pipeline.get_passes();
443+
for (const auto& pass : rhs_pipeline.get_passes()) {
444+
auto pred = [&pass](const std::shared_ptr<lowered::pass::SubgraphPass>& p) {
445+
return p->get_type_info() == pass->get_type_info();
446+
};
447+
if (std::find_if(res_passes.begin(), res_passes.end(), pred) == res_passes.end()) {
448+
merged_pipeline.register_pass(pass);
449+
}
450+
}
451+
return merged_pipeline;
452+
};
453+
454+
const auto min_size = std::min(lhs.size(), rhs.size());
455+
std::vector<lowered::pass::SubgraphPassPipeline> merged_handlers;
456+
merged_handlers.resize(min_size);
457+
for (size_t i = 0; i < min_size; ++i) {
458+
merged_handlers[i] = merge_pass_pipeline(lhs[i], rhs[i]);
459+
}
460+
auto& handlers_with_larger_size = lhs.size() > rhs.size() ? lhs : rhs;
461+
for (size_t i = min_size; i < handlers_with_larger_size.size(); ++i) {
462+
merged_handlers.emplace_back(std::move(handlers_with_larger_size[i]));
463+
}
464+
return merged_handlers;
465+
}
466+
434467
void LinearIR::LoopManager::fuse_loop_ports(std::vector<LinearIR::LoopManager::LoopPort>& exit_points,
435468
std::vector<LinearIR::LoopManager::LoopPort>& entry_points,
436469
size_t loop_id) {

src/common/snippets/src/lowered/pass/iter_handler.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,50 @@ bool SetFillOffset::run(const LinearIR& linear_ir, LinearIR::constExprIt begin,
112112
return true;
113113
}
114114

115+
TransformInnerSplitLoop::TransformInnerSplitLoop(size_t tail_size) : SubgraphPass(), m_tail_size(tail_size) {}
116+
117+
bool TransformInnerSplitLoop::run(const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
118+
const auto& expr = *end;
119+
const auto node = expr->get_node();
120+
const auto loop_end = ov::as_type_ptr<op::LoopEnd>(node);
121+
const auto& loop_manager = linear_ir.get_loop_manager();
122+
const auto& loop_info = loop_manager->get_loop_info(loop_end->get_id());
123+
const auto current_dim_idx = loop_info->get_dim_idx();
124+
OPENVINO_ASSERT(current_dim_idx != LinearIR::LoopManager::LoopInfo::UNDEFINED_DIM_IDX,
125+
"Outer splitted loop unexpectedly iterates by several dimension indices");
126+
127+
bool modified = false;
128+
for (auto it = std::next(begin); it != end; ++it) {
129+
const auto& expr = *it;
130+
const auto inner_loop_end = ov::as_type_ptr<op::LoopEnd>(expr->get_node());
131+
if (!inner_loop_end)
132+
continue;
133+
const auto inner_loop_info = loop_manager->get_loop_info(inner_loop_end->get_id());
134+
const auto inner_dim_idx = inner_loop_info->get_dim_idx();
135+
if (inner_dim_idx != current_dim_idx)
136+
continue;
137+
const auto inner_loop_begin = inner_loop_end->get_loop_begin();
138+
const auto inner_tail_work_amount = static_cast<int64_t>(inner_loop_end->get_work_amount());
139+
const auto inner_tail_increment = inner_loop_end->get_increment();
140+
auto inner_finalization_offsets = inner_loop_end->get_finalization_offsets();
141+
for (auto& offset : inner_finalization_offsets) {
142+
offset = offset / inner_tail_work_amount * static_cast<int64_t>(m_tail_size);
143+
}
144+
inner_loop_end->set_work_amount(m_tail_size);
145+
// TODO: if the new m_tail_size increment is set, all last iter handlers must be updated with new tail value
146+
// We can also don't split loops in case if inner loop has increment not equal to 1
147+
inner_loop_end->set_increment(std::min(inner_tail_increment, m_tail_size));
148+
inner_loop_end->set_finalization_offsets(inner_finalization_offsets);
149+
const auto inner_loop_begin_it = std::find(begin, it, linear_ir.get_expr_by_node(inner_loop_begin));
150+
const auto inner_loop_end_it = std::next(end);
151+
OPENVINO_ASSERT(inner_loop_begin_it != it, "LoopBegin has not been found!");
152+
const auto& last_iter_handlers = inner_loop_info->handlers[LinearIR::LoopManager::LoopInfo::LAST_ITER];
153+
last_iter_handlers.run(linear_ir, inner_loop_begin_it, inner_loop_end_it);
154+
modified = true;
155+
}
156+
return modified;
157+
}
158+
115159
} // namespace pass
116160
} // namespace lowered
117161
} // namespace snippets

src/common/snippets/src/lowered/pass/pass_pipeline.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ void SubgraphPassPipeline::run(const LinearIR& linear_ir, LinearIR::constExprIt
3939
pass->run(linear_ir, begin, end);
4040
}
4141

42+
const std::vector<std::shared_ptr<SubgraphPass>>& SubgraphPassPipeline::get_passes() const {
43+
return m_passes;
44+
}
45+
4246
void SubgraphPassPipeline::register_positioned_passes(const std::vector<PositionedPass>& pos_passes) {
4347
for (const auto& pp : pos_passes)
4448
insert_pass_instance(pp.position, pp.pass);

src/common/snippets/src/lowered/pass/split_loops.cpp

+18-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "snippets/lowered/pass/split_loops.hpp"
66

77
#include "snippets/lowered/pass/fuse_loops.hpp"
8+
#include "snippets/lowered/pass/iter_handler.hpp"
89
#include "snippets/lowered/linear_ir.hpp"
910
#include "snippets/lowered/loop_manager.hpp"
1011
#include "snippets/snippets_isa.hpp"
@@ -81,7 +82,23 @@ bool SplitLoops::run(LinearIR& linear_ir) {
8182
loop_to_split->get_dim_idx(),
8283
loop_to_split->get_entry_points(),
8384
loop_to_split->get_exit_points());
84-
loop_manager->get_loop_info(split_loop_id)->set_outer_splited_loop(true);
85+
const auto& new_loop_info = loop_manager->get_loop_info(split_loop_id);
86+
new_loop_info->set_outer_splited_loop(true);
87+
new_loop_info->handlers = loop_to_split->handlers;
88+
const auto work_amount = loop_to_fuse->get_work_amount();
89+
const auto increment = loop_to_fuse->get_increment();
90+
const auto tail_size = work_amount % increment;
91+
// TODO: current logic doesn't handle the case when loop has first iteration handlers too.
92+
// Need to skip this transformation for sich cases or improve the logic
93+
if (tail_size != 0) {
94+
// TODO: should we remove previous tail loop handler?
95+
new_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size);
96+
new_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<TransformInnerSplitLoop>(tail_size);
97+
if (work_amount > increment) {
98+
new_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ReduceWorkAmount>(tail_size);
99+
new_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ZeroFinalizationOffsets>();
100+
}
101+
}
85102
break;
86103
}
87104
}

src/common/snippets/src/op/subgraph.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ void Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,
434434
pipeline.register_pass<lowered::pass::MarkLoops>(vector_size);
435435
pipeline.register_pass<lowered::pass::SoftmaxDecomposition>(vector_size);
436436
pipeline.register_pass<lowered::pass::FuseLoops>();
437-
// pipeline.register_pass<lowered::pass::SplitLoops>();
437+
pipeline.register_pass<lowered::pass::SplitLoops>();
438438
pipeline.register_pass<lowered::pass::MoveResultOutOfLoop>();
439439
pipeline.register_pass<lowered::pass::InsertBuffers>(buffer_allocation_rank);
440440
pipeline.register_pass<lowered::pass::InsertLoadStore>(vector_size);

0 commit comments

Comments
 (0)