Skip to content

Commit 422ca28

Browse files
[TIR][Schedule] Fix reverse_compute_inline (#14263)
We can not reverse compute inline a block whose producer is an output block, since its content is visible to the caller.
1 parent 594bc0f commit 422ca28

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

src/tir/schedule/primitive/compute_inline.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,9 +844,11 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
844844
NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref);
845845
// Step 2. Check completeness
846846
CheckCompleteBlock(self, consumer_block_sref, scope_root_sref);
847-
// Step 3. Check if the consumer has a single complete producer
847+
// Step 3. Check if the consumer has a single complete producer, and the producer is not an output
848+
// block
848849
StmtSRef producer_block_sref =
849850
NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref);
851+
CheckNotOutputBlock(self, producer_block_sref, scope_root_sref);
850852
// Step 4. Analyze the block body
851853
ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs<BlockNode>(),
852854
consumer_block_realize, scope_root_sref, self->mod);

tests/python/unittest/test_tir_schedule_compute_inline.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,21 @@ def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None
503503
compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
504504

505505

506+
@T.prim_func
507+
def elementwise_output(a: T.handle, b: T.handle, c: T.handle) -> None:
508+
A = T.match_buffer(a, (128, 128))
509+
B = T.match_buffer(b, (128, 128))
510+
C = T.match_buffer(c, (128, 128))
511+
for i, j in T.grid(128, 128):
512+
with T.block("B"):
513+
vi, vj = T.axis.remap("SS", [i, j])
514+
B[vi, vj] = A[vi, vj] * 2.0
515+
for i, j in T.grid(128, 128):
516+
with T.block("C"):
517+
vi, vj = T.axis.remap("SS", [i, j])
518+
C[vi, vj] = B[vi, vj] + 1.0
519+
520+
506521
@T.prim_func
507522
def inline_block_with_init(
508523
A: T.Buffer((1, 512, 7, 7), "float32"),
@@ -1027,6 +1042,15 @@ def test_output_block(use_block_name):
10271042
with pytest.raises(tvm.tir.ScheduleError):
10281043
sch.compute_inline(block)
10291044

1045+
sch = tir.Schedule(elementwise_output, debug_mask="all")
1046+
block = sch.get_block("B")
1047+
with pytest.raises(tvm.tir.ScheduleError):
1048+
sch.compute_inline(block)
1049+
1050+
block = sch.get_block("C")
1051+
with pytest.raises(tvm.tir.ScheduleError):
1052+
sch.reverse_compute_inline(block)
1053+
10301054

10311055
def test_compute_inline_predicate(use_block_name):
10321056
sch = tir.Schedule(elementwise_predicate, debug_mask="all")

0 commit comments

Comments
 (0)