From f7d2c487207a57c9c2b4c03f23bba95475689621 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 29 Jun 2019 16:12:43 -0700 Subject: [PATCH] [ARITH] Improve min/max/div cases in RewriteSimplify (#3463) [PASS] Use new infra for lower warp memory [ARITH] EvalSet recursively evaluates set in case dom_map contains set that need to be relaxed. --- src/arithmetic/int_set.cc | 29 +++++-- src/arithmetic/rewrite_simplify.cc | 22 ++++++ src/pass/lower_warp_memory.cc | 77 +++++++++++++++---- .../unittest/test_arith_rewrite_simplify.py | 11 +++ ...implify.py => test_arith_stmt_simplify.py} | 0 .../unittest/test_pass_inject_copy_intrin.py | 1 - 6 files changed, 117 insertions(+), 23 deletions(-) rename tests/python/unittest/{test_pass_simplify.py => test_arith_stmt_simplify.py} (100%) diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index 75a4aaf83ab6..b81deb40ee2a 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -305,6 +305,16 @@ class IntervalSetEvaluator : IntervalSet Eval(const Expr& val) { return this->VisitExpr(val); } + // evaluate and relax the set + IntervalSet Eval(IntervalSet val) { + // avoid recursive indefinite recursive expansion. + if (static_cast(recur_depth_) >= dom_map_.size()) return val; + ++recur_depth_; + IntervalSet min_set = this->Eval(val->min_value); + IntervalSet max_set = this->Eval(val->max_value); + --recur_depth_; + return IntervalSet(min_set->min_value, max_set->max_value); + } IntervalSet VisitExpr_(const IntImm* op) final { return IntervalSet::SinglePoint(GetRef(op)); @@ -318,7 +328,14 @@ class IntervalSetEvaluator : Var var = GetRef(op); auto it = dom_map_.find(var); if (it != dom_map_.end()) { - return ToIntervalSet((*it).second); + IntervalSet res = ToIntervalSet((*it).second); + if (res->min_value.same_as(var) && + res->max_value.same_as(var)) { + return res; + } + // recursively evaluate mapped result + // in case the domain contains variables to be relaxed. + return Eval(res); } else { return IntervalSet::SinglePoint(var); } @@ -440,6 +457,9 @@ class IntervalSetEvaluator : return Combine(analyzer_, a, b); } + // recursive depth + int recur_depth_{0}; + // analyzer Analyzer* analyzer_; const Map& dom_map_; bool eval_vec_{false}; @@ -662,13 +682,10 @@ IntSet EvalSet(Range r, const Map& dom_map) { Analyzer ana; IntervalSetEvaluator m(&ana, dom_map); - IntervalSet min_set = m.Eval(r->min); // Simplifying first can give tighter bounds if r->min and r->extent share variables Expr sum = r->min + r->extent - 1; - IntervalSet max_set = m.Eval(Simplify(sum)); - if (!min_set->HasLowerBound()) return IntSet::everything(); - if (!max_set->HasUpperBound()) return IntSet::everything(); - return IntervalSet(min_set->min_value, max_set->max_value); + auto res = m.Eval(IntervalSet(r->min, Simplify(sum))); + return res; } IntSet EvalSet(Range r, diff --git a/src/arithmetic/rewrite_simplify.cc b/src/arithmetic/rewrite_simplify.cc index ea6530631880..28581e47a02a 100644 --- a/src/arithmetic/rewrite_simplify.cc +++ b/src/arithmetic/rewrite_simplify.cc @@ -147,6 +147,17 @@ Mutate_(const Add* op, const Expr& self) { TVM_TRY_REWRITE(min(x - z, y) + z, min(x, y + z)); TVM_TRY_REWRITE(max(x, y - z) + z, max(x + z, y)); TVM_TRY_REWRITE(max(x - z, y) + z, max(x, y + z)); + + + TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(y + z * c1, x) + z * c2, min(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(y + z * c1, x) + z * c2, max(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y); TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y); TVM_TRY_REWRITE(max(x, y) + min(y, x), x + y); @@ -265,6 +276,11 @@ Mutate_(const Sub* op, const Expr& self) { TVM_TRY_REWRITE(min(z, x + y) - x, min(z - x, y)); TVM_TRY_REWRITE(min(z, y + x) - x, min(z - x, y)); + TVM_TRY_REWRITE(max(x + y, z) - x, max(y, z - x)); + TVM_TRY_REWRITE(max(y + x, z) - x, max(y, z - x)); + TVM_TRY_REWRITE(max(z, x + y) - x, max(z - x, y)); + TVM_TRY_REWRITE(max(z, y + x) - x, max(z - x, y)); + TVM_TRY_REWRITE(x - min(x + y, z), max(0 - y, x - z)); TVM_TRY_REWRITE(x - min(y + x, z), max(0 - y, x - z)); TVM_TRY_REWRITE(x - min(z, x + y), max(x - z, 0 - y)); @@ -397,6 +413,12 @@ Mutate_(const Div* op, const Expr& self) { // Pattern var for lanes in broadcast and ramp PVar lanes; + // x / 2.0 = x * 0.5 + if (const FloatImm* ptr = op->b.as()) { + CHECK(op->type.is_float()); + return op->a * make_const(op->b.type(), 1.0 / ptr->value); + } + // Vector rules if (op->type.lanes() != 1) { TVM_TRY_REWRITE(broadcast(x, lanes) / broadcast(y, lanes), diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc index 2c795ba5dab7..7d9d48600f71 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/pass/lower_warp_memory.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -80,8 +80,11 @@ namespace ir { class WarpStoreCoeffFinder : private IRVisitor { public: WarpStoreCoeffFinder(const Variable* buffer, - Var warp_index) - : buffer_(buffer), warp_index_(warp_index) { + Var warp_index, + arith::Analyzer* analyzer) + : buffer_(buffer), + warp_index_(warp_index), + analyzer_(analyzer) { } // find the warp co-efficient in the statement given the warp size int Find(const Stmt& stmt) { @@ -113,7 +116,7 @@ class WarpStoreCoeffFinder : private IRVisitor { CHECK_EQ(m.size(), 2U) << "LowerWarpMemory failed due to store index=" << index; int coeff = 0; - Expr mcoeff = ir::Simplify(m[0]); + Expr mcoeff = analyzer_->canonical_simplify(m[0]); CHECK(arith::GetConstInt(mcoeff, &coeff) && coeff > 0) << "LowerWarpMemory failed due to store index=" << index @@ -134,6 +137,8 @@ class WarpStoreCoeffFinder : private IRVisitor { Var warp_index_; // the coefficient int warp_coeff_{0}; + // analyzer. + arith::Analyzer* analyzer_; }; @@ -184,8 +189,8 @@ class WarpIndexFinder : private IRVisitor { // Mutator to change the read pattern class WarpAccessRewriter : protected IRMutator { public: - explicit WarpAccessRewriter(int warp_size) - : warp_size_(warp_size) {} + explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer) + : warp_size_(warp_size), analyzer_(analyzer) {} // Rewrite the allocate statement which transforms // warp memory to local memory. Stmt Rewrite(const Allocate* op, const Stmt& stmt) { @@ -196,7 +201,7 @@ class WarpAccessRewriter : protected IRMutator { alloc_size *= op->type.lanes(); warp_index_ = WarpIndexFinder(warp_size_).Find(op->body)->var; warp_coeff_ = WarpStoreCoeffFinder( - buffer_, warp_index_).Find(op->body); + buffer_, warp_index_, analyzer_).Find(op->body); CHECK_EQ(alloc_size % (warp_size_ * warp_coeff_), 0) << "Warp memory must be multiple of warp size"; warp_group_ = alloc_size / (warp_size_ * warp_coeff_); @@ -258,21 +263,19 @@ class WarpAccessRewriter : protected IRMutator { return std::make_pair(local_index, group); } Expr m = make_const(index.type(), warp_coeff_); - Range rng = Range::make_by_min_extent( - make_zero(index.type()), make_const(index.type(), warp_size_)); - Map vrange({{warp_index_, rng}}); // simple case, warp index is on the highest. if (warp_group_ == 1) { - Expr x = Simplify(index % m, vrange); - Expr z = Simplify(index / m, vrange); + Expr x = analyzer_->canonical_simplify(index % m); + Expr z = analyzer_->canonical_simplify(index / m); return std::make_pair(x, z); } else { - Expr x = Simplify(index % m, vrange); + Expr x = analyzer_->canonical_simplify(index % m); Expr y = index / make_const(index.type(), warp_coeff_ * warp_size_); y = y * m + x; Expr z = index % make_const(index.type(), warp_coeff_ * warp_size_) / m; - return std::make_pair(Simplify(y, vrange), Simplify(z, vrange)); + return std::make_pair(analyzer_->canonical_simplify(y), + analyzer_->canonical_simplify(z)); } } @@ -287,6 +290,44 @@ class WarpAccessRewriter : protected IRMutator { int warp_coeff_{0}; // the coefficient n int warp_group_{0}; + // Internal analyzer + arith::Analyzer* analyzer_; +}; + + +// Bind bound information of variables to make analyzer more effective +// TODO(tqchen): consider a pass to inline the bound info into the expr +// so analysis can be context independent. +class BindVarBoundInfo : public IRVisitor { + public: + explicit BindVarBoundInfo(arith::Analyzer* analyzer) + : analyzer_(analyzer) {} + + void Visit_(const For* op) final { + Var loop_var(op->loop_var.node_); + analyzer_->Bind(loop_var, Range::make_by_min_extent(op->min, op->extent)); + IRVisitor::Visit_(op); + } + + void Visit_(const AttrStmt* op) { + if (op->attr_key == attr::thread_extent || + op->attr_key == attr::virtual_thread) { + IterVar iv(op->node.node_); + CHECK_NE(iv->thread_tag.length(), 0U); + if (!var_dom_.count(iv->var.get())) { + Range dom = Range::make_by_min_extent(0, op->value); + var_dom_[iv->var.get()] = dom; + analyzer_->Bind(iv->var, dom); + } + } + IRVisitor::Visit_(op); + } + + protected: + // internal analyzer. + arith::Analyzer* analyzer_; + // variable domain + std::unordered_map var_dom_; }; // Mutator to change the read pattern @@ -298,6 +339,7 @@ class WarpMemoryRewriter : private IRMutator { Stmt Rewrite(Stmt stmt) { if (warp_size_ == 1) return stmt; + BindVarBoundInfo(&analyzer_).Visit(stmt); stmt = this->Mutate(stmt); stmt = CanonicalSimplify(stmt); return stmt; @@ -306,7 +348,7 @@ class WarpMemoryRewriter : private IRMutator { private: Stmt Mutate_(const Allocate* op, const Stmt& stmt) { if (warp_buffer_.count(op->buffer_var.get())) { - WarpAccessRewriter rewriter(warp_size_); + WarpAccessRewriter rewriter(warp_size_, &analyzer_); return rewriter.Rewrite(op, stmt); } else { return IRMutator::Mutate_(op, stmt); @@ -331,6 +373,9 @@ class WarpMemoryRewriter : private IRMutator { int warp_size_{0}; std::unordered_set warp_buffer_; + arith::Analyzer analyzer_; + // variable domain + std::unordered_map var_dom_; }; LoweredFunc diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 07d460eee7fe..1e9f302834d0 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -151,6 +151,12 @@ def test_add_index_simplify(): ck.verify(tvm.min(x, y + 2) + (-2), tvm.min(x + (-2), y)); ck.verify(tvm.min(x + 2, y + 3) + (-2), tvm.min(x, y + 1)); + ck.verify(tvm.max(0, 1 - x * 4) + x * 4, tvm.max(x * 4, 1)) + ck.verify(tvm.max(2 - x * 4, 0) + x * 4, tvm.max(x * 4, 2)) + + ck.verify(tvm.min(0, 1 - x * 4) + x * 4, tvm.min(x * 4, 1)) + ck.verify(tvm.min(2 - x * 4, 0) + x * 4, tvm.min(x * 4, 2)) + ck.verify(x * y + x * 10, x * (y + 10)) ck.verify(y * x + x * 10, x * (y + 10)) ck.verify(y * x + 10 * x, x * (y + 10)) @@ -212,6 +218,11 @@ def test_sub_index_simplify(): ck.verify(tvm.min(z, x + y) - x, tvm.min(z - x, y)) ck.verify(tvm.min(z, y + x) - x, tvm.min(z - x, y)) + ck.verify(tvm.max(x + y, z) - x, tvm.max(y, z - x)) + ck.verify(tvm.max(y + x, z) - x, tvm.max(y, z - x)) + ck.verify(tvm.max(z, x + y) - x, tvm.max(z - x, y)) + ck.verify(tvm.max(z, y + x) - x, tvm.max(z - x, y)) + ck.verify(x - tvm.min(x + y, z), tvm.max(0 - y, x - z)) ck.verify(x - tvm.min(y + x, z), tvm.max(0 - y, x - z)) ck.verify(x - tvm.min(z, x + y), tvm.max(x - z, 0 - y)) diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_arith_stmt_simplify.py similarity index 100% rename from tests/python/unittest/test_pass_simplify.py rename to tests/python/unittest/test_arith_stmt_simplify.py diff --git a/tests/python/unittest/test_pass_inject_copy_intrin.py b/tests/python/unittest/test_pass_inject_copy_intrin.py index 83b0f718f824..858b1e8a9153 100644 --- a/tests/python/unittest/test_pass_inject_copy_intrin.py +++ b/tests/python/unittest/test_pass_inject_copy_intrin.py @@ -80,7 +80,6 @@ def cb(src, dst, pad_before, pad_after, pad_value): stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb) def assert_expr_equal(a, b): - print(a, b) assert tvm.ir_pass.Simplify(a - b).value == 0 def test_copy_pad_split():