Skip to content

Commit 5f51d19

Browse files
committed
[BugFix][MetaSchedule] MultiLevelTilingTensorCore generates inconsistent thread-binding sketch for batched matmul
1 parent 3b97658 commit 5f51d19

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> MultiLevelTilingNode::SplitLoo
190190
return {factors, splits};
191191
}
192192

193-
std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
193+
std::vector<State> MultiLevelTilingNode::TileLoopNest(State state,
194+
int tile_inner_most_space_loop_num) const {
194195
Schedule& sch = state->sch;
195196
const BlockRV& block_rv = state->block_rv;
196197
// Step 1. Assuming trivial binding, pair the loops and their iter-var-types
@@ -199,6 +200,16 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
199200
ICHECK_EQ(loops.size(), iter_types.size());
200201
// Step 2. For each loop axis, tile it
201202
int64_t spatial_loop_product = 1;
203+
204+
int total_spatial_loop_num = 0;
205+
std::for_each(iter_types.begin(), iter_types.end(), [&](const auto& iter_type) {
206+
if (iter_type == IterVarType::kDataPar) total_spatial_loop_num++;
207+
});
208+
CHECK_GE(total_spatial_loop_num, tile_inner_most_space_loop_num);
209+
if (tile_inner_most_space_loop_num < 0) tile_inner_most_space_loop_num = total_spatial_loop_num;
210+
int outer_most_spatial_loop_skipped_num = total_spatial_loop_num - tile_inner_most_space_loop_num;
211+
212+
Array<LoopRV> skipped_outer_spatial_loops;
202213
std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
203214
state->tile_factors.resize(tiles.size());
204215
std::vector<Array<tir::ExprRV>> tile_factors;
@@ -208,6 +219,11 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
208219
const std::vector<int>* idx = nullptr;
209220

210221
if (iter_types[i] == IterVarType::kDataPar) {
222+
if (outer_most_spatial_loop_skipped_num > 0) {
223+
skipped_outer_spatial_loops.push_back(loop);
224+
outer_most_spatial_loop_skipped_num--;
225+
continue;
226+
}
211227
idx = &s_indices_;
212228
if (spatial_loop_product != -1) {
213229
if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) {
@@ -241,6 +257,11 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
241257
sch->Reorder(support::ConcatArrayList<LoopRV>(tiles.begin(), tiles.end()));
242258
// Step 4. Bind the tiles to threads
243259
int n_binds = std::min(tile_binds.size(), tiles.size());
260+
if (skipped_outer_spatial_loops.size() && n_binds) {
261+
auto& the_first_tile = tiles[0];
262+
the_first_tile.insert(the_first_tile.begin(), skipped_outer_spatial_loops.begin(),
263+
skipped_outer_spatial_loops.end());
264+
}
244265
for (int i = 0; i < n_binds; ++i) {
245266
LoopRV fused = sch->Fuse(tiles[i]);
246267
sch->Bind(fused, tile_binds[i]);

src/meta_schedule/schedule_rule/multi_level_tiling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
162162
// SubRule 1. add write cache
163163
std::vector<State> AddWriteReuse(State state) const;
164164
// SubRule 2. tile the loop nest
165-
std::vector<State> TileLoopNest(State state) const;
165+
std::vector<State> TileLoopNest(State state, int tile_inner_most_space_loop_num = -1) const;
166166
// SubRule 3. add read cache
167167
std::vector<State> AddReadReuse(State state) const;
168168
// SubRule 4. add async pipeline

src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<Sta
251251
});
252252
states = SubRule(std::move(states), [&](State state) {
253253
TensorCoreState tc_state = Downcast<TensorCoreState>(state);
254-
return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state);
254+
return tc_state->is_mma ? MMATileLoopNest(tc_state) : TileLoopNest(state, 2);
255255
});
256256
states = SubRule(std::move(states), [&](State state) {
257257
return TransformIntermediateOutputLayout(Downcast<TensorCoreState>(state));

0 commit comments

Comments
 (0)