Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,15 @@ class IterSumExpr : public IterMapExpr {
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
* \param simplify_trivial_iterators If true, iterators with extent of
* 1 will be replaced with a constant value.
*
* \return The detected pattern if a match exists,
* otherwise return an empty array.
*/
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);
/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
*
Expand Down
9 changes: 5 additions & 4 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,13 @@ class IterMapRewriter : public ExprMutator {
public:
using Parent = ExprMutator;

explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters)
explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
bool simplify_trivial_iterators)
: analyzer_(analyzer) {
for (auto kv : input_iters) {
const Var& var = kv.first;
const Range& vrng = kv.second;
if (is_one(vrng->extent)) {
if (simplify_trivial_iterators && is_one(vrng->extent)) {
var_map_[var] = IterSumExpr({}, vrng->min);
} else if (is_zero(vrng->min)) {
IterMark mark(var, vrng->extent);
Expand Down Expand Up @@ -892,7 +893,7 @@ bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {

Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer) {
arith::Analyzer* analyzer, bool simplify_trivial_iterators) {
// Overall detection algorithm is divided into two steps:
// - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
// - Step1: IterIndependenceChecker checks if the iterator are independent.
Expand All @@ -914,7 +915,7 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });

IterMapRewriter rewriter(analyzer, constrained_input_iters);
IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators);
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
for (const IterConstraint& constraint : constraints) {
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
Expand Down
4 changes: 3 additions & 1 deletion src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
// Unpack the output indices into linear combinations of the initial
// indices.
arith::Analyzer analyzer;
auto iter_map = DetectIterMap((*this)->final_indices, input_iters, 1, true, &analyzer);
auto iter_map = DetectIterMap((*this)->final_indices, input_iters, /* predicate = */ 1,
/* require_bijective = */ true, &analyzer,
/* simplify_trivial_iterators = */ false);
CHECK(iter_map.size()) << "Index transformation was not bijective.";

// Determine expressions for the input variables, in terms of the
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,5 +545,35 @@ def test_transform_with_reduction():
tvm.lower(s, [A, B])


shape, transform = tvm.testing.parameters(
([1, 8], lambda n, i: [i, n]),
([1, 1, 8], lambda i, j, k: [j, te.AXIS_SEPARATOR, i, k]),
([1, 1, 8], lambda i, j, k: [i, te.AXIS_SEPARATOR, j, k]),
)


def test_size_one_buffer(shape, transform):
# This test is to catch a failure mode that occurred if a
# transformation were applied to a te.compute buffer, and one of
# the dimensions of the buffer was 1. Prior to bugfix,
# arith::DetectIterMap would fold the variable as a constant,
# causing an error when attempting to solve for the variable using
# arith::InverseAffineIterMap.

dtype = "int8"
A = te.placeholder(shape, dtype, name="A")
B = te.compute(
shape=A.shape,
fcompute=lambda *indices: A[indices].astype(dtype),
name="B",
)
s = te.create_schedule(B.op)

# If layout transformation is on the output buffer, and any
# dimension of the output buffer is 1, failure occurs in
# CheckFusePattern.
s[B].transform_layout(transform)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))