Skip to content

Commit df4919c

Browse files
committed
[Arith] Updated arith::DetectIterMap to keep extent=1 components
Previously, arith::DetectIterMap simplified the output expression by replacing iteration variables with extent==1 with their value. This prevented the return value from being used in arith::InverseAffineIterMap to solve for the variable, as it no longer existed in the returned expressions. This commit changes arith::DetectIterMap to keep the iteration variable even if extent==1, and adds a motivating unit test that requires this updated behavior.
1 parent 11d22bd commit df4919c

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

tests/python/unittest/test_arith_iter_affine_map.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def convert_iter_expr(expr):
5151
def assert_iter_sum_pattern(sum_expr, extent, base, scale=1):
5252
"""Check the sum expr have the right pattern."""
5353
assert isinstance(sum_expr, tvm.arith.IterSumExpr)
54-
if extent == 1:
54+
if extent is None:
5555
assert len(sum_expr.args) == 0
5656
else:
5757
assert len(sum_expr.args) == 1
@@ -69,12 +69,12 @@ def test_trivial():
6969
assert len(res) == 3
7070
assert_iter_sum_pattern(res[0], 3, 0)
7171
assert_iter_sum_pattern(res[1], 4, 0)
72-
assert_iter_sum_pattern(res[2], 1, 3)
72+
assert_iter_sum_pattern(res[2], None, 3)
7373

7474
res = tvm.arith.detect_iter_map([x[0], 3], var_dom([x, y]))
7575
assert len(res) == 2
7676
assert_iter_sum_pattern(res[0], 3, 0)
77-
assert_iter_sum_pattern(res[1], 1, 3)
77+
assert_iter_sum_pattern(res[1], None, 3)
7878

7979
# not independent
8080
res = tvm.arith.detect_iter_map([x[0], x[0], 3], var_dom([x, y]))

tests/python/unittest/test_transform_layout.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,5 +545,35 @@ def test_transform_with_reduction():
545545
tvm.lower(s, [A, B])
546546

547547

548+
shape, transform = tvm.testing.parameters(
549+
([1, 8], lambda n, i: [i, n]),
550+
([1, 1, 8], lambda i, j, k: [j, te.AXIS_SEPARATOR, i, k]),
551+
([1, 1, 8], lambda i, j, k: [i, te.AXIS_SEPARATOR, j, k]),
552+
)
553+
554+
555+
def test_size_one_buffer(shape, transform):
556+
# This test is to catch a failure mode that occurred if a
557+
# transformation were applied to a te.compute buffer, and one of
558+
# the dimensions of the buffer was 1. Prior to bugfix,
559+
# arith::DetectIterMap would fold the variable as a constant,
560+
# causing an error when attempting to solve for the variable using
561+
# arith::InverseAffineIterMap.
562+
563+
dtype = "int8"
564+
A = te.placeholder(shape, dtype, name="A")
565+
B = te.compute(
566+
shape=A.shape,
567+
fcompute=lambda *indices: A[indices].astype(dtype),
568+
name="B",
569+
)
570+
s = te.create_schedule(B.op)
571+
572+
# If layout transformation is on the output buffer, and any
573+
# dimension of the output buffer is 1, failure occurs in
574+
# CheckFusePattern.
575+
s[B].transform_layout(transform)
576+
577+
548578
if __name__ == "__main__":
549579
sys.exit(pytest.main(sys.argv))

0 commit comments

Comments
 (0)