Skip to content

Commit 8bfe3bb

Browse files
authored
[Arith] Updated arith::DetectIterMap to keep extent=1 components (#10980)
* [Arith] Updated arith::DetectIterMap to keep extent=1 components Previously, arith::DetectIterMap simplified the output expression by replacing iteration variables with extent==1 with their value. This prevented the return value from being used in arith::InverseAffineIterMap to solve for the variable, as it no longer existed in the returned expressions. This commit changes arith::DetectIterMap to keep the iteration variable even if extent==1, and adds a motivating unit test that requires this updated behavior. * Updated to retain default behavior of DetectIterMap To avoid breaking existing test cases, updated to maintain the same default behavior, but a flag to maintain trivial iterators in the result. * Updated FFI and Python API for DetectIterMap
1 parent f238900 commit 8bfe3bb

File tree

5 files changed

+60
-10
lines changed

5 files changed

+60
-10
lines changed

include/tvm/arith/iter_affine_map.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,15 @@ class IterSumExpr : public IterMapExpr {
276276
* \param predicate The predicate constraints on the input iterators
277277
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
278278
* \param analyzer Analyzer used to get context information.
279+
* \param simplify_trivial_iterators If true, iterators with extent of
280+
* 1 will be replaced with a constant value.
279281
*
280282
* \return The detected pattern if a match exists,
281283
* otherwise return an empty array.
282284
*/
283285
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
284286
const PrimExpr& predicate, bool require_bijective,
285-
arith::Analyzer* analyzer);
287+
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);
286288
/*!
287289
* \brief Use IterVarMap detector to rewrite and simplify the indices
288290
*

python/tvm/arith/iter_affine_map.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,13 @@ def __init__(self, args, base):
8888
self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)
8989

9090

91-
def detect_iter_map(indices, input_iters, predicate=True, require_bijective=False):
91+
def detect_iter_map(
92+
indices,
93+
input_iters,
94+
predicate=True,
95+
require_bijective=False,
96+
simplify_trivial_iterators=True,
97+
):
9298
"""Detect if indices can be written as mapped iters from input iters
9399
94100
Parameters
@@ -105,13 +111,20 @@ def detect_iter_map(indices, input_iters, predicate=True, require_bijective=Fals
105111
require_bijective : bool
106112
A boolean flag that indicates whether the mapping should be bijective
107113
114+
simplify_trivial_iterators: bool
115+
If true, iterators with extent of 1 will be replaced with a
116+
constant value.
117+
108118
Returns
109119
-------
110120
results : List[IterSumExpr]
111121
The iter map matching result.
112122
Empty array if no match can be found.
123+
113124
"""
114-
return _ffi_api.DetectIterMap(indices, input_iters, predicate, require_bijective)
125+
return _ffi_api.DetectIterMap(
126+
indices, input_iters, predicate, require_bijective, simplify_trivial_iterators
127+
)
115128

116129

117130
def normalize_iter_map_to_expr(expr):

src/arith/iter_affine_map.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,13 @@ class IterMapRewriter : public ExprMutator {
173173
public:
174174
using Parent = ExprMutator;
175175

176-
explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
176+
explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
177+
bool simplify_trivial_iterators)
177178
: analyzer_(analyzer) {
178179
for (auto kv : input_iters) {
179180
const Var& var = kv.first;
180181
const Range& vrng = kv.second;
181-
if (is_one(vrng->extent)) {
182+
if (simplify_trivial_iterators && is_one(vrng->extent)) {
182183
var_map_[var] = IterSumExpr({}, vrng->min);
183184
} else if (is_zero(vrng->min)) {
184185
IterMark mark(var, vrng->extent);
@@ -892,7 +893,7 @@ bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
892893

893894
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
894895
const PrimExpr& predicate, bool require_bijective,
895-
arith::Analyzer* analyzer) {
896+
arith::Analyzer* analyzer, bool simplify_trivial_iterators) {
896897
// Overall detection algorithm is divided into two steps:
897898
// - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
898899
// - Step1: IterIndependenceChecker checks if the iterator are independent.
@@ -914,7 +915,7 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
914915
constraints.begin(), constraints.end(),
915916
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });
916917

917-
IterMapRewriter rewriter(analyzer, constrained_input_iters);
918+
IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators);
918919
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
919920
for (const IterConstraint& constraint : constraints) {
920921
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
@@ -942,9 +943,11 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
942943

943944
TVM_REGISTER_GLOBAL("arith.DetectIterMap")
944945
.set_body_typed([](const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
945-
const PrimExpr& input_pred, bool is_bijective) {
946+
const PrimExpr& input_pred, bool is_bijective,
947+
bool simplify_trivial_iterators) {
946948
arith::Analyzer ana;
947-
return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana);
949+
return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana,
950+
simplify_trivial_iterators);
948951
});
949952

950953
PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) {

src/tir/ir/index_map.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
7676
// Unpack the output indices into linear combinations of the initial
7777
// indices.
7878
arith::Analyzer analyzer;
79-
auto iter_map = DetectIterMap((*this)->final_indices, input_iters, 1, true, &analyzer);
79+
auto iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1,
80+
/* require_bijective = */ true, &analyzer,
81+
/* simplify_trivial_iterators = */ false);
8082
CHECK(iter_map.size()) << "Index transformation was not bijective.";
8183

8284
// Determine expressions for the input variables, in terms of the

tests/python/unittest/test_transform_layout.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,5 +545,35 @@ def test_transform_with_reduction():
545545
tvm.lower(s, [A, B])
546546

547547

548+
shape, transform = tvm.testing.parameters(
549+
([1, 8], lambda n, i: [i, n]),
550+
([1, 1, 8], lambda i, j, k: [j, te.AXIS_SEPARATOR, i, k]),
551+
([1, 1, 8], lambda i, j, k: [i, te.AXIS_SEPARATOR, j, k]),
552+
)
553+
554+
555+
def test_size_one_buffer(shape, transform):
556+
# This test is to catch a failure mode that occurred if a
557+
# transformation were applied to a te.compute buffer, and one of
558+
# the dimensions of the buffer was 1. Prior to bugfix,
559+
# arith::DetectIterMap would fold the variable as a constant,
560+
# causing an error when attempting to solve for the variable using
561+
# arith::InverseAffineIterMap.
562+
563+
dtype = "int8"
564+
A = te.placeholder(shape, dtype, name="A")
565+
B = te.compute(
566+
shape=A.shape,
567+
fcompute=lambda *indices: A[indices].astype(dtype),
568+
name="B",
569+
)
570+
s = te.create_schedule(B.op)
571+
572+
# If layout transformation is on the output buffer, and any
573+
# dimension of the output buffer is 1, failure occurs in
574+
# CheckFusePattern.
575+
s[B].transform_layout(transform)
576+
577+
548578
if __name__ == "__main__":
549579
sys.exit(pytest.main(sys.argv))

0 commit comments

Comments
 (0)