diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 8fcecb4cb429..ed59be32b6e2 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -276,13 +276,15 @@ class IterSumExpr : public IterMapExpr { * \param predicate The predicate constraints on the input iterators * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. * \param analyzer Analyzer used to get context information. + * \param simplify_trivial_iterators If true, iterators with extent of + * 1 will be replaced with a constant value. * * \return The detected pattern if a match exists, * otherwise return an empty array. */ Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer); + arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); /*! * \brief Use IterVarMap detector to rewrite and simplify the indices * diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 85513ecae5c4..2be939a12277 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -88,7 +88,13 @@ def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) -def detect_iter_map(indices, input_iters, predicate=True, require_bijective=False): +def detect_iter_map( + indices, + input_iters, + predicate=True, + require_bijective=False, + simplify_trivial_iterators=True, +): """Detect if indices can be written as mapped iters from input iters Parameters @@ -105,13 +111,20 @@ def detect_iter_map(indices, input_iters, predicate=True, require_bijective=Fals require_bijective : bool A boolean flag that indicates whether the mapping should be bijective + simplify_trivial_iterators: bool + If true, iterators with extent of 1 will be replaced with a + constant value. + Returns ------- results : List[IterSumExpr] The iter map matching result. Empty array if no match can be found. + """ - return _ffi_api.DetectIterMap(indices, input_iters, predicate, require_bijective) + return _ffi_api.DetectIterMap( + indices, input_iters, predicate, require_bijective, simplify_trivial_iterators + ) def normalize_iter_map_to_expr(expr): diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 7694300ce043..e7a73f4ea288 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -173,12 +173,13 @@ class IterMapRewriter : public ExprMutator { public: using Parent = ExprMutator; - explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters) + explicit IterMapRewriter(Analyzer* analyzer, const Map& input_iters, + bool simplify_trivial_iterators) : analyzer_(analyzer) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; - if (is_one(vrng->extent)) { + if (simplify_trivial_iterators && is_one(vrng->extent)) { var_map_[var] = IterSumExpr({}, vrng->min); } else if (is_zero(vrng->min)) { IterMark mark(var, vrng->extent); @@ -892,7 +893,7 @@ bool IterRangeSanityCheck(const Map& iter_ranges) { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, - arith::Analyzer* analyzer) { + arith::Analyzer* analyzer, bool simplify_trivial_iterators) { // Overall detection algorithm is divided into two steps: // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns. // - Step1: IterIndependenceChecker checks if the iterator are independent. @@ -914,7 +915,7 @@ Array DetectIterMap(const Array& indices, const Map DetectIterMap(const Array& indices, const Map& indices, const Map& input_iters, - const PrimExpr& input_pred, bool is_bijective) { + const PrimExpr& input_pred, bool is_bijective, + bool simplify_trivial_iterators) { arith::Analyzer ana; - return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana); + return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana, + simplify_trivial_iterators); }); PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) { diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index 3f8f84f649d4..93f308b42d74 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -76,7 +76,9 @@ IndexMap IndexMap::Inverse(Array initial_ranges) const { // Unpack the output indices into linear combinations of the initial // indices. arith::Analyzer analyzer; - auto iter_map = DetectIterMap((*this)->final_indices, input_iters, 1, true, &analyzer); + auto iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1, + /* require_bijective = */ true, &analyzer, + /* simplify_trivial_iterators = */ false); CHECK(iter_map.size()) << "Index transformation was not bijective."; // Determine expressions for the input variables, in terms of the diff --git a/tests/python/unittest/test_transform_layout.py b/tests/python/unittest/test_transform_layout.py index 28399498c784..e7d5f125dc68 100755 --- a/tests/python/unittest/test_transform_layout.py +++ b/tests/python/unittest/test_transform_layout.py @@ -545,5 +545,35 @@ def test_transform_with_reduction(): tvm.lower(s, [A, B]) +shape, transform = tvm.testing.parameters( + ([1, 8], lambda n, i: [i, n]), + ([1, 1, 8], lambda i, j, k: [j, te.AXIS_SEPARATOR, i, k]), + ([1, 1, 8], lambda i, j, k: [i, te.AXIS_SEPARATOR, j, k]), +) + + +def test_size_one_buffer(shape, transform): + # This test is to catch a failure mode that occurred if a + # transformation were applied to a te.compute buffer, and one of + # the dimensions of the buffer was 1. Prior to bugfix, + # arith::DetectIterMap would fold the variable as a constant, + # causing an error when attempting to solve for the variable using + # arith::InverseAffineIterMap. + + dtype = "int8" + A = te.placeholder(shape, dtype, name="A") + B = te.compute( + shape=A.shape, + fcompute=lambda *indices: A[indices].astype(dtype), + name="B", + ) + s = te.create_schedule(B.op) + + # If layout transformation is on the output buffer, and any + # dimension of the output buffer is 1, failure occurs in + # CheckFusePattern. + s[B].transform_layout(transform) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv))