From bbe87afdb29d57cf74849265275401e6be9ca461 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 25 Mar 2022 10:55:17 -0500 Subject: [PATCH 1/3] [Arith] Updated arith::DetectIterMap to keep extent=1 components Previously, arith::DetectIterMap simplified the output expression by replacing iteration variables with extent==1 with their value. This prevented the return value from being used in arith::InverseAffineIterMap to solve for the variable, as it no longer existed in the returned expressions. This commit changes arith::DetectIterMap to keep the iteration variable even if extent==1, and adds a motivating unit test that requires this updated behavior. --- src/arith/iter_affine_map.cc | 4 +-- .../unittest/test_arith_iter_affine_map.py | 6 ++-- .../python/unittest/test_transform_layout.py | 30 +++++++++++++++++++ 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 7694300ce043..63c9ada1275a 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -178,9 +178,7 @@ class IterMapRewriter : public ExprMutator { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; - if (is_one(vrng->extent)) { - var_map_[var] = IterSumExpr({}, vrng->min); - } else if (is_zero(vrng->min)) { + if (is_zero(vrng->min)) { IterMark mark(var, vrng->extent); var_map_[var] = IterSplitExpr(mark); input_marks_.push_back(mark); diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 3dd6ee1c2b59..ec4740254b4b 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -51,7 +51,7 @@ def convert_iter_expr(expr): def assert_iter_sum_pattern(sum_expr, extent, base, scale=1): """Check the sum expr have the right pattern.""" assert isinstance(sum_expr, tvm.arith.IterSumExpr) - if extent == 1: + if extent is None: assert len(sum_expr.args) == 0 else: assert len(sum_expr.args) == 1 @@ -69,12 +69,12 @@ def test_trivial(): assert len(res) == 3 assert_iter_sum_pattern(res[0], 3, 0) assert_iter_sum_pattern(res[1], 4, 0) - assert_iter_sum_pattern(res[2], 1, 3) + assert_iter_sum_pattern(res[2], None, 3) res = tvm.arith.detect_iter_map([x[0], 3], var_dom([x, y])) assert len(res) == 2 assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], 1, 3) + assert_iter_sum_pattern(res[1], None, 3) # not independent res = tvm.arith.detect_iter_map([x[0], x[0], 3], var_dom([x, y])) diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py index 28399498c784..e7d5f125dc68 100755 --- a/tests/python/unittest/test_transform_layout.py +++ b/tests/python/unittest/test_transform_layout.py @@ -545,5 +545,35 @@ def test_transform_with_reduction(): tvm.lower(s, [A, B]) +shape, transform = tvm.testing.parameters( + ([1, 8], lambda n, i: [i, n]), + ([1, 1, 8], lambda i, j, k: [j, te.AXIS_SEPARATOR, i, k]), + ([1, 1, 8], lambda i, j, k: [i, te.AXIS_SEPARATOR, j, k]), +) + + +def test_size_one_buffer(shape, transform): + # This test is to catch a failure mode that occurred if a + # transformation were applied to a te.compute buffer, and one of + # the dimensions of the buffer was 1. Prior to bugfix, + # arith::DetectIterMap would fold the variable as a constant, + # causing an error when attempting to solve for the variable using + # arith::InverseAffineIterMap. + + dtype = "int8" + A = te.placeholder(shape, dtype, name="A") + B = te.compute( + shape=A.shape, + fcompute=lambda *indices: A[indices].astype(dtype), + name="B", + ) + s = te.create_schedule(B.op) + + # If layout transformation is on the output buffer, and any + # dimension of the output buffer is 1, failure occurs in + # CheckFusePattern. + s[B].transform_layout(transform) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) From 98740d78015594500135854aa2dcf4bb43b1edc7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 13 Apr 2022 09:53:59 -0500 Subject: [PATCH 2/3] Updated to retain default behavior of DetectIterMap To avoid breaking existing test cases, updated to maintain the same default behavior, but a flag to maintain trivial iterators in the result. --- include/tvm/arith/iter_affine_map.h | 4 +++- src/arith/iter_affine_map.cc | 11 +++++++---- src/tir/ir/index_map.cc | 4 +++- tests/python/unittest/test_arith_iter_affine_map.py | 6 +++--- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 8fcecb4cb429..ed59be32b6e2 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -276,13 +276,15 @@ class IterSumExpr : public IterMapExpr { * \param predicate The predicate constraints on the input iterators * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. * \param analyzer Analyzer used to get context information. + * \param simplify_trivial_iterators If true, iterators with extent of + * 1 will be replaced with a constant value. * * \return The detected pattern if a match exists, * otherwise return an empty array. */ Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer); + arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); /*! * \brief Use IterVarMap detector to rewrite and simplify the indices * diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 63c9ada1275a..12c6fabb4834 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -173,12 +173,15 @@ class IterMapRewriter : public ExprMutator { public: using Parent = ExprMutator; - explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters) + explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, + bool simplify_trivial_iterators) : analyzer_(analyzer) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; - if (is_zero(vrng->min)) { + if (simplify_trivial_iterators && is_one(vrng->extent)) { + var_map_[var] = IterSumExpr({}, vrng->min); + } else if (is_zero(vrng->min)) { IterMark mark(var, vrng->extent); var_map_[var] = IterSplitExpr(mark); input_marks_.push_back(mark); @@ -890,7 +893,7 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer) { + arith::Analyzer* analyzer, bool simplify_trivial_iterators) { // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. @@ -912,7 +915,7 @@ Array DetectIterMap(const Array& indices, const Map initial_ranges) const { // Unpack the output indices into linear combinations of the initial // indices. arith::Analyzer analyzer; - auto iter_map = DetectIterMap((*this)->final_indices, input_iters, 1, true, &analyzer); + auto iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, + /* require_bijective = */ true, &analyzer, + /* simplify_trivial_iterators = */ false); CHECK(iter_map.size()) << "Index transformation was not bijective."; // Determine expressions for the input variables, in terms of the diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index ec4740254b4b..3dd6ee1c2b59 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -51,7 +51,7 @@ def convert_iter_expr(expr): def assert_iter_sum_pattern(sum_expr, extent, base, scale=1): """Check the sum expr have the right pattern.""" assert isinstance(sum_expr, tvm.arith.IterSumExpr) - if extent is None: + if extent == 1: assert len(sum_expr.args) == 0 else: assert len(sum_expr.args) == 1 @@ -69,12 +69,12 @@ def test_trivial(): assert len(res) == 3 assert_iter_sum_pattern(res[0], 3, 0) assert_iter_sum_pattern(res[1], 4, 0) - assert_iter_sum_pattern(res[2], None, 3) + assert_iter_sum_pattern(res[2], 1, 3) res = tvm.arith.detect_iter_map([x[0], 3], var_dom([x, y])) assert len(res) == 2 assert_iter_sum_pattern(res[0], 3, 0) - assert_iter_sum_pattern(res[1], None, 3) + assert_iter_sum_pattern(res[1], 1, 3) # not independent res = tvm.arith.detect_iter_map([x[0], x[0], 3], var_dom([x, y])) From 10e7a9a4b6760e1e592f859305b4d62761f95335 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 15 Apr 2022 08:43:02 -0500 Subject: [PATCH 3/3] Updated FFI and Python API for DetectIterMap --- python/tvm/arith/iter_affine_map.py | 17 +++++++++++++++-- src/arith/iter_affine_map.cc | 6 ++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 85513ecae5c4..2be939a12277 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -88,7 +88,13 @@ def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) -def detect_iter_map(indices, input_iters, predicate=True, require_bijective=False): +def detect_iter_map( + indices, + input_iters, + predicate=True, + require_bijective=False, + simplify_trivial_iterators=True, +): """Detect if indices can be written as mapped iters from input iters Parameters @@ -105,13 +111,20 @@ def detect_iter_map(indices, input_iters, predicate=True, require_bijective=Fals require_bijective : bool A boolean flag that indicates whether the mapping should be bijective + simplify_trivial_iterators: bool + If true, iterators with extent of 1 will be replaced with a + constant value. + Returns ------- results : List[IterSumExpr] The iter map matching result. Empty array if no match can be found. + """ - return _ffi_api.DetectIterMap(indices, input_iters, predicate, require_bijective) + return _ffi_api.DetectIterMap( + indices, input_iters, predicate, require_bijective, simplify_trivial_iterators + ) def normalize_iter_map_to_expr(expr): diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 12c6fabb4834..e7a73f4ea288 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -943,9 +943,11 @@ Array DetectIterMap(const Array& indices, const Map& indices, const Map& input_iters, - const PrimExpr& input_pred, bool is_bijective) { + const PrimExpr& input_pred, bool is_bijective, + bool simplify_trivial_iterators) { arith::Analyzer ana; - return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana); + return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana, + simplify_trivial_iterators); }); PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) {