Skip to content

Commit ba50c83

Browse files
committed
Reduce decomposition
1 parent ef4178e commit ba50c83

8 files changed

+289
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "pass.hpp"
8+
9+
namespace ov {
10+
namespace snippets {
11+
namespace lowered {
12+
namespace pass {
13+
14+
/**
15+
* @interface ReduceMaxDecomposition
16+
* @brief Decomposes Softmax to a range of low-level operations on linear IR
17+
* @ingroup snippets
18+
*/
19+
class ReduceMaxDecomposition : public Pass {
20+
public:
21+
OPENVINO_RTTI("ReduceMaxDecomposition", "Pass")
22+
explicit ReduceMaxDecomposition(size_t vector_size);
23+
bool run(LinearIR& linear_ir) override;
24+
25+
private:
26+
size_t m_vector_size;
27+
};
28+
29+
} // namespace pass
30+
} // namespace lowered
31+
} // namespace snippets
32+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "pass.hpp"
8+
9+
namespace ov {
10+
namespace snippets {
11+
namespace lowered {
12+
namespace pass {
13+
14+
/**
15+
* @interface ReduceSumDecomposition
16+
* @brief Decomposes Softmax to a range of low-level operations on linear IR
17+
* @ingroup snippets
18+
*/
19+
class ReduceSumDecomposition : public Pass {
20+
public:
21+
OPENVINO_RTTI("ReduceSumDecomposition", "Pass")
22+
explicit ReduceSumDecomposition(size_t vector_size);
23+
bool run(LinearIR& linear_ir) override;
24+
25+
private:
26+
size_t m_vector_size;
27+
};
28+
29+
} // namespace pass
30+
} // namespace lowered
31+
} // namespace snippets
32+
} // namespace ov

src/common/snippets/src/lowered/loop_manager.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ void LinearIR::LoopManager::insert_loop_id(const ExpressionPtr& expr, size_t new
598598
OPENVINO_ASSERT(m_map.count(new_id) == 1, "Failed marking expression by Loop ID: the Loop with this ID hasn't registered");
599599
auto& loop_ids = expr->m_loop_ids;
600600
OPENVINO_ASSERT(std::find(loop_ids.cbegin(), loop_ids.cend(), new_id) == loop_ids.cend(),
601-
"Expression cannot have several the same Loop IDs");
601+
"Expression cannot have several identical Loop IDs");
602602
auto insert_it = before ? loop_ids.cbegin() : loop_ids.cend();
603603
if (target_id != SIZE_MAX) {
604604
insert_it = std::find(loop_ids.cbegin(), loop_ids.cend(), target_id);

src/common/snippets/src/lowered/pass/insert_broadcastmove.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ bool InsertBroadcastMove::run(LinearIR& linear_ir) {
5656
OPENVINO_ASSERT(last_dims[i] == 1,
5757
"Attempt to broadcast non-1 dimension. Target dim: ", broadcasted_dim,
5858
" This dim: ", last_dims[i]);
59-
const auto bcast_dim = ov::Dimension(last_dims[i]);
60-
const auto broadcast = std::make_shared<op::BroadcastMove>(node->get_input_source_output(i), bcast_dim);
59+
const auto broadcast = std::make_shared<op::BroadcastMove>(node->get_input_source_output(i), broadcasted_dim);
6160

6261
PortDescriptorUtils::set_port_descriptor_ptr(broadcast->output(0), connectors[i]->get_source().get_descriptor_ptr()->clone());
6362
const auto broadcast_expr = linear_ir.create_expression(broadcast, {connectors[i]});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "snippets/lowered/pass/reduce_max_decomposition.hpp"
6+
7+
#include "snippets/lowered/linear_ir.hpp"
8+
#include "snippets/lowered/loop_manager.hpp"
9+
#include "snippets/lowered/pass/mark_loops.hpp"
10+
#include "snippets/lowered/pass/iter_handler.hpp"
11+
#include "snippets/snippets_isa.hpp"
12+
#include "snippets/itt.hpp"
13+
14+
#include "openvino/pass/pattern/op/wrap_type.hpp"
15+
#include "openvino/pass/pattern/matcher.hpp"
16+
17+
18+
namespace ov {
19+
namespace snippets {
20+
namespace lowered {
21+
namespace pass {
22+
23+
using LoopInfo = LinearIR::LoopManager::LoopInfo;
24+
25+
ReduceMaxDecomposition::ReduceMaxDecomposition(size_t vector_size) : m_vector_size{vector_size} {}
26+
27+
bool ReduceMaxDecomposition::run(LinearIR& linear_ir) {
28+
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ReduceMaxDecompositionLowered")
29+
const auto& loop_manager = linear_ir.get_loop_manager();
30+
31+
bool modified = false;
32+
for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
33+
const auto& op = (*expr_it)->get_node();
34+
if (!ov::is_type<ov::snippets::op::ReduceMax>(op))
35+
continue;
36+
37+
const auto reduce = op;
38+
const auto reduce_expr = *expr_it;
39+
const auto& input_shape = reduce_expr->get_input_port_descriptor(0)->get_shape();
40+
const auto work_amount = *(input_shape.rbegin());
41+
const bool is_dynamic = reduce->is_dynamic();
42+
43+
// We need an iterator to the inserted element
44+
auto push_node = [&](const std::shared_ptr<Node>& n) {
45+
const auto expr = linear_ir.insert(expr_it, n);
46+
if (is_dynamic)
47+
expr->get()->updateShapes();
48+
return std::make_pair(expr, n);
49+
};
50+
// Float constant values in byte representation
51+
const auto fill_value = uint32_t(0xff7fffff);
52+
// Note: VectorBuffer is a special case, since it should go before the initial Load.
53+
// The buffer must be initialized with fill_value before reduction
54+
const auto vector_buffer = push_node(std::make_shared<op::VectorBuffer>());
55+
const auto initial_fill = push_node(std::make_shared<op::Fill>(vector_buffer.second, 0, fill_value));
56+
57+
// Reduce loop
58+
const auto fill = push_node(std::make_shared<op::Fill>(reduce->get_input_source_output(0), m_vector_size, fill_value));
59+
const auto max = push_node(std::make_shared<ov::op::v1::Maximum>(fill.second, initial_fill.second));
60+
61+
const auto reduce_loop_id = loop_manager->mark_loop(
62+
fill.first,
63+
expr_it,
64+
work_amount,
65+
m_vector_size,
66+
0,
67+
std::vector<ExpressionPort>{(*fill.first)->get_input_port(0), (*max.first)->get_input_port(1)},
68+
std::vector<ExpressionPort>{(*max.first)->get_output_port(0)});
69+
const auto reduce_loop_info = loop_manager->get_loop_info(reduce_loop_id);
70+
const auto tail_size = work_amount % m_vector_size;
71+
if (tail_size != 0) {
72+
reduce_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size);
73+
reduce_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<SetFillOffset>(tail_size);
74+
if (work_amount > m_vector_size) {
75+
reduce_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ReduceWorkAmount>(tail_size);
76+
reduce_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ZeroFinalizationOffsets>();
77+
}
78+
}
79+
80+
const auto horizon = push_node(std::make_shared<op::HorizonMax>(max.second));
81+
82+
// Transfer original ExpressionPorts
83+
linear_ir.replace_input((*fill.first)->get_input_port(0), reduce_expr->get_input_port_connector(0));
84+
linear_ir.replace_input(reduce_expr->get_output_port_connector(0)->get_consumers(), (*horizon.first)->get_output_port_connector(0));
85+
86+
// Update Loop info for outer loops
87+
const std::vector<ExpressionPort> entry_points{(*fill.first)->get_input_port(0)};
88+
const std::vector<ExpressionPort> exit_points{(*horizon.first)->get_output_port(0)};
89+
for (auto loop_id : reduce_expr->get_loop_ids()) {
90+
loop_manager->expression_replacement(vector_buffer.first,
91+
expr_it,
92+
reduce_expr,
93+
loop_id,
94+
entry_points,
95+
exit_points);
96+
}
97+
98+
expr_it = linear_ir.erase(expr_it);
99+
modified = true;
100+
}
101+
return modified;
102+
}
103+
104+
} // namespace pass
105+
} // namespace lowered
106+
} // namespace snippets
107+
} // namespace ov
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (C) 2023 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "snippets/lowered/pass/reduce_sum_decomposition.hpp"
6+
7+
#include "snippets/lowered/linear_ir.hpp"
8+
#include "snippets/lowered/loop_manager.hpp"
9+
#include "snippets/lowered/pass/mark_loops.hpp"
10+
#include "snippets/lowered/pass/iter_handler.hpp"
11+
#include "snippets/snippets_isa.hpp"
12+
#include "snippets/itt.hpp"
13+
14+
#include "openvino/pass/pattern/op/wrap_type.hpp"
15+
#include "openvino/pass/pattern/matcher.hpp"
16+
17+
18+
namespace ov {
19+
namespace snippets {
20+
namespace lowered {
21+
namespace pass {
22+
23+
using LoopInfo = LinearIR::LoopManager::LoopInfo;
24+
25+
ReduceSumDecomposition::ReduceSumDecomposition(size_t vector_size) : m_vector_size{vector_size} {}
26+
27+
bool ReduceSumDecomposition::run(LinearIR& linear_ir) {
28+
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ReduceSumDecompositionLowered")
29+
bool modified = false;
30+
const auto& loop_manager = linear_ir.get_loop_manager();
31+
32+
for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
33+
const auto& op = (*expr_it)->get_node();
34+
if (!ov::is_type<ov::snippets::op::ReduceSum>(op))
35+
continue;
36+
37+
const auto reduce = op;
38+
const auto reduce_expr = *expr_it;
39+
const auto& input_shape = reduce_expr->get_input_port_descriptor(0)->get_shape();
40+
const auto work_amount = *(input_shape.rbegin());
41+
const bool is_dynamic = reduce->is_dynamic();
42+
43+
// We need an iterator to the inserted element
44+
auto push_node = [&](const std::shared_ptr<Node>& n) {
45+
const auto expr = linear_ir.insert(expr_it, n);
46+
if (is_dynamic)
47+
expr->get()->updateShapes();
48+
return std::make_pair(expr, n);
49+
};
50+
// Float constant values in byte representation
51+
const auto fill_value = uint32_t(0x00000000);
52+
// Note: VectorBuffer is a special case, since it should go before the initial Load.
53+
// The buffer must be initialized with fill_value before reduction
54+
const auto vector_buffer = push_node(std::make_shared<op::VectorBuffer>());
55+
const auto initial_fill = push_node(std::make_shared<op::Fill>(vector_buffer.second, 0, fill_value));
56+
57+
// Reduce loop
58+
const auto fill = push_node(std::make_shared<op::Fill>(reduce->get_input_source_output(0), m_vector_size, fill_value));
59+
const auto add = push_node(std::make_shared<ov::op::v1::Add>(fill.second, initial_fill.second));
60+
61+
const auto reduce_loop_id = loop_manager->mark_loop(
62+
fill.first,
63+
expr_it,
64+
work_amount,
65+
m_vector_size,
66+
0,
67+
std::vector<ExpressionPort>{(*fill.first)->get_input_port(0), (*add.first)->get_input_port(1)},
68+
std::vector<ExpressionPort>{(*add.first)->get_output_port(0)});
69+
const auto reduce_loop_info = loop_manager->get_loop_info(reduce_loop_id);
70+
const auto tail_size = work_amount % m_vector_size;
71+
if (tail_size != 0) {
72+
reduce_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<DefaultTailLoopHandler>(tail_size);
73+
reduce_loop_info->handlers[LoopInfo::LAST_ITER].register_pass<SetFillOffset>(tail_size);
74+
if (work_amount > m_vector_size) {
75+
reduce_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ReduceWorkAmount>(tail_size);
76+
reduce_loop_info->handlers[LoopInfo::MAIN_BODY].register_pass<ZeroFinalizationOffsets>();
77+
}
78+
}
79+
80+
const auto horizon = push_node(std::make_shared<op::HorizonSum>(add.second));
81+
82+
// Transfer original ExpressionPorts
83+
linear_ir.replace_input((*fill.first)->get_input_port(0), reduce_expr->get_input_port_connector(0));
84+
linear_ir.replace_input(reduce_expr->get_output_port_connector(0)->get_consumers(), (*horizon.first)->get_output_port_connector(0));
85+
86+
// Update Loop info for outer loops
87+
const std::vector<ExpressionPort> entry_points{(*fill.first)->get_input_port(0)};
88+
const std::vector<ExpressionPort> exit_points{(*horizon.first)->get_output_port(0)};
89+
for (auto loop_id : reduce_expr->get_loop_ids()) {
90+
loop_manager->expression_replacement(vector_buffer.first,
91+
expr_it,
92+
reduce_expr,
93+
loop_id,
94+
entry_points,
95+
exit_points);
96+
}
97+
98+
expr_it = linear_ir.erase(expr_it);
99+
modified = true;
100+
}
101+
102+
return modified;
103+
}
104+
105+
} // namespace pass
106+
} // namespace lowered
107+
} // namespace snippets
108+
} // namespace ov

src/common/snippets/src/op/subgraph.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
#include "snippets/lowered/pass/validate_loops.hpp"
4545
#include "snippets/lowered/pass/insert_loops.hpp"
4646
#include "snippets/lowered/pass/optimize_domain.hpp"
47+
#include "snippets/lowered/pass/reduce_max_decomposition.hpp"
48+
#include "snippets/lowered/pass/reduce_sum_decomposition.hpp"
4749

4850
#include "transformations/utils/utils.hpp"
4951

@@ -436,7 +438,10 @@ void Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,
436438

437439
PassPipeline pipeline;
438440
pipeline.register_pass<lowered::pass::MarkLoops>(vector_size);
441+
// TODO: remove SoftmaxDecomposition pass
439442
pipeline.register_pass<lowered::pass::SoftmaxDecomposition>(vector_size);
443+
pipeline.register_pass<lowered::pass::ReduceMaxDecomposition>(vector_size);
444+
pipeline.register_pass<lowered::pass::ReduceSumDecomposition>(vector_size);
440445
pipeline.register_pass<lowered::pass::FuseLoops>();
441446
pipeline.register_pass<lowered::pass::SplitLoops>();
442447
pipeline.register_pass<lowered::pass::MoveResultOutOfLoop>();

src/common/snippets/src/pass/softmax_decomposition.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ SoftmaxDecomposition::SoftmaxDecomposition() {
5757
subtensor[i] = PortDescriptor::ServiceDimensions::FULL_DIM;
5858

5959
PortDescriptorUtils::set_port_descriptor_ptr(reduce_max->input(0), std::make_shared<PortDescriptor>(reduce_max->input(0), subtensor));
60-
PortDescriptorUtils::set_port_descriptor_ptr(reduce_sum->input(0), std::make_shared<PortDescriptor>(reduce_sum->input(0), subtensor));
6160
PortDescriptorUtils::set_port_descriptor_ptr(reduce_max->output(0), std::make_shared<PortDescriptor>(reduce_max->output(0), subtensor));
61+
PortDescriptorUtils::set_port_descriptor_ptr(reduce_sum->input(0), std::make_shared<PortDescriptor>(reduce_sum->input(0), subtensor));
6262
PortDescriptorUtils::set_port_descriptor_ptr(reduce_sum->output(0), std::make_shared<PortDescriptor>(reduce_sum->output(0), subtensor));
63+
PortDescriptorUtils::set_port_descriptor_ptr(power->input(0), std::make_shared<PortDescriptor>(power->input(0), subtensor));
64+
PortDescriptorUtils::set_port_descriptor_ptr(power->output(0), std::make_shared<PortDescriptor>(power->output(0), subtensor));
6365

6466
return ov::replace_node_update_name(softmax, multiply);
6567
};

0 commit comments

Comments
 (0)