Skip to content

Commit c3464fe

Browse files
spectrometerHBHpfk-beta
authored andcommitted
[MetaSchedule] Arithmetic analysis (apache#10403)
This PR changes the normal form of the affine detector and supports a single var predicate. It also enhances ModularSet detector to enable floor mod patterns.
1 parent 8a30172 commit c3464fe

10 files changed

+164
-94
lines changed

src/arith/iter_affine_map.cc

Lines changed: 100 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,9 @@ class IterMapRewriter : public ExprMutator {
201201
return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
202202
}
203203

204-
IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min,
205-
const PrimExpr& predicate_induced_max) {
204+
IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
205+
const Optional<PrimExpr>& predicate_induced_min,
206+
const Optional<PrimExpr>& predicate_induced_max) {
206207
return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min,
207208
predicate_induced_max);
208209
}
@@ -494,16 +495,17 @@ class IterMapRewriter : public ExprMutator {
494495
* \param predicate_induced_max Open upper bound from iter constraint, maybe undefined.
495496
* \return The Normalized expression.
496497
*/
497-
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min,
498-
PrimExpr predicate_induced_max) {
498+
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional<PrimExpr> predicate_induced_min,
499+
Optional<PrimExpr> predicate_induced_max) {
499500
// normalize to zero base
500501
PrimExpr base = expr->base;
501502
if (!is_zero(base)) {
502503
expr.CopyOnWrite()->base = 0;
503-
if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base;
504-
if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base;
504+
if (predicate_induced_min.defined())
505+
predicate_induced_min = predicate_induced_min.value() - base;
506+
if (predicate_induced_max.defined())
507+
predicate_induced_max = predicate_induced_max.value() - base;
505508
}
506-
if (expr->args.size() < 1) return expr;
507509
Optional<IterSumExpr> opt = TryFuseIters(expr);
508510
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
509511
// scale should be 1
@@ -522,10 +524,10 @@ class IterMapRewriter : public ExprMutator {
522524
PrimExpr iter_min = mark_offset;
523525
PrimExpr iter_max = iter_min + mark->extent;
524526
if (predicate_induced_min.defined()) {
525-
iter_min = max(predicate_induced_min, iter_min);
527+
iter_min = max(predicate_induced_min.value(), iter_min);
526528
}
527529
if (predicate_induced_max.defined()) {
528-
iter_max = min(predicate_induced_max, iter_max);
530+
iter_max = min(predicate_induced_max.value(), iter_max);
529531
}
530532
if (!is_zero(iter_min)) {
531533
// structured form's offset should be updated
@@ -536,7 +538,6 @@ class IterMapRewriter : public ExprMutator {
536538
}
537539
mark.CopyOnWrite()->extent = iter_max - iter_min;
538540
sum_fuse_map_[flattened_form] = {mark, iter_min};
539-
540541
// we need to note down the flattened form of constrained iterators
541542
// to check the validity of constraints, see also CheckConstraints()
542543
constrained_iters_flattened_.push_back(flattened_form);
@@ -771,14 +772,15 @@ class IterMapRewriter : public ExprMutator {
771772
struct IterConstraint {
772773
// The expr of the iter
773774
PrimExpr iter;
774-
// The expr of the lower_bound
775-
PrimExpr lower_bound;
776-
// The expr of the upper_bound
777-
PrimExpr upper_bound;
775+
// The expr of the lower_bound, maybe undefined
776+
Optional<PrimExpr> lower_bound;
777+
// The expr of the upper_bound, maybe undefined
778+
Optional<PrimExpr> upper_bound;
778779
// The size of the iter, which is the number of nodes
779780
size_t expr_size = 0;
780781

781-
IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size)
782+
IterConstraint(PrimExpr iter, Optional<PrimExpr> lower_bound, Optional<PrimExpr> upper_bound,
783+
size_t size)
782784
: iter(std::move(iter)),
783785
lower_bound(std::move(lower_bound)),
784786
upper_bound(std::move(upper_bound)),
@@ -788,11 +790,12 @@ struct IterConstraint {
788790
/*!
789791
* \brief Split the predicate into `(a < b) && (c < d) && ...`
790792
* \param pred The predicate to be split.
793+
* \param input_iters The input iterators.
794+
* \param result The result of predicate split.
791795
* \return A list of IterConstraint, empty if the split failed.
792796
*/
793-
std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
794-
const Map<Var, Range>& input_iters) {
795-
std::vector<IterConstraint> result;
797+
bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>* input_iters,
798+
std::vector<IterConstraint>* result) {
796799
arith::PVar<PrimExpr> lhs, rhs, rest;
797800
for (;;) {
798801
// try extract comparisions
@@ -821,78 +824,94 @@ std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
821824
is_equal = true;
822825
is_finish = true;
823826
} else {
824-
return std::vector<IterConstraint>();
827+
return false;
825828
}
826829
PrimExpr lhs_expr = lhs.Eval();
827830
PrimExpr rhs_expr = rhs.Eval();
828831
// we only accept predicate of integers
829832
if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) &&
830833
(rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) {
831-
return std::vector<IterConstraint>();
834+
return false;
832835
}
833836
// determine iter and bound, if we can not distinguish them simply,
834837
// try divide (lhs - rhs) into itervar aware and itervar free parts
835838
auto f_use_itervar = [&input_iters](const VarNode* v) {
836-
return input_iters.count(GetRef<Var>(v));
839+
return input_iters->count(GetRef<Var>(v));
837840
};
838841
bool bound_at_left;
839-
if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) {
840-
bound_at_left = true;
841-
} else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) {
842-
bound_at_left = false;
843-
} else {
844-
bound_at_left = false; // accumulate bound to rhs
845-
PrimExpr sum_parts = lhs_expr - rhs_expr;
846-
lhs_expr = 0;
847-
rhs_expr = 0;
848-
std::function<void(const PrimExpr&, bool)> f_extract =
849-
[&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
850-
if (const AddNode* add = part.as<AddNode>()) {
851-
f_extract(add->a, sign);
852-
f_extract(add->b, sign);
853-
} else if (const SubNode* sub = part.as<SubNode>()) {
854-
f_extract(sub->a, sign);
855-
f_extract(sub->b, !sign);
856-
} else if (UsesVar(part, f_use_itervar)) {
857-
lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
858-
} else {
859-
rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
860-
}
861-
};
862-
f_extract(sum_parts, true);
863-
arith::Analyzer analyzer;
864-
lhs_expr = analyzer.Simplify(lhs_expr);
865-
rhs_expr = analyzer.Simplify(rhs_expr);
866-
}
867-
PrimExpr lower_bound, upper_bound, iter;
868-
if (is_greater) {
869-
if (bound_at_left) {
870-
// bound > iter
871-
upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
872-
iter = rhs_expr;
842+
if (UsesVar(lhs_expr, f_use_itervar) || UsesVar(rhs_expr, f_use_itervar)) {
843+
// At least it uses one input iter
844+
if (is_const_int(lhs_expr) || !UsesVar(lhs_expr, f_use_itervar)) {
845+
bound_at_left = true;
846+
} else if (is_const_int(rhs_expr) || !UsesVar(rhs_expr, f_use_itervar)) {
847+
bound_at_left = false;
873848
} else {
874-
// iter > bound
875-
lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
876-
iter = lhs_expr;
849+
bound_at_left = false; // accumulate bound to rhs
850+
PrimExpr sum_parts = lhs_expr - rhs_expr;
851+
lhs_expr = 0;
852+
rhs_expr = 0;
853+
std::function<void(const PrimExpr&, bool)> f_extract =
854+
[&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
855+
if (const AddNode* add = part.as<AddNode>()) {
856+
f_extract(add->a, sign);
857+
f_extract(add->b, sign);
858+
} else if (const SubNode* sub = part.as<SubNode>()) {
859+
f_extract(sub->a, sign);
860+
f_extract(sub->b, !sign);
861+
} else if (UsesVar(part, f_use_itervar)) {
862+
lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
863+
} else {
864+
rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
865+
}
866+
};
867+
f_extract(sum_parts, true);
868+
arith::Analyzer analyzer;
869+
lhs_expr = analyzer.Simplify(lhs_expr);
870+
rhs_expr = analyzer.Simplify(rhs_expr);
877871
}
878-
} else {
879-
if (bound_at_left) {
880-
// bound < iter
881-
lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
882-
iter = rhs_expr;
872+
Optional<PrimExpr> lower_bound = NullOpt, upper_bound = NullOpt;
873+
PrimExpr iter;
874+
if (is_greater) {
875+
if (bound_at_left) {
876+
// bound > iter / bound >= iter
877+
upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
878+
iter = rhs_expr;
879+
} else {
880+
// iter > bound / iter >= bound
881+
lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
882+
iter = lhs_expr;
883+
}
883884
} else {
884-
// iter < bound
885-
upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
886-
iter = lhs_expr;
885+
if (bound_at_left) {
886+
// bound < iter / bound <= iter
887+
lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
888+
iter = rhs_expr;
889+
} else {
890+
// iter < bound / iter <= bound
891+
upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
892+
iter = lhs_expr;
893+
}
894+
}
895+
// If it is a predicate for a single input iter
896+
if (const auto* var_ptr = iter.as<VarNode>()) {
897+
auto it = input_iters->find(GetRef<Var>(var_ptr));
898+
if (it != input_iters->end()) {
899+
PrimExpr iter_min = (*it).second->min;
900+
PrimExpr iter_max = (*it).second->min + (*it).second->extent;
901+
if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value());
902+
if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value());
903+
input_iters->Set(GetRef<Var>(var_ptr), Range(iter_min, iter_max));
904+
}
905+
} else {
906+
result->emplace_back(iter, lower_bound, upper_bound, 0);
887907
}
888908
}
889-
result.emplace_back(iter, lower_bound, upper_bound, 0);
890909
if (is_finish) {
891910
break;
892911
}
893912
pred = rest.Eval();
894913
}
895-
return result;
914+
return true;
896915
}
897916

898917
bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
@@ -912,13 +931,14 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
912931
// - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
913932
// - Step1: IterIndependenceChecker checks if the iterator are independent.
914933
if (!IterRangeSanityCheck(input_iters)) return Array<IterSumExpr>();
915-
std::vector<IterConstraint> constraints = MatchBoundConstraints(predicate, input_iters);
916-
if (!is_one(predicate) && constraints.empty()) {
934+
Map<Var, Range> constrained_input_iters = input_iters;
935+
std::vector<IterConstraint> constraints;
936+
if (!is_one(predicate) &&
937+
!MatchBoundConstraints(predicate, &constrained_input_iters, &constraints)) {
917938
diag_ctx.Emit(Diagnostic::Error(predicate->span)
918939
<< "Fail to collect constraints from iteration predicate: " << predicate);
919940
return Array<IterSumExpr>();
920941
}
921-
922942
// We have to make sure when we visit an iterator, all the constraints related with its successors
923943
// in the iter var graph has been visited, where the expression of this iterator will contain the
924944
// expression of its successor, so we sort them by their sizes.
@@ -930,10 +950,11 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
930950
constraints.begin(), constraints.end(),
931951
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });
932952

933-
IterMapRewriter rewriter(analyzer, input_iters, diag_ctx);
953+
IterMapRewriter rewriter(analyzer, constrained_input_iters, diag_ctx);
934954
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
935955
for (const IterConstraint& constraint : constraints) {
936-
rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, constraint.upper_bound);
956+
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
957+
constraint.upper_bound);
937958
if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
938959
}
939960
if (!rewriter.CheckConstraints()) {
@@ -945,7 +966,10 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
945966
Array<IterSumExpr> results;
946967
for (PrimExpr value : indices) {
947968
results.push_back(rewriter.Rewrite(value));
948-
if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
969+
if (rewriter.unresolved_count() != 0) {
970+
diag_ctx.Emit(Diagnostic::Error(predicate->span) << "Affine mapping detection failed");
971+
return Array<IterSumExpr>();
972+
}
949973
}
950974
// Step1: IterIndependenceChecker checks if the iterator are independent.
951975
if (!rewriter.CheckMapping(results, require_bijective)) {
@@ -1306,7 +1330,8 @@ class IterMapToExprNormalizer : public ExprMutator {
13061330
} else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor * expr->extent)) {
13071331
return floordiv(source, expr->lower_factor) * expr->scale;
13081332
} else {
1309-
return floormod(floordiv(source, expr->lower_factor), expr->extent) * expr->scale;
1333+
return floordiv(floormod(source, expr->lower_factor * expr->extent), expr->lower_factor) *
1334+
expr->scale;
13101335
}
13111336
}
13121337

src/arith/modular_set.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,18 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
196196
return Everything();
197197
}
198198

199+
Entry VisitExpr_(const FloorModNode* op) final {
200+
Entry b = VisitExpr(op->b);
201+
if (b.is_const()) {
202+
int64_t c2 = b.base;
203+
ICHECK(c2 != 0) << "MathError: the divisor is 0";
204+
Entry a = VisitExpr(op->a);
205+
int64_t coeff = ZeroAwareGCD(a.coeff, c2);
206+
return Entry(coeff, a.base % c2);
207+
}
208+
return Everything();
209+
}
210+
199211
Entry VisitExpr_(const MinNode* op) final {
200212
Entry a = VisitExpr(op->a);
201213
Entry b = VisitExpr(op->b);

src/arith/rewrite_simplify.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
192192
TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x);
193193
// floor div
194194
TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x);
195+
TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2),
196+
c2.Eval()->value > 0);
195197

196198
// canonicalization rule
197199
// will try rewrite again after canonicalization.
@@ -785,6 +787,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
785787
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), x * floordiv(c1, c2) + floordiv(y, c2),
786788
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
787789

790+
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y, c2), floordiv(x, floordiv(c2, c1)),
791+
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
792+
c2.Eval()->value % c1.Eval()->value == 0 &&
793+
CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));
794+
788795
TVM_TRY_REWRITE_IF(floordiv(min(x * c1, y), c2), min(x * floordiv(c1, c2), floordiv(y, c2)),
789796
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
790797

@@ -794,6 +801,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
794801
TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(y, c2) + x * floordiv(c1, c2),
795802
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
796803

804+
TVM_TRY_REWRITE_IF(floordiv(y + x * c1, c2), floordiv(x, floordiv(c2, c1)),
805+
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
806+
c2.Eval()->value % c1.Eval()->value == 0 &&
807+
CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));
808+
797809
TVM_TRY_REWRITE_IF(floordiv(min(y, x * c1), c2), min(floordiv(y, c2), x * floordiv(c1, c2)),
798810
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
799811

src/tir/schedule/primitive/loop_transformation.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
413413
for (int i = 0; i < n; i++) {
414414
const PrimExpr& factor = factors[i];
415415
Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i));
416-
substitute_value = substitute_value * factor + var;
416+
if (!is_one(factor)) substitute_value = substitute_value * factor + var;
417417
analyzer.Bind(var, Range::FromMinExtent(0, factor));
418418
new_loop_vars.emplace_back(std::move(var));
419419
}
@@ -505,11 +505,14 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
505505
Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix);
506506
Array<PrimExpr> substitute_value;
507507
substitute_value.resize(loops.size());
508-
PrimExpr tot = fused_var;
509-
for (int i = static_cast<int>(loops.size()) - 1; i >= 0; i--) {
510-
substitute_value.Set(i, floormod(tot, loops[i]->extent));
511-
tot = floordiv(tot, loops[i]->extent);
512-
}
508+
PrimExpr lower = 1;
509+
for (int i = static_cast<int>(loops.size()) - 1; i > 0; i--) {
510+
substitute_value.Set(i, is_one(loops[i]->extent)
511+
? 0
512+
: floordiv(floormod(fused_var, lower * loops[i]->extent), lower));
513+
lower = lower * loops[i]->extent;
514+
}
515+
substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower));
513516
Stmt new_stmt = loops.back()->body;
514517
Map<Block, Block> opaque_block_reuse;
515518
auto f_substitute = [&](const Var& v) -> Optional<PrimExpr> {
@@ -534,6 +537,7 @@ StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
534537
self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
535538
return self->stmt2ref.at(new_stmt.get());
536539
}
540+
537541
/*!
538542
* \brief Collect an array of loop srefs into a set
539543
* \param self The schedule state

tests/python/unittest/test_arith_intset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ def test_mod():
105105
ck.verify(
106106
flm(y, 8),
107107
{y: tvm.arith.IntervalSet(z * 8 + x * 4, z * 8 + x * 4 + 3)},
108-
(x * 4 - 8 * fld(x * 4, 8), x * 4 - 8 * fld(x * 4, 8) + 3),
108+
(
109+
z * 8 + x * 4 - 8 * fld(z * 8 + x * 4, 8),
110+
z * 8 + x * 4 + 3 - 8 * fld(z * 8 + x * 4, 8),
111+
),
109112
)
110113
ck1 = IntSetChecker()
111114
ck1.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 2))

0 commit comments

Comments
 (0)