@@ -24,6 +24,18 @@ using ExpressionPtr = ov::snippets::lowered::ExpressionPtr;
24
24
25
25
BrgemmBlocking::BrgemmBlocking () : Pass() {}
26
26
27
+ void BrgemmBlocking::move_amx_scratchpad_buffer (snippets::lowered::LinearIR& linear_ir, const snippets::lowered::LinearIR::constExprIt& brgemm_it) {
28
+ const auto & brgemm_expr = brgemm_it->get ();
29
+ const auto wsp_expr = brgemm_expr->get_input_port_connector (2 )->get_source ().get_expr ();
30
+ const auto wsp_buffer = ov::as_type_ptr<ov::snippets::op::Buffer>(wsp_expr->get_node ());
31
+ OPENVINO_ASSERT (wsp_buffer && wsp_buffer->is_new_memory (), " Incorrect Scratchpad buffer for Brgemm AMX" );
32
+ // If scratchpad with temp memory is not explicitly before Brgemm, need to move to there.
33
+ if (wsp_expr != *std::prev (brgemm_it)) {
34
+ const auto wsp_it = linear_ir.find (wsp_expr);
35
+ linear_ir.move (wsp_it, brgemm_it);
36
+ }
37
+ }
38
+
27
39
bool BrgemmBlocking::run (LinearIR& linear_ir) {
28
40
OV_ITT_SCOPED_TASK (ov::pass::itt::domains::SnippetsTransform, " Snippets::BrgemmBlocking" )
29
41
if (linear_ir.empty ())
@@ -75,12 +87,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
75
87
*(in_0_subtensor.rbegin () + 1 ) = block_size_m;
76
88
*(out_subtensor.rbegin () + 1 ) = block_size_m;
77
89
90
+ auto loop_begin_it = expr_it, loop_end_it = std::next (expr_it);
78
91
std::vector<LoopPort> entries{LoopPort (brgemm_expr->get_input_port (0 ), true ),
79
92
LoopPort (brgemm_expr->get_input_port (1 ), false )};
80
- if (brgemm->is_with_scratchpad ())
93
+ if (brgemm->is_with_compensations ()) {
81
94
entries.emplace_back (brgemm_expr->get_input_port (2 ), false );
95
+ } else if (brgemm->is_amx ()) {
96
+ move_amx_scratchpad_buffer (linear_ir, expr_it);
97
+ loop_begin_it = std::prev (expr_it);
98
+ }
82
99
std::vector<LoopPort> exits{LoopPort (brgemm_expr->get_output_port (0 ), true )};
83
- loop_manager->mark_loop (expr_it, std::next (expr_it) , m, block_size_m, 1 , entries, exits);
100
+ loop_manager->mark_loop (loop_begin_it, loop_end_it , m, block_size_m, 1 , entries, exits);
84
101
}
85
102
};
86
103
@@ -94,15 +111,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
94
111
*in_1_subtensor.rbegin () = block_size_n;
95
112
*out_subtensor.rbegin () = block_size_n;
96
113
114
+ auto loop_begin_it = expr_it, loop_end_it = std::next (expr_it);
97
115
std::vector<LoopPort> entries{LoopPort (brgemm_expr->get_input_port (0 ), false ),
98
116
LoopPort (brgemm_expr->get_input_port (1 ), true )};
99
- if (brgemm->is_with_scratchpad ()) {
100
- // The second input of Brgemm for AMX case is scratch buffer so it mustn't be incremented
101
- const bool is_incremented = brgemm->is_with_compensations () ? true : false ;
102
- entries.emplace_back (brgemm_expr->get_input_port (2 ), is_incremented);
117
+ if (brgemm->is_with_compensations ()) {
118
+ entries.emplace_back (brgemm_expr->get_input_port (2 ), true );
119
+ } else if (brgemm->is_amx ()) {
120
+ move_amx_scratchpad_buffer (linear_ir, expr_it);
121
+ loop_begin_it = std::prev (expr_it);
103
122
}
104
123
std::vector<LoopPort> exits{LoopPort (brgemm_expr->get_output_port (0 ), true )};
105
- loop_manager->mark_loop (expr_it, std::next (expr_it) , n, block_size_n, 0 , entries, exits);
124
+ loop_manager->mark_loop (loop_begin_it, loop_end_it , n, block_size_n, 0 , entries, exits);
106
125
}
107
126
};
108
127
@@ -117,12 +136,17 @@ bool BrgemmBlocking::run(LinearIR& linear_ir) {
117
136
*in_0_subtensor.rbegin () = block_size_k;
118
137
*(in_1_subtensor.rbegin () + 1 ) = block_size_k;
119
138
139
+ auto loop_begin_it = expr_it, loop_end_it = std::next (expr_it);
120
140
std::vector<LoopPort> entries{LoopPort (brgemm_expr->get_input_port (0 ), true , 0 ),
121
141
LoopPort (brgemm_expr->get_input_port (1 ), true , 1 )};
122
- if (brgemm->is_with_scratchpad ())
142
+ if (brgemm->is_with_compensations ()) {
123
143
entries.emplace_back (brgemm_expr->get_input_port (2 ), false , 1 );
144
+ } else if (brgemm->is_amx ()) {
145
+ move_amx_scratchpad_buffer (linear_ir, expr_it);
146
+ loop_begin_it = std::prev (expr_it);
147
+ }
124
148
std::vector<LoopPort> exits{LoopPort (brgemm_expr->get_output_port (0 ), false )};
125
- auto loop_id = loop_manager->mark_loop (expr_it, std::next (expr_it) , k, block_size_k, entries, exits);
149
+ auto loop_id = loop_manager->mark_loop (loop_begin_it, loop_end_it , k, block_size_k, entries, exits);
126
150
const auto loop_info = loop_manager->get_loop_info (loop_id);
127
151
128
152
auto first_iter_handler = [](LinearIR& linear_ir, LinearIR::constExprIt loop_end_it) {
0 commit comments