@@ -190,7 +190,8 @@ std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> MultiLevelTilingNode::SplitLoo
190190 return {factors, splits};
191191}
192192
193- std::vector<State> MultiLevelTilingNode::TileLoopNest (State state) const {
193+ std::vector<State> MultiLevelTilingNode::TileLoopNest (State state,
194+ int tile_inner_most_space_loop_num) const {
194195 Schedule& sch = state->sch ;
195196 const BlockRV& block_rv = state->block_rv ;
196197 // Step 1. Assuming trivial binding, pair the loops and their iter-var-types
@@ -199,6 +200,16 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
199200 ICHECK_EQ (loops.size (), iter_types.size ());
200201 // Step 2. For each loop axis, tile it
201202 int64_t spatial_loop_product = 1 ;
203+
204+ int total_spatial_loop_num = 0 ;
205+ std::for_each (iter_types.begin (), iter_types.end (), [&](const auto & iter_type) {
206+ if (iter_type == IterVarType::kDataPar ) total_spatial_loop_num++;
207+ });
208+ CHECK_GE (total_spatial_loop_num, tile_inner_most_space_loop_num);
209+ if (tile_inner_most_space_loop_num < 0 ) tile_inner_most_space_loop_num = total_spatial_loop_num;
210+ int outer_most_spatial_loop_skipped_num = total_spatial_loop_num - tile_inner_most_space_loop_num;
211+
212+ Array<LoopRV> skipped_outer_spatial_loops;
202213 std::vector<Array<LoopRV>> tiles (s_indices_.size () + r_indices_.size ());
203214 state->tile_factors .resize (tiles.size ());
204215 std::vector<Array<tir::ExprRV>> tile_factors;
@@ -208,6 +219,11 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
208219 const std::vector<int >* idx = nullptr ;
209220
210221 if (iter_types[i] == IterVarType::kDataPar ) {
222+ if (outer_most_spatial_loop_skipped_num > 0 ) {
223+ skipped_outer_spatial_loops.push_back (loop);
224+ outer_most_spatial_loop_skipped_num--;
225+ continue ;
226+ }
211227 idx = &s_indices_;
212228 if (spatial_loop_product != -1 ) {
213229 if (const int64_t * extent = tir::GetLoopIntExtent (sch->Get (loop).get ())) {
@@ -241,6 +257,11 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
241257 sch->Reorder (support::ConcatArrayList<LoopRV>(tiles.begin (), tiles.end ()));
242258 // Step 4. Bind the tiles to threads
243259 int n_binds = std::min (tile_binds.size (), tiles.size ());
260+ if (skipped_outer_spatial_loops.size () && n_binds) {
261+ auto & the_first_tile = tiles[0 ];
262+ the_first_tile.insert (the_first_tile.begin (), skipped_outer_spatial_loops.begin (),
263+ skipped_outer_spatial_loops.end ());
264+ }
244265 for (int i = 0 ; i < n_binds; ++i) {
245266 LoopRV fused = sch->Fuse (tiles[i]);
246267 sch->Bind (fused, tile_binds[i]);
0 commit comments