@@ -112,6 +112,50 @@ bool SetFillOffset::run(const LinearIR& linear_ir, LinearIR::constExprIt begin,
112
112
return true ;
113
113
}
114
114
115
+ TransformInnerSplitLoop::TransformInnerSplitLoop (size_t tail_size) : SubgraphPass(), m_tail_size(tail_size) {}
116
+
117
+ bool TransformInnerSplitLoop::run (const LinearIR& linear_ir, LinearIR::constExprIt begin, LinearIR::constExprIt end) {
118
+ const auto & expr = *end;
119
+ const auto node = expr->get_node ();
120
+ const auto loop_end = ov::as_type_ptr<op::LoopEnd>(node);
121
+ const auto & loop_manager = linear_ir.get_loop_manager ();
122
+ const auto & loop_info = loop_manager->get_loop_info (loop_end->get_id ());
123
+ const auto current_dim_idx = loop_info->get_dim_idx ();
124
+ OPENVINO_ASSERT (current_dim_idx != LinearIR::LoopManager::LoopInfo::UNDEFINED_DIM_IDX,
125
+ " Outer splitted loop unexpectedly iterates by several dimension indices" );
126
+
127
+ bool modified = false ;
128
+ for (auto it = std::next (begin); it != end; ++it) {
129
+ const auto & expr = *it;
130
+ const auto inner_loop_end = ov::as_type_ptr<op::LoopEnd>(expr->get_node ());
131
+ if (!inner_loop_end)
132
+ continue ;
133
+ const auto inner_loop_info = loop_manager->get_loop_info (inner_loop_end->get_id ());
134
+ const auto inner_dim_idx = inner_loop_info->get_dim_idx ();
135
+ if (inner_dim_idx != current_dim_idx)
136
+ continue ;
137
+ const auto inner_loop_begin = inner_loop_end->get_loop_begin ();
138
+ const auto inner_tail_work_amount = static_cast <int64_t >(inner_loop_end->get_work_amount ());
139
+ const auto inner_tail_increment = inner_loop_end->get_increment ();
140
+ auto inner_finalization_offsets = inner_loop_end->get_finalization_offsets ();
141
+ for (auto & offset : inner_finalization_offsets) {
142
+ offset = offset / inner_tail_work_amount * static_cast <int64_t >(m_tail_size);
143
+ }
144
+ inner_loop_end->set_work_amount (m_tail_size);
145
+ // TODO: if the new m_tail_size increment is set, all last iter handlers must be updated with new tail value
146
+ // We can also don't split loops in case if inner loop has increment not equal to 1
147
+ inner_loop_end->set_increment (std::min (inner_tail_increment, m_tail_size));
148
+ inner_loop_end->set_finalization_offsets (inner_finalization_offsets);
149
+ const auto inner_loop_begin_it = std::find (begin, it, linear_ir.get_expr_by_node (inner_loop_begin));
150
+ const auto inner_loop_end_it = std::next (end);
151
+ OPENVINO_ASSERT (inner_loop_begin_it != it, " LoopBegin has not been found!" );
152
+ const auto & last_iter_handlers = inner_loop_info->handlers [LinearIR::LoopManager::LoopInfo::LAST_ITER];
153
+ last_iter_handlers.run (linear_ir, inner_loop_begin_it, inner_loop_end_it);
154
+ modified = true ;
155
+ }
156
+ return modified;
157
+ }
158
+
115
159
} // namespace pass
116
160
} // namespace lowered
117
161
} // namespace snippets
0 commit comments