Skip to content

Commit e1fc633

Browse files
committed
AMX scratchpad buffer is moved inside blocking loops
1 parent 7ba840d commit e1fc633

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.cpp

+33-9
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@ using ExpressionPtr = ov::snippets::lowered::ExpressionPtr;
2424

2525
BrgemmBlocking::BrgemmBlocking() : Pass() {}
2626

27+
void BrgemmBlocking::move_amx_scratchpad_buffer(snippets::lowered::LinearIR& linear_ir, const snippets::lowered::LinearIR::constExprIt& brgemm_it) {
28+
const auto& brgemm_expr = brgemm_it->get();
29+
const auto wsp_expr = brgemm_expr->get_input_port_connector(2)->get_source().get_expr();
30+
const auto wsp_buffer = ov::as_type_ptr<ov::snippets::op::Buffer>(wsp_expr->get_node());
31+
OPENVINO_ASSERT(wsp_buffer && wsp_buffer->is_new_memory(), "Incorrect Scratchpad buffer for Brgemm AMX");
32+
// If scratchpad with temp memory is not explicitly before Brgemm, need to move to there.
33+
if (wsp_expr != *std::prev(brgemm_it)) {
34+
const auto wsp_it = linear_ir.find(wsp_expr);
35+
linear_ir.move(wsp_it, brgemm_it);
36+
}
37+
}
38+
2739
bool BrgemmBlocking::run(LinearIR& linear_ir) {
2840
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BrgemmBlocking")
2941
if (linear_ir.empty())
@@ -75,12 +87,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
7587
*(in_0_subtensor.rbegin() + 1) = block_size_m;
7688
*(out_subtensor.rbegin() + 1) = block_size_m;
7789

90+
auto loop_begin_it = expr_it, loop_end_it = std::next(expr_it);
7891
std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true),
7992
LoopPort(brgemm_expr->get_input_port(1), false)};
80-
if (brgemm->is_with_scratchpad())
93+
if (brgemm->is_with_compensations()) {
8194
entries.emplace_back(brgemm_expr->get_input_port(2), false);
95+
} else if (brgemm->is_amx()) {
96+
move_amx_scratchpad_buffer(linear_ir, expr_it);
97+
loop_begin_it = std::prev(expr_it);
98+
}
8299
std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), true)};
83-
loop_manager->mark_loop(expr_it, std::next(expr_it), m, block_size_m, 1, entries, exits);
100+
loop_manager->mark_loop(loop_begin_it, loop_end_it, m, block_size_m, 1, entries, exits);
84101
}
85102
};
86103

@@ -94,15 +111,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
94111
*in_1_subtensor.rbegin() = block_size_n;
95112
*out_subtensor.rbegin() = block_size_n;
96113

114+
auto loop_begin_it = expr_it, loop_end_it = std::next(expr_it);
97115
std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), false),
98116
LoopPort(brgemm_expr->get_input_port(1), true)};
99-
if (brgemm->is_with_scratchpad()) {
100-
// The second input of Brgemm for AMX case is scratch buffer so it mustn't be incremented
101-
const bool is_incremented = brgemm->is_with_compensations() ? true : false;
102-
entries.emplace_back(brgemm_expr->get_input_port(2), is_incremented);
117+
if (brgemm->is_with_compensations()) {
118+
entries.emplace_back(brgemm_expr->get_input_port(2), true);
119+
} else if (brgemm->is_amx()) {
120+
move_amx_scratchpad_buffer(linear_ir, expr_it);
121+
loop_begin_it = std::prev(expr_it);
103122
}
104123
std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), true)};
105-
loop_manager->mark_loop(expr_it, std::next(expr_it), n, block_size_n, 0, entries, exits);
124+
loop_manager->mark_loop(loop_begin_it, loop_end_it, n, block_size_n, 0, entries, exits);
106125
}
107126
};
108127

@@ -117,12 +136,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
117136
*in_0_subtensor.rbegin() = block_size_k;
118137
*(in_1_subtensor.rbegin() + 1) = block_size_k;
119138

139+
auto loop_begin_it = expr_it, loop_end_it = std::next(expr_it);
120140
std::vector<LoopPort> entries{LoopPort(brgemm_expr->get_input_port(0), true, 0),
121141
LoopPort(brgemm_expr->get_input_port(1), true, 1)};
122-
if (brgemm->is_with_scratchpad())
142+
if (brgemm->is_with_compensations()) {
123143
entries.emplace_back(brgemm_expr->get_input_port(2), false, 1);
144+
} else if (brgemm->is_amx()) {
145+
move_amx_scratchpad_buffer(linear_ir, expr_it);
146+
loop_begin_it = std::prev(expr_it);
147+
}
124148
std::vector<LoopPort> exits{LoopPort(brgemm_expr->get_output_port(0), false)};
125-
auto loop_id = loop_manager->mark_loop(expr_it, std::next(expr_it), k, block_size_k, entries, exits);
149+
auto loop_id = loop_manager->mark_loop(loop_begin_it, loop_end_it, k, block_size_k, entries, exits);
126150
const auto loop_info = loop_manager->get_loop_info(loop_id);
127151

128152
auto first_iter_handler = [](LinearIR& linear_ir, LinearIR::constExprIt loop_end_it) {

src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/brgemm_blocking.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ class BrgemmBlocking : public snippets::lowered::pass::Pass {
2121
OPENVINO_RTTI("BrgemmBlocking", "Pass")
2222
BrgemmBlocking();
2323
bool run(snippets::lowered::LinearIR& linear_ir) override;
24+
25+
private:
26+
static void move_amx_scratchpad_buffer(snippets::lowered::LinearIR& linear_ir, const snippets::lowered::LinearIR::constExprIt& brgemm_it);
2427
};
2528

2629
} // namespace pass

0 commit comments

Comments
 (0)