Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ class ReverseComputeInliner : public BaseInliner {
/*indices=*/buffer_load_indices_,
/*input_iters=*/consumer_iter_doms,
/*predicate=*/true,
/*check_level=*/arith::IterMapLevel::Bijective,
/*check_level=*/arith::IterMapLevel::NoCheck,
/*analyzer=*/&analyzer_,
/*simplify_trivial_iterators=*/false);
buffer_load_iter_map_ = res->indices;
Expand Down Expand Up @@ -651,6 +651,7 @@ class ReverseComputeInliner : public BaseInliner {
// Substitute the producer block iters with the its bindings since the predicate in BlockRealize
// should not contain the block iters
predicate = Substitute(predicate, subst_map);
predicate = analyzer_.Simplify(predicate);
return predicate;
}

Expand Down Expand Up @@ -865,6 +866,13 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
return;
}
self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
// Step 8. Update the cached flags
arith::Analyzer analyzer;
BlockInfo& block_info = self->block_info[producer_block_sref];
block_info.affine_binding = IsAffineBinding(
/*realize=*/GetBlockRealize(self, producer_block_sref),
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(producer_block_sref->parent)),
/*analyzer=*/&analyzer);
}

bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) {
Expand Down
74 changes: 74 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,60 @@ def elementwise_overcomputed_producer_reverse_inlined(
C[vi, vj] = A[vi, vj] * 2.0 + 1.0


@T.prim_func
def elementwise_overcomputed_producer_simplify_predicate(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
) -> None:
B = T.alloc_buffer((128, 128))
for i in T.grid(16384):
with T.block("B"):
vi = T.axis.spatial(128, i // 128)
vj = T.axis.spatial(128, i % 128)
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(127, 127):
with T.block("C"):
cvi, cvj = T.axis.remap("SS", [i, j])
C[cvi, cvj] = B[cvi, cvj] + 1.0


@T.prim_func
def elementwise_overcomputed_producer_simplify_predicate_reverse_inlined(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
) -> None:
for i in T.grid(16384):
with T.block("B"):
vi = T.axis.spatial(128, i // 128)
vj = T.axis.spatial(128, i % 128)
T.where(i < 16255 and i % 128 < 127)
C[vi, vj] = A[vi, vj] * 2.0 + 1.0


@T.prim_func
def elementwise_overcomputed_producer_injective_load(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
) -> None:
B = T.alloc_buffer((8, 8, 16, 16))
for i0, j0, i1, j1 in T.grid(8, 8, 16, 16):
with T.block("B"):
vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1])
B[vi, vj, vm, vn] = A[vi * 16 + vm, vj * 16 + vn] * 2.0
for i, j in T.grid(127, 127):
with T.block("C"):
cvi, cvj = T.axis.remap("SS", [i, j])
C[cvi, cvj] = B[cvi // 16, cvj // 16, cvi % 16, cvj % 16] + 1.0


@T.prim_func
def elementwise_overcomputed_producer_injective_load_reverse_inlined(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
) -> None:
for i0, j0, i1, j1 in T.grid(8, 8, 16, 16):
with T.block("B"):
vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1])
T.where(i0 * 16 + i1 < 127 and j0 * 16 + j1 < 127)
C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 + 1.0


@T.prim_func
def elementwise_producer_not_cover_consumer(
A: T.Buffer((128, 128), "float32"), D: T.Buffer((256, 128), "float32")
Expand Down Expand Up @@ -1025,6 +1079,26 @@ def test_reverse_compute_inline_overcomputed_producer(use_block_name):
)


def test_reverse_compute_inline_overcomputed_producer_simplify_predicate(use_block_name):
"""Test reverse compute inline overcomputed producer where the predicate should be simplified"""
sch = tir.Schedule(elementwise_overcomputed_producer_simplify_predicate, debug_mask="all")
compute = "C" if use_block_name else sch.get_block("C")
sch.reverse_compute_inline(compute)
tvm.ir.assert_structural_equal(
elementwise_overcomputed_producer_simplify_predicate_reverse_inlined, sch.mod["main"]
)


def test_reverse_compute_inline_overcomputed_producer_injective_load(use_block_name):
"""Test reverse compute inline overcomputed producer with injective buffer load"""
sch = tir.Schedule(elementwise_overcomputed_producer_injective_load, debug_mask="all")
compute = "C" if use_block_name else sch.get_block("C")
sch.reverse_compute_inline(compute)
tvm.ir.assert_structural_equal(
elementwise_overcomputed_producer_injective_load_reverse_inlined, sch.mod["main"]
)


def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name):
"""Test reverse compute inline failure when the inlined block iter domains are not covered by
its producer
Expand Down