Skip to content

Commit 7da3a79

Browse files
Fix bug when decompose padding wrt the single child subtree
1 parent 6161a8d commit 7da3a79

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

src/tir/schedule/primitive/decompose_padding.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref,
442442
if (!found_const_filling_pos) {
443443
if (cur_loop.same_as(const_filling_pos)) {
444444
found_const_filling_pos = true;
445+
found_in_bound_filling_pos = true;
445446
}
446447
}
447448

tests/python/unittest/test_tir_schedule_decompose_padding.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,5 +309,48 @@ def pooling_decompose_3(
309309
check_decompose_padding(sum_pool_2d, sch.mod["main"], pooling_decompose_3, check_run=True)
310310

311311

312+
def test_decompose_wrt_single_child_subtree():
313+
"""Test the case when the decompose position is under the same subtree"""
314+
315+
@T.prim_func
316+
def pad_op(
317+
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer([1, 16, 231, 231], dtype="int8")
318+
):
319+
for i0, i1, i2, i3 in T.grid(1, 16, 231, 231):
320+
with T.block("pad_temp"):
321+
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
322+
y[ax0, ax1, ax2, ax3] = T.if_then_else(
323+
3 <= ax2 and ax2 < 228 and 3 <= ax3 and ax3 < 228,
324+
x[ax0, ax1, ax2 - 3, ax3 - 3],
325+
T.int8(0),
326+
dtype="int8",
327+
)
328+
329+
@T.prim_func
330+
def pad_op_after(
331+
x: T.Buffer[(1, 16, 225, 225), "int8"], y: T.Buffer[(1, 16, 231, 231), "int8"]
332+
):
333+
for i0, i1 in T.grid(1, 16):
334+
for i2, i3 in T.grid(231, 231):
335+
with T.block("pad_temp_pad_const"):
336+
ax0 = T.axis.spatial(1, 0)
337+
ax1, ax2, ax3 = T.axis.remap("SSS", [i1, i2, i3])
338+
y[ax0, ax1, ax2, ax3] = T.int8(0)
339+
for i2, i3 in T.grid(231, 225):
340+
with T.block("pad_temp"):
341+
T.where(3 <= i2 and i2 < 228)
342+
ax0 = T.axis.spatial(1, 0)
343+
ax1 = T.axis.spatial(16, i1)
344+
ax2 = T.axis.spatial(225, i2 - 3)
345+
ax3 = T.axis.spatial(225, i3)
346+
y[ax0, ax1, ax2 + 3, ax3 + 3] = x[ax0, ax1, ax2, ax3]
347+
348+
sch = tir.Schedule(pad_op, debug_mask="all")
349+
pad = sch.get_block("pad_temp")
350+
_, _, h, _ = sch.get_loops(pad)
351+
sch.decompose_padding(pad, h)
352+
check_decompose_padding(pad_op, sch.mod["main"], pad_op_after, check_run=True)
353+
354+
312355
if __name__ == "__main__":
313356
tvm.testing.main()

0 commit comments

Comments
 (0)