Skip to content

Commit

Permalink
Added WA for loops with broadcastable work amount fusing
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Nov 29, 2023
1 parent 23f3367 commit d0b71a3
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions src/common/snippets/src/lowered/loop_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,13 @@ void LinearIR::LoopManager::fuse_loops(LinearIR::constExprIt loop_begin_target,
loop_info->set_entry_points(new_entries);
loop_info->set_exit_points(new_exits);

loop_info->handlers = fuse_loop_handlers(loop_info_upper->handlers, loop_info_lower->handlers);
// WA: if one of the fused loops is broadcastable (wa = 1), its handlers have less priority.
// Need to fix it by avoiding handlers creation for the loops whose work amount less than increment
if (loop_info_upper->get_work_amount() > loop_info_lower->get_work_amount()) {
loop_info->handlers = fuse_loop_handlers(loop_info_upper->handlers, loop_info_lower->handlers);
} else {
loop_info->handlers = fuse_loop_handlers(loop_info_lower->handlers, loop_info_upper->handlers);
}

const auto& from = fuse_into_upper ? loop_id_lower : loop_id_upper;
const auto& to = fuse_into_upper ? loop_id_upper : loop_id_lower;
Expand All @@ -434,30 +440,24 @@ void LinearIR::LoopManager::fuse_loops(LinearIR::constExprIt loop_begin_target,
}

std::vector<lowered::pass::SubgraphPassPipeline> LinearIR::LoopManager::fuse_loop_handlers(
std::vector<lowered::pass::SubgraphPassPipeline>& lhs,
std::vector<lowered::pass::SubgraphPassPipeline>& rhs) {
auto merge_pass_pipeline = [](const lowered::pass::SubgraphPassPipeline& lhs_pipeline,
const lowered::pass::SubgraphPassPipeline& rhs_pipeline) {
lowered::pass::SubgraphPassPipeline merged_pipeline = lhs_pipeline;
const auto& res_passes = merged_pipeline.get_passes();
for (const auto& pass : rhs_pipeline.get_passes()) {
std::vector<lowered::pass::SubgraphPassPipeline>& from,
std::vector<lowered::pass::SubgraphPassPipeline>& to) {
const auto min_size = std::min(from.size(), to.size());
std::vector<lowered::pass::SubgraphPassPipeline> merged_handlers;
merged_handlers.resize(min_size);
for (size_t i = 0; i < min_size; ++i) {
merged_handlers[i] = from[i];
const auto& res_passes = merged_handlers[i].get_passes();
for (const auto& pass : to[i].get_passes()) {
auto pred = [&pass](const std::shared_ptr<lowered::pass::SubgraphPass>& p) {
return p->get_type_info() == pass->get_type_info();
};
if (std::find_if(res_passes.begin(), res_passes.end(), pred) == res_passes.end()) {
merged_pipeline.register_pass(pass);
merged_handlers[i].register_pass(pass);
}
}
return merged_pipeline;
};

const auto min_size = std::min(lhs.size(), rhs.size());
std::vector<lowered::pass::SubgraphPassPipeline> merged_handlers;
merged_handlers.resize(min_size);
for (size_t i = 0; i < min_size; ++i) {
merged_handlers[i] = merge_pass_pipeline(lhs[i], rhs[i]);
}
auto& handlers_with_larger_size = lhs.size() > rhs.size() ? lhs : rhs;
auto& handlers_with_larger_size = from.size() > to.size() ? from : to;
for (size_t i = min_size; i < handlers_with_larger_size.size(); ++i) {
merged_handlers.emplace_back(std::move(handlers_with_larger_size[i]));
}
Expand Down

0 comments on commit d0b71a3

Please sign in to comment.