Skip to content

Commit afb444c

Browse files
vinx13yongwww
authored andcommitted
[TIR] Update block flags and simplify predicate in Reverse-Compute-Inline (apache#14030)
* Add simplification after substitution to make the predicate simpler to arithmetic analysis. * Update block flags. Since reverse-compute-inline may introduce predicates and the result block may not have affine binding.
1 parent 2c28efd commit afb444c

File tree

2 files changed

+83
-1
lines changed

2 files changed

+83
-1
lines changed

src/tir/schedule/primitive/compute_inline.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ class ReverseComputeInliner : public BaseInliner {
608608
/*indices=*/buffer_load_indices_,
609609
/*input_iters=*/consumer_iter_doms,
610610
/*predicate=*/true,
611-
/*check_level=*/arith::IterMapLevel::Bijective,
611+
/*check_level=*/arith::IterMapLevel::NoCheck,
612612
/*analyzer=*/&analyzer_,
613613
/*simplify_trivial_iterators=*/false);
614614
buffer_load_iter_map_ = res->indices;
@@ -651,6 +651,7 @@ class ReverseComputeInliner : public BaseInliner {
651651
// Substitute the producer block iters with the its bindings since the predicate in BlockRealize
652652
// should not contain the block iters
653653
predicate = Substitute(predicate, subst_map);
654+
predicate = analyzer_.Simplify(predicate);
654655
return predicate;
655656
}
656657

@@ -865,6 +866,13 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
865866
return;
866867
}
867868
self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
869+
// Step 8. Update the cached flags
870+
arith::Analyzer analyzer;
871+
BlockInfo& block_info = self->block_info[producer_block_sref];
872+
block_info.affine_binding = IsAffineBinding(
873+
/*realize=*/GetBlockRealize(self, producer_block_sref),
874+
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(producer_block_sref->parent)),
875+
/*analyzer=*/&analyzer);
868876
}
869877

870878
bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) {

tests/python/unittest/test_tir_schedule_compute_inline.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,60 @@ def elementwise_overcomputed_producer_reverse_inlined(
611611
C[vi, vj] = A[vi, vj] * 2.0 + 1.0
612612

613613

614+
@T.prim_func
615+
def elementwise_overcomputed_producer_simplify_predicate(
616+
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
617+
) -> None:
618+
B = T.alloc_buffer((128, 128))
619+
for i in T.grid(16384):
620+
with T.block("B"):
621+
vi = T.axis.spatial(128, i // 128)
622+
vj = T.axis.spatial(128, i % 128)
623+
B[vi, vj] = A[vi, vj] * 2.0
624+
for i, j in T.grid(127, 127):
625+
with T.block("C"):
626+
cvi, cvj = T.axis.remap("SS", [i, j])
627+
C[cvi, cvj] = B[cvi, cvj] + 1.0
628+
629+
630+
@T.prim_func
631+
def elementwise_overcomputed_producer_simplify_predicate_reverse_inlined(
632+
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
633+
) -> None:
634+
for i in T.grid(16384):
635+
with T.block("B"):
636+
vi = T.axis.spatial(128, i // 128)
637+
vj = T.axis.spatial(128, i % 128)
638+
T.where(i < 16255 and i % 128 < 127)
639+
C[vi, vj] = A[vi, vj] * 2.0 + 1.0
640+
641+
642+
@T.prim_func
643+
def elementwise_overcomputed_producer_injective_load(
644+
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
645+
) -> None:
646+
B = T.alloc_buffer((8, 8, 16, 16))
647+
for i0, j0, i1, j1 in T.grid(8, 8, 16, 16):
648+
with T.block("B"):
649+
vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1])
650+
B[vi, vj, vm, vn] = A[vi * 16 + vm, vj * 16 + vn] * 2.0
651+
for i, j in T.grid(127, 127):
652+
with T.block("C"):
653+
cvi, cvj = T.axis.remap("SS", [i, j])
654+
C[cvi, cvj] = B[cvi // 16, cvj // 16, cvi % 16, cvj % 16] + 1.0
655+
656+
657+
@T.prim_func
658+
def elementwise_overcomputed_producer_injective_load_reverse_inlined(
659+
A: T.Buffer((128, 128), "float32"), C: T.Buffer((127, 127), "float32")
660+
) -> None:
661+
for i0, j0, i1, j1 in T.grid(8, 8, 16, 16):
662+
with T.block("B"):
663+
vi, vj, vm, vn = T.axis.remap("SSSS", [i0, j0, i1, j1])
664+
T.where(i0 * 16 + i1 < 127 and j0 * 16 + j1 < 127)
665+
C[vm + vi * 16, vn + vj * 16] = A[vi * 16 + vm, vj * 16 + vn] * 2.0 + 1.0
666+
667+
614668
@T.prim_func
615669
def elementwise_producer_not_cover_consumer(
616670
A: T.Buffer((128, 128), "float32"), D: T.Buffer((256, 128), "float32")
@@ -1025,6 +1079,26 @@ def test_reverse_compute_inline_overcomputed_producer(use_block_name):
10251079
)
10261080

10271081

1082+
def test_reverse_compute_inline_overcomputed_producer_simplify_predicate(use_block_name):
1083+
"""Test reverse compute inline overcomputed producer where the predicate should be simplified"""
1084+
sch = tir.Schedule(elementwise_overcomputed_producer_simplify_predicate, debug_mask="all")
1085+
compute = "C" if use_block_name else sch.get_block("C")
1086+
sch.reverse_compute_inline(compute)
1087+
tvm.ir.assert_structural_equal(
1088+
elementwise_overcomputed_producer_simplify_predicate_reverse_inlined, sch.mod["main"]
1089+
)
1090+
1091+
1092+
def test_reverse_compute_inline_overcomputed_producer_injective_load(use_block_name):
1093+
"""Test reverse compute inline overcomputed producer with injective buffer load"""
1094+
sch = tir.Schedule(elementwise_overcomputed_producer_injective_load, debug_mask="all")
1095+
compute = "C" if use_block_name else sch.get_block("C")
1096+
sch.reverse_compute_inline(compute)
1097+
tvm.ir.assert_structural_equal(
1098+
elementwise_overcomputed_producer_injective_load_reverse_inlined, sch.mod["main"]
1099+
)
1100+
1101+
10281102
def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name):
10291103
"""Test reverse compute inline failure when the inlined block iter domains are not covered by
10301104
its producer

0 commit comments

Comments
 (0)