Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,15 @@ class IterSumExpr : public IterMapExpr {
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
* \param simplify_trivial_iterators If true, iterators with extent of
* 1 will be replaced with a constant value.
*
* \return The detected pattern if a match exists,
* otherwise return an empty array.
*/
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);
/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
*
Expand Down
17 changes: 15 additions & 2 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,13 @@ def __init__(self, args, base):
self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)


def detect_iter_map(indices, input_iters, predicate=True, require_bijective=False):
def detect_iter_map(
indices,
input_iters,
predicate=True,
require_bijective=False,
simplify_trivial_iterators=True,
):
"""Detect if indices can be written as mapped iters from input iters

Parameters
Expand All @@ -105,13 +111,20 @@ def detect_iter_map(indices, input_iters, predicate=True, require_bijective=Fals
require_bijective : bool
A boolean flag that indicates whether the mapping should be bijective

simplify_trivial_iterators: bool
If true, iterators with extent of 1 will be replaced with a
constant value.

Returns
-------
results : List[IterSumExpr]
The iter map matching result.
Empty array if no match can be found.

"""
return _ffi_api.DetectIterMap(indices, input_iters, predicate, require_bijective)
return _ffi_api.DetectIterMap(
indices, input_iters, predicate, require_bijective, simplify_trivial_iterators
)


def normalize_iter_map_to_expr(expr):
Expand Down
15 changes: 9 additions & 6 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,13 @@ class IterMapRewriter : public ExprMutator {
public:
using Parent = ExprMutator;

explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
bool simplify_trivial_iterators)
: analyzer_(analyzer) {
for (auto kv : input_iters) {
const Var& var = kv.first;
const Range& vrng = kv.second;
if (is_one(vrng->extent)) {
if (simplify_trivial_iterators && is_one(vrng->extent)) {
var_map_[var] = IterSumExpr({}, vrng->min);
} else if (is_zero(vrng->min)) {
IterMark mark(var, vrng->extent);
Expand Down Expand Up @@ -892,7 +893,7 @@ bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {

Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer) {
arith::Analyzer* analyzer, bool simplify_trivial_iterators) {
// Overall detection algorithm is divided into two steps:
// - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
// - Step1: IterIndependenceChecker checks if the iterator are independent.
Expand All @@ -914,7 +915,7 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });

IterMapRewriter rewriter(analyzer, constrained_input_iters);
IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators);
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
for (const IterConstraint& constraint : constraints) {
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
Expand Down Expand Up @@ -942,9 +943,11 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,

TVM_REGISTER_GLOBAL("arith.DetectIterMap")
.set_body_typed([](const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, bool is_bijective) {
const PrimExpr& input_pred, bool is_bijective,
bool simplify_trivial_iterators) {
arith::Analyzer ana;
return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana);
return DetectIterMap(indices, input_iters, input_pred, is_bijective, &ana,
simplify_trivial_iterators);
});

PrimExpr IterMapRewriter::VisitExpr_(const VarNode* op) {
Expand Down
4 changes: 3 additions & 1 deletion src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
// Unpack the output indices into linear combinations of the initial
// indices.
arith::Analyzer analyzer;
auto iter_map = DetectIterMap((*this)->final_indices, input_iters, 1, true, &analyzer);
auto iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1,
/* require_bijective = */ true, &analyzer,
/* simplify_trivial_iterators = */ false);
CHECK(iter_map.size()) << "Index transformation was not bijective.";

// Determine expressions for the input variables, in terms of the
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,5 +545,35 @@ def test_transform_with_reduction():
tvm.lower(s, [A, B])


shape, transform = tvm.testing.parameters(
([1, 8], lambda n, i: [i, n]),
([1, 1, 8], lambda i, j, k: [j, te.AXIS_SEPARATOR, i, k]),
([1, 1, 8], lambda i, j, k: [i, te.AXIS_SEPARATOR, j, k]),
)


def test_size_one_buffer(shape, transform):
# This test is to catch a failure mode that occurred if a
# transformation were applied to a te.compute buffer, and one of
# the dimensions of the buffer was 1. Prior to bugfix,
# arith::DetectIterMap would fold the variable as a constant,
# causing an error when attempting to solve for the variable using
# arith::InverseAffineIterMap.

dtype = "int8"
A = te.placeholder(shape, dtype, name="A")
B = te.compute(
shape=A.shape,
fcompute=lambda *indices: A[indices].astype(dtype),
name="B",
)
s = te.create_schedule(B.op)

# If layout transformation is on the output buffer, and any
# dimension of the output buffer is 1, failure occurs in
# CheckFusePattern.
s[B].transform_layout(transform)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))