@@ -61,20 +61,22 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
61
61
// Init value of vector buffer for ReduceMax is -FLOAT_MIN.
62
62
const auto fill_max = push_node (std::make_shared<op::Fill>(vector_buffer_max.second , 0 , float_min_constant));
63
63
// ReduceMax loop
64
- const auto & max = push_node (std::make_shared<ov::op::v1::Maximum>(softmax->get_input_source_output (0 ), fill_max.second ));
64
+ const auto fill_max_tail = push_node (std::make_shared<op::Fill>(softmax->get_input_source_output (0 ), m_vector_size, float_min_constant));
65
+
66
+ const auto & max = push_node (std::make_shared<ov::op::v1::Maximum>(fill_max_tail.second , fill_max.second ));
65
67
66
68
const auto horizon_max = push_node (std::make_shared<op::HorizonMax>(max.second ));
67
69
68
70
// Markup of ReduceMax Loop
69
- const auto reduce_max_loop_id = loop_manager->mark_loop (max .first , horizon_max.first , inner_work_amount, m_vector_size, 0 ,
70
- std::vector<ExpressionPort>{(*max .first )->get_input_port (0 ),
71
+ const auto reduce_max_loop_id = loop_manager->mark_loop (fill_max_tail .first , horizon_max.first , inner_work_amount, m_vector_size, 0 ,
72
+ std::vector<ExpressionPort>{(*fill_max_tail .first )->get_input_port (0 ),
71
73
(*max.first )->get_input_port (1 )},
72
74
std::vector<ExpressionPort>{(*max.first )->get_output_port (0 )});
73
75
const auto & reduce_max_loop_info = loop_manager->get_loop_info (reduce_max_loop_id);
74
76
const auto tail_size = inner_work_amount % m_vector_size;
75
77
if (tail_size != 0 ) {
76
78
reduce_max_loop_info->handlers [LoopInfo::LAST_ITER].register_pass <DefaultTailLoopHandler>(tail_size);
77
- reduce_max_loop_info->handlers [LoopInfo::LAST_ITER].register_pass <InsertFill >(tail_size);
79
+ reduce_max_loop_info->handlers [LoopInfo::LAST_ITER].register_pass <SetFillOffset >(tail_size);
78
80
if (inner_work_amount > m_vector_size) {
79
81
reduce_max_loop_info->handlers [LoopInfo::MAIN_BODY].register_pass <ReduceWorkAmount>(tail_size);
80
82
reduce_max_loop_info->handlers [LoopInfo::MAIN_BODY].register_pass <ZeroFinalizationOffsets>();
@@ -88,7 +90,8 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
88
90
// Sub + Exp + ReduceSum Loop
89
91
const auto sub = push_node (std::make_shared<ov::op::v1::Subtract>(softmax->get_input_source_output (0 ), broadcast_horizon_max.second ));
90
92
const auto exp = push_node (std::make_shared<ov::op::v0::Exp>(sub.second ));
91
- const auto sum = push_node (std::make_shared<ov::op::v1::Add>(exp .second , fill_sum.second ));
93
+ const auto fill_sum_tail = push_node (std::make_shared<op::Fill>(exp .second , m_vector_size, zero_constant));
94
+ const auto sum = push_node (std::make_shared<ov::op::v1::Add>(fill_sum_tail.second , fill_sum.second ));
92
95
93
96
const auto horizon_sum = push_node (std::make_shared<op::HorizonSum>(sum.second ));
94
97
@@ -97,12 +100,12 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
97
100
std::vector<ExpressionPort>{(*sub.first )->get_input_port (0 ),
98
101
(*sub.first )->get_input_port (1 ),
99
102
(*sum.first )->get_input_port (1 )},
100
- std::vector<ExpressionPort>{(*exp .first )->get_output_port (0 ),
103
+ std::vector<ExpressionPort>{(*fill_sum_tail .first )->get_output_port (0 ),
101
104
(*sum.first )->get_output_port (0 )});
102
105
const auto & reduce_sum_loop_info = loop_manager->get_loop_info (reduce_sum_loop_id);
103
106
if (tail_size != 0 ) {
104
- reduce_max_loop_info->handlers [LoopInfo::LAST_ITER].register_pass <InsertFill>(tail_size);
105
107
reduce_sum_loop_info->handlers [LoopInfo::LAST_ITER].register_pass <DefaultTailLoopHandler>(tail_size);
108
+ reduce_sum_loop_info->handlers [LoopInfo::LAST_ITER].register_pass <SetFillOffset>(tail_size);
106
109
if (inner_work_amount > m_vector_size) {
107
110
reduce_sum_loop_info->handlers [LoopInfo::MAIN_BODY].register_pass <ReduceWorkAmount>(tail_size);
108
111
reduce_sum_loop_info->handlers [LoopInfo::MAIN_BODY].register_pass <ZeroFinalizationOffsets>();
@@ -114,10 +117,10 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
114
117
const auto broadcast_pow = push_node (std::make_shared<op::BroadcastMove>(pow .second , broadcasted_dim));
115
118
116
119
// Mul (pseudo-Divide loop)
117
- const auto mul = push_node (std::make_shared<ov::op::v1::Multiply>(exp .second , broadcast_pow.second ));
120
+ const auto mul = push_node (std::make_shared<ov::op::v1::Multiply>(fill_sum_tail .second , broadcast_pow.second ));
118
121
119
122
// Transfer original ExpressionPorts
120
- linear_ir.replace_input ((*max .first )->get_input_port (0 ), input_connector);
123
+ linear_ir.replace_input ((*fill_max_tail .first )->get_input_port (0 ), input_connector);
121
124
linear_ir.replace_input ((*sub.first )->get_input_port (0 ), input_connector);
122
125
linear_ir.replace_input (output_connector->get_consumers (), (*mul.first )->get_output_port_connector (0 ));
123
126
@@ -136,24 +139,14 @@ bool SoftmaxDecomposition::run(LinearIR& linear_ir) {
136
139
}
137
140
138
141
// Update Loop info for outer loops
139
- const auto entry_points = std::vector<ExpressionPort>{(*max .first )->get_input_port (0 ),
142
+ const auto entry_points = std::vector<ExpressionPort>{(*fill_max_tail .first )->get_input_port (0 ),
140
143
(*sub.first )->get_input_port (0 )};
141
144
const auto exit_points = std::vector<ExpressionPort>{(*mul.first )->get_output_port (0 )};
142
145
for (auto loop_id : softmax_loop_ids) {
143
146
loop_manager->expression_replacement (vector_buffer_max.first , expr_it, softmax_expr, loop_id, entry_points, exit_points);
144
147
}
145
148
146
149
expr_it = linear_ir.erase (expr_it); // Remove Softmax
147
-
148
- /* =========================================== */
149
-
150
- /* ============= Runtime Info ================ */
151
-
152
- // For tail loop we should fill input of Max by float min and
153
- // input of Sum by zero to avoid math incorrect calculations
154
- // TODO [111383]: It should be covered via general pipeline (for example, via analyze in InsertTailLoop?)
155
- max.second ->input (0 ).get_rt_info ()[" set_fill" ] = float_min_constant;
156
- sum.second ->input (0 ).get_rt_info ()[" set_fill" ] = zero_constant;
157
150
modified = true ;
158
151
}
159
152
}
0 commit comments