diff --git a/src/common/snippets/src/lowered/loop_manager.cpp b/src/common/snippets/src/lowered/loop_manager.cpp index c6390b56c08da0..a63ecdd1de34e5 100644 --- a/src/common/snippets/src/lowered/loop_manager.cpp +++ b/src/common/snippets/src/lowered/loop_manager.cpp @@ -421,7 +421,13 @@ void LinearIR::LoopManager::fuse_loops(LinearIR::constExprIt loop_begin_target, loop_info->set_entry_points(new_entries); loop_info->set_exit_points(new_exits); - loop_info->handlers = fuse_loop_handlers(loop_info_upper->handlers, loop_info_lower->handlers); + // WA: if one of the fused loops is broadcastable (wa = 1), its handlers have less priority. + // Need to fix it by avoiding handlers creation for the loops whose work amount less than increment + if (loop_info_upper->get_work_amount() > loop_info_lower->get_work_amount()) { + loop_info->handlers = fuse_loop_handlers(loop_info_upper->handlers, loop_info_lower->handlers); + } else { + loop_info->handlers = fuse_loop_handlers(loop_info_lower->handlers, loop_info_upper->handlers); + } const auto& from = fuse_into_upper ? loop_id_lower : loop_id_upper; const auto& to = fuse_into_upper ? loop_id_upper : loop_id_lower; @@ -434,30 +440,24 @@ void LinearIR::LoopManager::fuse_loops(LinearIR::constExprIt loop_begin_target, } std::vector LinearIR::LoopManager::fuse_loop_handlers( - std::vector& lhs, - std::vector& rhs) { - auto merge_pass_pipeline = [](const lowered::pass::SubgraphPassPipeline& lhs_pipeline, - const lowered::pass::SubgraphPassPipeline& rhs_pipeline) { - lowered::pass::SubgraphPassPipeline merged_pipeline = lhs_pipeline; - const auto& res_passes = merged_pipeline.get_passes(); - for (const auto& pass : rhs_pipeline.get_passes()) { + std::vector& from, + std::vector& to) { + const auto min_size = std::min(from.size(), to.size()); + std::vector merged_handlers; + merged_handlers.resize(min_size); + for (size_t i = 0; i < min_size; ++i) { + merged_handlers[i] = from[i]; + const auto& res_passes = merged_handlers[i].get_passes(); + for (const auto& pass : to[i].get_passes()) { auto pred = [&pass](const std::shared_ptr& p) { return p->get_type_info() == pass->get_type_info(); }; if (std::find_if(res_passes.begin(), res_passes.end(), pred) == res_passes.end()) { - merged_pipeline.register_pass(pass); + merged_handlers[i].register_pass(pass); } } - return merged_pipeline; - }; - - const auto min_size = std::min(lhs.size(), rhs.size()); - std::vector merged_handlers; - merged_handlers.resize(min_size); - for (size_t i = 0; i < min_size; ++i) { - merged_handlers[i] = merge_pass_pipeline(lhs[i], rhs[i]); } - auto& handlers_with_larger_size = lhs.size() > rhs.size() ? lhs : rhs; + auto& handlers_with_larger_size = from.size() > to.size() ? from : to; for (size_t i = min_size; i < handlers_with_larger_size.size(); ++i) { merged_handlers.emplace_back(std::move(handlers_with_larger_size[i])); }