@@ -201,8 +201,9 @@ class IterMapRewriter : public ExprMutator {
201201 return NormalizeToIterWithOffset (ToIterSumExpr (DirectMutate (expr)));
202202 }
203203
204- IterSumExpr RewriteIterConstraint (const PrimExpr& expr, const PrimExpr& predicate_induced_min,
205- const PrimExpr& predicate_induced_max) {
204+ IterSumExpr RewriteIterConstraint (const PrimExpr& expr,
205+ const Optional<PrimExpr>& predicate_induced_min,
206+ const Optional<PrimExpr>& predicate_induced_max) {
206207 return NormalizeToIterOnBoundExpr (ToIterSumExpr (DirectMutate (expr)), predicate_induced_min,
207208 predicate_induced_max);
208209 }
@@ -494,16 +495,17 @@ class IterMapRewriter : public ExprMutator {
494495 * \param predicate_induced_max Open upper bound from iter constraint, maybe undefined.
495496 * \return The Normalized expression.
496497 */
497- IterSumExpr NormalizeToIterOnBoundExpr (IterSumExpr expr, PrimExpr predicate_induced_min,
498- PrimExpr predicate_induced_max) {
498+ IterSumExpr NormalizeToIterOnBoundExpr (IterSumExpr expr, Optional< PrimExpr> predicate_induced_min,
499+ Optional< PrimExpr> predicate_induced_max) {
499500 // normalize to zero base
500501 PrimExpr base = expr->base ;
501502 if (!is_zero (base)) {
502503 expr.CopyOnWrite ()->base = 0 ;
503- if (predicate_induced_min.defined ()) predicate_induced_min = predicate_induced_min - base;
504- if (predicate_induced_max.defined ()) predicate_induced_max = predicate_induced_max - base;
504+ if (predicate_induced_min.defined ())
505+ predicate_induced_min = predicate_induced_min.value () - base;
506+ if (predicate_induced_max.defined ())
507+ predicate_induced_max = predicate_induced_max.value () - base;
505508 }
506- if (expr->args .size () < 1 ) return expr;
507509 Optional<IterSumExpr> opt = TryFuseIters (expr);
508510 ICHECK (!opt.defined () || opt.value ()->args .size () == 1 );
509511 // scale should be 1
@@ -522,10 +524,10 @@ class IterMapRewriter : public ExprMutator {
522524 PrimExpr iter_min = mark_offset;
523525 PrimExpr iter_max = iter_min + mark->extent ;
524526 if (predicate_induced_min.defined ()) {
525- iter_min = max (predicate_induced_min, iter_min);
527+ iter_min = max (predicate_induced_min. value () , iter_min);
526528 }
527529 if (predicate_induced_max.defined ()) {
528- iter_max = min (predicate_induced_max, iter_max);
530+ iter_max = min (predicate_induced_max. value () , iter_max);
529531 }
530532 if (!is_zero (iter_min)) {
531533 // structured form's offset should be updated
@@ -536,7 +538,6 @@ class IterMapRewriter : public ExprMutator {
536538 }
537539 mark.CopyOnWrite ()->extent = iter_max - iter_min;
538540 sum_fuse_map_[flattened_form] = {mark, iter_min};
539-
540541 // we need to note down the flattened form of constrained iterators
541542 // to check the validity of constraints, see also CheckConstraints()
542543 constrained_iters_flattened_.push_back (flattened_form);
@@ -771,14 +772,15 @@ class IterMapRewriter : public ExprMutator {
771772struct IterConstraint {
772773 // The expr of the iter
773774 PrimExpr iter;
774- // The expr of the lower_bound
775- PrimExpr lower_bound;
776- // The expr of the upper_bound
777- PrimExpr upper_bound;
775+ // The expr of the lower_bound, maybe undefined
776+ Optional< PrimExpr> lower_bound;
777+ // The expr of the upper_bound, maybe undefined
778+ Optional< PrimExpr> upper_bound;
778779 // The size of the iter, which is the number of nodes
779780 size_t expr_size = 0 ;
780781
781- IterConstraint (PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size)
782+ IterConstraint (PrimExpr iter, Optional<PrimExpr> lower_bound, Optional<PrimExpr> upper_bound,
783+ size_t size)
782784 : iter(std::move(iter)),
783785 lower_bound (std::move(lower_bound)),
784786 upper_bound(std::move(upper_bound)),
@@ -788,11 +790,12 @@ struct IterConstraint {
788790/* !
789791 * \brief Split the predicate into `(a < b) && (c < d) && ...`
790792 * \param pred The predicate to be split.
793+ * \param input_iters The input iterators.
794+ * \param result The result of predicate split.
791795 * \return A list of IterConstraint, empty if the split failed.
792796 */
793- std::vector<IterConstraint> MatchBoundConstraints (PrimExpr pred,
794- const Map<Var, Range>& input_iters) {
795- std::vector<IterConstraint> result;
797+ bool MatchBoundConstraints (PrimExpr pred, Map<Var, Range>* input_iters,
798+ std::vector<IterConstraint>* result) {
796799 arith::PVar<PrimExpr> lhs, rhs, rest;
797800 for (;;) {
798801 // try extract comparisions
@@ -821,78 +824,94 @@ std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
821824 is_equal = true ;
822825 is_finish = true ;
823826 } else {
824- return std::vector<IterConstraint>() ;
827+ return false ;
825828 }
826829 PrimExpr lhs_expr = lhs.Eval ();
827830 PrimExpr rhs_expr = rhs.Eval ();
828831 // we only accept predicate of integers
829832 if (!((lhs_expr->dtype .is_int () || lhs_expr->dtype .is_uint ()) &&
830833 (rhs_expr->dtype .is_int () || rhs_expr->dtype .is_uint ()))) {
831- return std::vector<IterConstraint>() ;
834+ return false ;
832835 }
833836 // determine iter and bound, if we can not distinguish them simply,
834837 // try divide (lhs - rhs) into itervar aware and itervar free parts
835838 auto f_use_itervar = [&input_iters](const VarNode* v) {
836- return input_iters. count (GetRef<Var>(v));
839+ return input_iters-> count (GetRef<Var>(v));
837840 };
838841 bool bound_at_left;
839- if (is_const_int (lhs_expr) || !UsesVar (lhs_expr, f_use_itervar)) {
840- bound_at_left = true ;
841- } else if (is_const_int (rhs_expr) || !UsesVar (rhs_expr, f_use_itervar)) {
842- bound_at_left = false ;
843- } else {
844- bound_at_left = false ; // accumulate bound to rhs
845- PrimExpr sum_parts = lhs_expr - rhs_expr;
846- lhs_expr = 0 ;
847- rhs_expr = 0 ;
848- std::function<void (const PrimExpr&, bool )> f_extract =
849- [&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
850- if (const AddNode* add = part.as <AddNode>()) {
851- f_extract (add->a , sign);
852- f_extract (add->b , sign);
853- } else if (const SubNode* sub = part.as <SubNode>()) {
854- f_extract (sub->a , sign);
855- f_extract (sub->b , !sign);
856- } else if (UsesVar (part, f_use_itervar)) {
857- lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
858- } else {
859- rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
860- }
861- };
862- f_extract (sum_parts, true );
863- arith::Analyzer analyzer;
864- lhs_expr = analyzer.Simplify (lhs_expr);
865- rhs_expr = analyzer.Simplify (rhs_expr);
866- }
867- PrimExpr lower_bound, upper_bound, iter;
868- if (is_greater) {
869- if (bound_at_left) {
870- // bound > iter
871- upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
872- iter = rhs_expr;
842+ if (UsesVar (lhs_expr, f_use_itervar) || UsesVar (rhs_expr, f_use_itervar)) {
843+ // At least it uses one input iter
844+ if (is_const_int (lhs_expr) || !UsesVar (lhs_expr, f_use_itervar)) {
845+ bound_at_left = true ;
846+ } else if (is_const_int (rhs_expr) || !UsesVar (rhs_expr, f_use_itervar)) {
847+ bound_at_left = false ;
873848 } else {
874- // iter > bound
875- lower_bound = is_equal ? rhs_expr : rhs_expr + 1 ;
876- iter = lhs_expr;
849+ bound_at_left = false ; // accumulate bound to rhs
850+ PrimExpr sum_parts = lhs_expr - rhs_expr;
851+ lhs_expr = 0 ;
852+ rhs_expr = 0 ;
853+ std::function<void (const PrimExpr&, bool )> f_extract =
854+ [&lhs_expr, &rhs_expr, f_use_itervar, &f_extract](const PrimExpr& part, bool sign) {
855+ if (const AddNode* add = part.as <AddNode>()) {
856+ f_extract (add->a , sign);
857+ f_extract (add->b , sign);
858+ } else if (const SubNode* sub = part.as <SubNode>()) {
859+ f_extract (sub->a , sign);
860+ f_extract (sub->b , !sign);
861+ } else if (UsesVar (part, f_use_itervar)) {
862+ lhs_expr = sign ? lhs_expr + part : lhs_expr - part;
863+ } else {
864+ rhs_expr = sign ? rhs_expr - part : rhs_expr + part;
865+ }
866+ };
867+ f_extract (sum_parts, true );
868+ arith::Analyzer analyzer;
869+ lhs_expr = analyzer.Simplify (lhs_expr);
870+ rhs_expr = analyzer.Simplify (rhs_expr);
877871 }
878- } else {
879- if (bound_at_left) {
880- // bound < iter
881- lower_bound = is_equal ? lhs_expr : lhs_expr + 1 ;
882- iter = rhs_expr;
872+ Optional<PrimExpr> lower_bound = NullOpt, upper_bound = NullOpt;
873+ PrimExpr iter;
874+ if (is_greater) {
875+ if (bound_at_left) {
876+ // bound > iter / bound >= iter
877+ upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
878+ iter = rhs_expr;
879+ } else {
880+ // iter > bound / iter >= bound
881+ lower_bound = is_equal ? rhs_expr : rhs_expr + 1 ;
882+ iter = lhs_expr;
883+ }
883884 } else {
884- // iter < bound
885- upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
886- iter = lhs_expr;
885+ if (bound_at_left) {
886+ // bound < iter / bound <= iter
887+ lower_bound = is_equal ? lhs_expr : lhs_expr + 1 ;
888+ iter = rhs_expr;
889+ } else {
890+ // iter < bound / iter <= bound
891+ upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
892+ iter = lhs_expr;
893+ }
894+ }
895+ // If it is a predicate for a single input iter
896+ if (const auto * var_ptr = iter.as <VarNode>()) {
897+ auto it = input_iters->find (GetRef<Var>(var_ptr));
898+ if (it != input_iters->end ()) {
899+ PrimExpr iter_min = (*it).second ->min ;
900+ PrimExpr iter_max = (*it).second ->min + (*it).second ->extent ;
901+ if (lower_bound.defined ()) iter_min = max (iter_min, lower_bound.value ());
902+ if (upper_bound.defined ()) iter_max = min (iter_max, upper_bound.value ());
903+ input_iters->Set (GetRef<Var>(var_ptr), Range (iter_min, iter_max));
904+ }
905+ } else {
906+ result->emplace_back (iter, lower_bound, upper_bound, 0 );
887907 }
888908 }
889- result.emplace_back (iter, lower_bound, upper_bound, 0 );
890909 if (is_finish) {
891910 break ;
892911 }
893912 pred = rest.Eval ();
894913 }
895- return result ;
914+ return true ;
896915}
897916
898917bool IterRangeSanityCheck (const Map<Var, Range>& iter_ranges) {
@@ -912,13 +931,14 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
912931 // - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
913932 // - Step1: IterIndependenceChecker checks if the iterator are independent.
914933 if (!IterRangeSanityCheck (input_iters)) return Array<IterSumExpr>();
915- std::vector<IterConstraint> constraints = MatchBoundConstraints (predicate, input_iters);
916- if (!is_one (predicate) && constraints.empty ()) {
934+ Map<Var, Range> constrained_input_iters = input_iters;
935+ std::vector<IterConstraint> constraints;
936+ if (!is_one (predicate) &&
937+ !MatchBoundConstraints (predicate, &constrained_input_iters, &constraints)) {
917938 diag_ctx.Emit (Diagnostic::Error (predicate->span )
918939 << " Fail to collect constraints from iteration predicate: " << predicate);
919940 return Array<IterSumExpr>();
920941 }
921-
922942 // We have to make sure when we visit an iterator, all the constraints related with its successors
923943 // in the iter var graph has been visited, where the expression of this iterator will contain the
924944 // expression of its successor, so we sort them by their sizes.
@@ -930,10 +950,11 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
930950 constraints.begin (), constraints.end (),
931951 [](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size ; });
932952
933- IterMapRewriter rewriter (analyzer, input_iters , diag_ctx);
953+ IterMapRewriter rewriter (analyzer, constrained_input_iters , diag_ctx);
934954 // Step0.0: rewrite constraints in the order from size-small ones to size-big ones
935955 for (const IterConstraint& constraint : constraints) {
936- rewriter.RewriteIterConstraint (constraint.iter , constraint.lower_bound , constraint.upper_bound );
956+ auto res = rewriter.RewriteIterConstraint (constraint.iter , constraint.lower_bound ,
957+ constraint.upper_bound );
937958 if (rewriter.unresolved_count () != 0 ) return Array<IterSumExpr>();
938959 }
939960 if (!rewriter.CheckConstraints ()) {
@@ -945,7 +966,10 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
945966 Array<IterSumExpr> results;
946967 for (PrimExpr value : indices) {
947968 results.push_back (rewriter.Rewrite (value));
948- if (rewriter.unresolved_count () != 0 ) return Array<IterSumExpr>();
969+ if (rewriter.unresolved_count () != 0 ) {
970+ diag_ctx.Emit (Diagnostic::Error (predicate->span ) << " Affine mapping detection failed" );
971+ return Array<IterSumExpr>();
972+ }
949973 }
950974 // Step1: IterIndependenceChecker checks if the iterator are independent.
951975 if (!rewriter.CheckMapping (results, require_bijective)) {
@@ -1306,7 +1330,8 @@ class IterMapToExprNormalizer : public ExprMutator {
13061330 } else if (analyzer_->CanProve (expr->source ->extent == expr->lower_factor * expr->extent )) {
13071331 return floordiv (source, expr->lower_factor ) * expr->scale ;
13081332 } else {
1309- return floormod (floordiv (source, expr->lower_factor ), expr->extent ) * expr->scale ;
1333+ return floordiv (floormod (source, expr->lower_factor * expr->extent ), expr->lower_factor ) *
1334+ expr->scale ;
13101335 }
13111336 }
13121337
0 commit comments