Skip to content

Commit 48a16f1

Browse files
rebase upstream
1 parent 8d46bb5 commit 48a16f1

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

src/tir/schedule/primitive/layout_transformation.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,9 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
392392

393393
auto iter_map = arith::DetectIterMap(
394394
/*indices=*/transformed_block_iters, /*input_iters=*/block_iter_dom, /*predicate=*/Bool(true),
395-
/*require_bijective=*/true, &analyzer, /*simplify_trivial_iterators=*/true);
396-
if (iter_map.empty()) {
395+
/*check_level=*/arith::IterMapLevel::Bijective, &analyzer,
396+
/*simplify_trivial_iterators=*/true);
397+
if (iter_map->indices.empty()) {
397398
throw NotBijectiveAffineIndexMapError(self->mod, index_map);
398399
}
399400

@@ -417,7 +418,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
417418
// Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters
418419
// in the body.
419420

420-
auto inverse_map = arith::InverseAffineIterMap(iter_map, new_block_vars);
421+
auto inverse_map = arith::InverseAffineIterMap(iter_map->indices, new_block_vars);
421422
// Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant
422423
// zero.
423424
for (const auto& iter_var : block_ptr->iter_vars) {

tests/python/unittest/test_tir_schedule_compute_at.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,13 +1281,10 @@ def grouped_channel_bias_non_perfect_tiled(
12811281
cc = T.axis.spatial(720, c_o * 360 + c_i)
12821282
Y[cc, hh, ww] = X[cc, hh, ww] + B[cc // 16]
12831283

1284-
def check_sched(debug_mask):
1285-
sch = tir.Schedule(grouped_channel_bias, debug_mask=debug_mask)
1286-
loop = sch.get_loops(sch.get_block("compute"))[0]
1287-
sch.compute_at(sch.get_block("init"), loop)
1288-
tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled)
1289-
1290-
check_sched("all")
1284+
sch = tir.Schedule(grouped_channel_bias, debug_mask=debug_mask)
1285+
loop = sch.get_loops(sch.get_block("compute"))[0]
1286+
sch.compute_at(sch.get_block("init"), loop)
1287+
tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled)
12911288

12921289

12931290
def test_fail_subtree_complete_block():

0 commit comments

Comments
 (0)