@@ -372,12 +372,12 @@ class IterMapRewriter : public ExprMutator {
372372 // IterSplit(k, scale=1)),
373373 // extent=9)
374374 // scale=1))
375- // Example(2): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
375+ // Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
376376 // predicate: 1 <= j*2 + k < 9
377- // Then, flattened form = IterSum(IterSplit(i, scale=9 ),
377+ // Then, flattened form = IterSum(IterSplit(i, scale=8 ),
378378 // IterSplit(j, scale=2),
379379 // IterSplit(k, scale=1))
380- // normal form = IterSum(IterSplit(i, scale=9 ),
380+ // normal form = IterSum(IterSplit(i, scale=8 ),
381381 // IterSplit(IterMark(IterSum(IterSplit(j, scale=2),
382382 // IterSplit(k, scale=1), base=-1),
383383 // extent=9-1)
@@ -495,7 +495,7 @@ class IterMapRewriter : public ExprMutator {
495495 */
496496 IterSumExpr NormalizeToIterOnBoundExpr (IterSumExpr expr, PrimExpr predicate_induced_min,
497497 PrimExpr predicate_induced_max) {
498- // remove base temporarily since `TryFuseIters` require zero base iter sum
498+ // normalize to zero base
499499 PrimExpr base = expr->base ;
500500 if (!is_zero (base)) {
501501 expr.CopyOnWrite ()->base = 0 ;
@@ -506,39 +506,40 @@ class IterMapRewriter : public ExprMutator {
506506 ICHECK (!opt.defined () || opt.value ()->args .size () == 1 );
507507 // scale should be 1
508508 if (opt.defined () && is_one (opt.value ()->args [0 ]->scale )) {
509- IterSplitExpr fused_split = opt.value ()->args [0 ];
510- IterSumExpr sum = Downcast<IterSumExpr>(fused_split ->source ->source );
509+ const IterSplitExpr split = opt.value ()->args [0 ];
510+ IterSumExpr structured_form = Downcast<IterSumExpr>(split ->source ->source );
511511 // get the flattened form
512- auto it = flattened_map_.find (sum );
512+ auto it = flattened_map_.find (structured_form );
513513 ICHECK (it != flattened_map_.end ());
514514 IterSumExpr flattened_form = it->second ;
515- // get the mark
515+ // get the mark and offset of the structured_form
516516 auto it_mark = sum_fuse_map_.find (flattened_form);
517517 ICHECK (it_mark != sum_fuse_map_.end ());
518518 IterMark mark = it_mark->second .mark ;
519519 PrimExpr mark_offset = it_mark->second .offset ;
520- // update iter mark iter range to [0, mark->extent) ^ [pred_min, pred_max)
521- PrimExpr mark_min = 0 ;
522- PrimExpr mark_max = mark->extent ;
520+ PrimExpr iter_min = mark_offset;
521+ PrimExpr iter_max = iter_min + mark->extent ;
523522 if (predicate_induced_min.defined ()) {
524- mark_min = max (predicate_induced_min, mark_min );
523+ iter_min = max (predicate_induced_min, iter_min );
525524 }
526525 if (predicate_induced_max.defined ()) {
527- mark_max = min (predicate_induced_max, mark_max );
526+ iter_max = min (predicate_induced_max, iter_max );
528527 }
529- // mark.CopyOnWrite()->min = mark_min;
530- mark.CopyOnWrite ()->source = mark->source - mark_min;
531- mark.CopyOnWrite ()->extent = mark_max - mark_min;
532- mark_offset = mark_offset + mark_min;
533-
534- // update the bound of the lhs based on predicate_induced_extent
535- sum_fuse_map_[flattened_form] = {mark, mark_offset};
528+ if (!is_zero (iter_min)) {
529+ // structured form's offset should be updated
530+ flattened_map_.erase (structured_form);
531+ structured_form.CopyOnWrite ()->base = -iter_min;
532+ mark.CopyOnWrite ()->source = structured_form;
533+ flattened_map_[structured_form] = flattened_form;
534+ }
535+ mark.CopyOnWrite ()->extent = iter_max - iter_min;
536+ sum_fuse_map_[flattened_form] = {mark, iter_min};
536537
537538 // we need to note down the flattened form of constrained iterators
538539 // to check the validity of constraints, see also CheckConstraints()
539540 constrained_iters_flattened_.push_back (flattened_form);
540- expr.CopyOnWrite ()->args = Array<IterSplitExpr>({fused_split });
541- expr.CopyOnWrite ()->base = base + mark_min ;
541+ expr.CopyOnWrite ()->args = Array<IterSplitExpr>({split });
542+ expr.CopyOnWrite ()->base = base + iter_min ;
542543 return expr;
543544 }
544545 Fail (Diagnostic::Error (expr->span )
@@ -554,7 +555,7 @@ class IterMapRewriter : public ExprMutator {
554555 */
555556 IterSumExpr NormalizeToIterWithOffset (IterSumExpr expr) {
556557 // We are normalizing a regular iter
557- if (expr->args .size () <= 1 ) return expr;
558+ if (expr->args .size () < 1 ) return expr;
558559 Optional<IterSumExpr> opt = TryFuseIters (expr);
559560 if (opt.defined ()) {
560561 return opt.value ();
@@ -593,6 +594,7 @@ class IterMapRewriter : public ExprMutator {
593594 Optional<IterSumExpr> TryFuseIters (IterSumExpr expr) {
594595 // select the iterators in order
595596 std::vector<bool > visited (expr->args .size (), false );
597+ size_t num_visited = 0 ;
596598 std::vector<IterSplitExpr> flattened_iters, grouped_iters;
597599 // canonicalize the expression into two different forms: flattened form and structured form
598600 // step0. check if find the base scale first
@@ -606,7 +608,11 @@ class IterMapRewriter : public ExprMutator {
606608 }
607609 }
608610 }
609- if (!base_scale) return NullOpt;
611+ if (!base_scale) {
612+ diag_ctx_.Emit (Diagnostic::Error (expr->span )
613+ << " Fuse iters failed, can not find a valid base scale" );
614+ return NullOpt;
615+ }
610616 // check if it can be remapped into a fused pattern.
611617 PrimExpr expected_extra_base = 0 ;
612618 PrimExpr expected_scale = base_scale.value ();
@@ -616,7 +622,11 @@ class IterMapRewriter : public ExprMutator {
616622 for (; j < expr->args .size (); ++j) {
617623 if (!visited[j] && analyzer_->CanProveEqual (expr->args [j]->scale , expected_scale)) break ;
618624 }
619- if (j == expr->args .size ()) return NullOpt;
625+ if (j == expr->args .size ()) {
626+ diag_ctx_.Emit (Diagnostic::Error (expr->span )
627+ << " Fuse iters failed, can not find expected scale " << expected_scale);
628+ return NullOpt;
629+ }
620630 // look for the longest constrained iter started from expr->args[j]
621631 // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
622632 // predicate: j*2 + k < 9
@@ -637,6 +647,8 @@ class IterMapRewriter : public ExprMutator {
637647 // Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
638648 // predicate = j*2 + k < 9
639649 // then j*2 + k matches the lower two splits of expr
650+ size_t flattened_iters_pos = flattened_iters.size ();
651+ bool match_constraint_suffix = false ;
640652 for (auto it = constraint_to_match.value ()->args .rbegin ();
641653 it != constraint_to_match.value ()->args .rend (); ++it) {
642654 size_t k = 0 ;
@@ -646,10 +658,32 @@ class IterMapRewriter : public ExprMutator {
646658 break ;
647659 }
648660 }
649- if (k == expr->args .size ()) return NullOpt;
661+ if (k == expr->args .size ()) {
662+ if (i == 0 && num_visited == visited.size ()) {
663+ // if match failed because of iterations are used out instead of scale mismatch,
664+ // and all used iters are visited during current match round, fallback to skip the
665+ // constraint. Example: exprs = [i * 2 + j, k], i in [0, 3), j in [0, 2), k in [0, 4)
666+ // predicate = i * 8 + j * 4 + k < 10
667+ for (size_t pos = flattened_iters_pos; pos < flattened_iters.size (); ++pos) {
668+ grouped_iters.push_back (flattened_iters[pos]);
669+ expected_scale *= flattened_iters[pos]->extent ;
670+ }
671+ match_constraint_suffix = true ;
672+ break ;
673+ }
674+ diag_ctx_.Emit (Diagnostic::Error (expr->span )
675+ << " Fuse iters failed, can not find flattened iter match constraint "
676+ << constraint_to_match.value ());
677+ return NullOpt;
678+ }
650679 visited[k] = true ;
680+ num_visited += 1 ;
651681 flattened_iters.push_back (expr->args [k]);
652682 }
683+ if (match_constraint_suffix) {
684+ // all iters are used to match the constraint, but only a suffix is matched.
685+ break ;
686+ }
653687 auto iter = sum_fuse_map_.find (constraint_to_match.value ());
654688 ICHECK (iter != sum_fuse_map_.end ());
655689 const IterMarkWithOffset& iter_matched = iter->second ;
@@ -661,6 +695,7 @@ class IterMapRewriter : public ExprMutator {
661695 } else {
662696 // constraint_to_match not found, skip this iterator
663697 visited[j] = true ;
698+ num_visited += 1 ;
664699 flattened_iters.push_back (expr->args [j]);
665700 grouped_iters.push_back (expr->args [j]);
666701 expected_scale *= expr->args [j]->extent ;
@@ -681,6 +716,8 @@ class IterMapRewriter : public ExprMutator {
681716 // old iter
682717 if (!analyzer_->CanProveEqual (expected_extra_base, it->second .offset * base_scale.value ())) {
683718 // the extra offset is not consistent with old
719+ diag_ctx_.Emit (Diagnostic::Error (expr->span )
720+ << " Fuse iters failed, the extra offset is not consistent with old" );
684721 return NullOpt;
685722 }
686723 return IterSumExpr ({IterSplitExpr (it->second .mark , base_scale.value ())},
0 commit comments