Skip to content

Commit 7ea21ef

Browse files
add more affine check testcases, fix bug for single iter and duplicate constraints on iter
1 parent 2142f91 commit 7ea21ef

File tree

4 files changed

+246
-56
lines changed

4 files changed

+246
-56
lines changed

src/arith/iter_affine_map.cc

Lines changed: 63 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -372,12 +372,12 @@ class IterMapRewriter : public ExprMutator {
372372
// IterSplit(k, scale=1)),
373373
// extent=9)
374374
// scale=1))
375-
// Example(2): expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
375+
// Example(2): expr = i*8 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
376376
// predicate: 1 <= j*2 + k < 9
377-
// Then, flattened form = IterSum(IterSplit(i, scale=9),
377+
// Then, flattened form = IterSum(IterSplit(i, scale=8),
378378
// IterSplit(j, scale=2),
379379
// IterSplit(k, scale=1))
380-
// normal form = IterSum(IterSplit(i, scale=9),
380+
// normal form = IterSum(IterSplit(i, scale=8),
381381
// IterSplit(IterMark(IterSum(IterSplit(j, scale=2),
382382
// IterSplit(k, scale=1), base=-1),
383383
// extent=9-1)
@@ -495,7 +495,7 @@ class IterMapRewriter : public ExprMutator {
495495
*/
496496
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min,
497497
PrimExpr predicate_induced_max) {
498-
// remove base temporarily since `TryFuseIters` require zero base iter sum
498+
// normalize to zero base
499499
PrimExpr base = expr->base;
500500
if (!is_zero(base)) {
501501
expr.CopyOnWrite()->base = 0;
@@ -506,39 +506,40 @@ class IterMapRewriter : public ExprMutator {
506506
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
507507
// scale should be 1
508508
if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
509-
IterSplitExpr fused_split = opt.value()->args[0];
510-
IterSumExpr sum = Downcast<IterSumExpr>(fused_split->source->source);
509+
const IterSplitExpr split = opt.value()->args[0];
510+
IterSumExpr structured_form = Downcast<IterSumExpr>(split->source->source);
511511
// get the flattened form
512-
auto it = flattened_map_.find(sum);
512+
auto it = flattened_map_.find(structured_form);
513513
ICHECK(it != flattened_map_.end());
514514
IterSumExpr flattened_form = it->second;
515-
// get the mark
515+
// get the mark and offset of the structured_form
516516
auto it_mark = sum_fuse_map_.find(flattened_form);
517517
ICHECK(it_mark != sum_fuse_map_.end());
518518
IterMark mark = it_mark->second.mark;
519519
PrimExpr mark_offset = it_mark->second.offset;
520-
// update iter mark iter range to [0, mark->extent) ^ [pred_min, pred_max)
521-
PrimExpr mark_min = 0;
522-
PrimExpr mark_max = mark->extent;
520+
PrimExpr iter_min = mark_offset;
521+
PrimExpr iter_max = iter_min + mark->extent;
523522
if (predicate_induced_min.defined()) {
524-
mark_min = max(predicate_induced_min, mark_min);
523+
iter_min = max(predicate_induced_min, iter_min);
525524
}
526525
if (predicate_induced_max.defined()) {
527-
mark_max = min(predicate_induced_max, mark_max);
526+
iter_max = min(predicate_induced_max, iter_max);
528527
}
529-
// mark.CopyOnWrite()->min = mark_min;
530-
mark.CopyOnWrite()->source = mark->source - mark_min;
531-
mark.CopyOnWrite()->extent = mark_max - mark_min;
532-
mark_offset = mark_offset + mark_min;
533-
534-
// update the bound of the lhs based on predicate_induced_extent
535-
sum_fuse_map_[flattened_form] = {mark, mark_offset};
528+
if (!is_zero(iter_min)) {
529+
// structured form's offset should be updated
530+
flattened_map_.erase(structured_form);
531+
structured_form.CopyOnWrite()->base = -iter_min;
532+
mark.CopyOnWrite()->source = structured_form;
533+
flattened_map_[structured_form] = flattened_form;
534+
}
535+
mark.CopyOnWrite()->extent = iter_max - iter_min;
536+
sum_fuse_map_[flattened_form] = {mark, iter_min};
536537

537538
// we need to note down the flattened form of constrained iterators
538539
// to check the validity of constraints, see also CheckConstraints()
539540
constrained_iters_flattened_.push_back(flattened_form);
540-
expr.CopyOnWrite()->args = Array<IterSplitExpr>({fused_split});
541-
expr.CopyOnWrite()->base = base + mark_min;
541+
expr.CopyOnWrite()->args = Array<IterSplitExpr>({split});
542+
expr.CopyOnWrite()->base = base + iter_min;
542543
return expr;
543544
}
544545
Fail(Diagnostic::Error(expr->span)
@@ -554,7 +555,7 @@ class IterMapRewriter : public ExprMutator {
554555
*/
555556
IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
556557
// We are normalizing a regular iter
557-
if (expr->args.size() <= 1) return expr;
558+
if (expr->args.size() < 1) return expr;
558559
Optional<IterSumExpr> opt = TryFuseIters(expr);
559560
if (opt.defined()) {
560561
return opt.value();
@@ -593,6 +594,7 @@ class IterMapRewriter : public ExprMutator {
593594
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr) {
594595
// select the iterators in order
595596
std::vector<bool> visited(expr->args.size(), false);
597+
size_t num_visited = 0;
596598
std::vector<IterSplitExpr> flattened_iters, grouped_iters;
597599
// canonicalize the expression into two different forms: flattened form and structured form
598600
// step0. check if find the base scale first
@@ -606,7 +608,11 @@ class IterMapRewriter : public ExprMutator {
606608
}
607609
}
608610
}
609-
if (!base_scale) return NullOpt;
611+
if (!base_scale) {
612+
diag_ctx_.Emit(Diagnostic::Error(expr->span)
613+
<< "Fuse iters failed, can not find a valid base scale");
614+
return NullOpt;
615+
}
610616
// check if it can be remapped into a fused pattern.
611617
PrimExpr expected_extra_base = 0;
612618
PrimExpr expected_scale = base_scale.value();
@@ -616,7 +622,11 @@ class IterMapRewriter : public ExprMutator {
616622
for (; j < expr->args.size(); ++j) {
617623
if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break;
618624
}
619-
if (j == expr->args.size()) return NullOpt;
625+
if (j == expr->args.size()) {
626+
diag_ctx_.Emit(Diagnostic::Error(expr->span)
627+
<< "Fuse iters failed, can not find expected scale " << expected_scale);
628+
return NullOpt;
629+
}
620630
// look for the longest constrained iter started from expr->args[j]
621631
// Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
622632
// predicate: j*2 + k < 9
@@ -637,6 +647,8 @@ class IterMapRewriter : public ExprMutator {
637647
// Example: expr = i*9 + j*2 + k, i in [0, 4) j in [0, 5) k in [0, 2)
638648
// predicate = j*2 + k < 9
639649
// then j*2 + k matches the lower two splits of expr
650+
size_t flattened_iters_pos = flattened_iters.size();
651+
bool match_constraint_suffix = false;
640652
for (auto it = constraint_to_match.value()->args.rbegin();
641653
it != constraint_to_match.value()->args.rend(); ++it) {
642654
size_t k = 0;
@@ -646,10 +658,32 @@ class IterMapRewriter : public ExprMutator {
646658
break;
647659
}
648660
}
649-
if (k == expr->args.size()) return NullOpt;
661+
if (k == expr->args.size()) {
662+
if (i == 0 && num_visited == visited.size()) {
663+
// if match failed because of iterations are used out instead of scale mismatch,
664+
// and all used iters are visited during current match round, fallback to skip the
665+
// constraint. Example: exprs = [i * 2 + j, k], i in [0, 3), j in [0, 2), k in [0, 4)
666+
// predicate = i * 8 + j * 4 + k < 10
667+
for (size_t pos = flattened_iters_pos; pos < flattened_iters.size(); ++pos) {
668+
grouped_iters.push_back(flattened_iters[pos]);
669+
expected_scale *= flattened_iters[pos]->extent;
670+
}
671+
match_constraint_suffix = true;
672+
break;
673+
}
674+
diag_ctx_.Emit(Diagnostic::Error(expr->span)
675+
<< "Fuse iters failed, can not find flattened iter match constraint "
676+
<< constraint_to_match.value());
677+
return NullOpt;
678+
}
650679
visited[k] = true;
680+
num_visited += 1;
651681
flattened_iters.push_back(expr->args[k]);
652682
}
683+
if (match_constraint_suffix) {
684+
// all iters are used to match the constraint, but only a suffix is matched.
685+
break;
686+
}
653687
auto iter = sum_fuse_map_.find(constraint_to_match.value());
654688
ICHECK(iter != sum_fuse_map_.end());
655689
const IterMarkWithOffset& iter_matched = iter->second;
@@ -661,6 +695,7 @@ class IterMapRewriter : public ExprMutator {
661695
} else {
662696
// constraint_to_match not found, skip this iterator
663697
visited[j] = true;
698+
num_visited += 1;
664699
flattened_iters.push_back(expr->args[j]);
665700
grouped_iters.push_back(expr->args[j]);
666701
expected_scale *= expr->args[j]->extent;
@@ -681,6 +716,8 @@ class IterMapRewriter : public ExprMutator {
681716
// old iter
682717
if (!analyzer_->CanProveEqual(expected_extra_base, it->second.offset * base_scale.value())) {
683718
// the extra offset is not consistent with old
719+
diag_ctx_.Emit(Diagnostic::Error(expr->span)
720+
<< "Fuse iters failed, the extra offset is not consistent with old");
684721
return NullOpt;
685722
}
686723
return IterSumExpr({IterSplitExpr(it->second.mark, base_scale.value())},

tests/python/unittest/test_arith_intset.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
import tvm
1818
from tvm import te
19+
from tvm import tir
1920
from tvm.ir.base import structural_equal
2021

2122

@@ -218,14 +219,9 @@ def test_region_lower_bound_for_non_perfect_tile():
218219
h1 = tvm.tir.Var("h1", "int32")
219220
h2 = tvm.tir.Var("h2", "int32")
220221
h3 = tvm.tir.Var("h3", "int32")
221-
# h1, h2 are bounded, h3 is free
222-
var_dom = {
223-
h2: tvm.ir.Range(begin=0, end=2),
224-
h1: tvm.ir.Range(begin=0, end=5),
225-
}
226222
analyzer = tvm.arith.Analyzer()
227223

228-
def do_test_point_access(point, predicates, expect):
224+
def do_test_point_access(point, predicates, var_dom, expect):
229225
regions = tvm.arith.estimate_region_lower_bound(
230226
region=[
231227
tvm.ir.Range.from_min_extent(min_value=point, extent=1),
@@ -237,29 +233,68 @@ def do_test_point_access(point, predicates, expect):
237233
assert regions is None
238234
else:
239235
assert len(regions) == 1
240-
assert structural_equal(
241-
analyzer.simplify(expect[0], 3), analyzer.simplify(regions[0].min_value, 3)
242-
)
243-
assert structural_equal(
244-
analyzer.simplify(expect[1], 3), analyzer.simplify(regions[0].max_value, 3)
245-
)
246-
247-
# normal case of a non-uniform tiling
236+
for binding, expect_min, expect_max in expect:
237+
min_diff = expect_min - regions[0].min_value
238+
assert analyzer.simplify(tir.stmt_functor.substitute(min_diff, binding), 3) == 0
239+
max_diff = expect_max - regions[0].max_value
240+
assert analyzer.simplify(tir.stmt_functor.substitute(max_diff, binding), 3) == 0
241+
242+
# non-uniform tiling, single inner variable
248243
# h3 == 0: region is [1, 9]
249244
# 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 9]
250245
# h3 > 26: region is [h3 * 8, 223]
246+
do_test_point_access(
247+
point=h3 * 8 + h2,
248+
predicates=[1 <= h3 * 8 + h2, h3 * 8 + h2 < 224],
249+
var_dom={
250+
h2: tvm.ir.Range(begin=0, end=10),
251+
},
252+
expect=[
253+
(
254+
{},
255+
tvm.tir.max(h3 * 8, 1),
256+
tvm.tir.max(h3 * 8, 1)
257+
- tvm.tir.max(h3 * 8, 214)
258+
- tvm.tir.max(1 - h3 * 8, 0)
259+
+ 223,
260+
),
261+
({h3: 0}, 1, 9),
262+
({h3: 10}, h3 * 8, h3 * 8 + 9),
263+
({h3: 27}, h3 * 8, 223),
264+
],
265+
)
266+
267+
# non-uniform tiling, two inner variables
251268
do_test_point_access(
252269
point=h3 * 8 + h2 * 5 + h1,
253270
predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224],
254-
expect=(
255-
tvm.tir.max(h3 * 8, 1),
256-
tvm.tir.max(h3 * 8, 1) - tvm.tir.max(h3 * 8, 214) - tvm.tir.max(1 - h3 * 8, 0) + 223,
257-
),
271+
var_dom={
272+
h2: tvm.ir.Range(begin=0, end=2),
273+
h1: tvm.ir.Range(begin=0, end=5),
274+
},
275+
expect=[
276+
(
277+
{},
278+
tvm.tir.max(h3 * 8, 1),
279+
tvm.tir.max(h3 * 8, 1)
280+
- tvm.tir.max(h3 * 8, 214)
281+
- tvm.tir.max(1 - h3 * 8, 0)
282+
+ 223,
283+
),
284+
({h3: 0}, 1, 9),
285+
({h3: 10}, h3 * 8, h3 * 8 + 9),
286+
({h3: 27}, h3 * 8, 223),
287+
],
258288
)
289+
259290
# should fail on incompatible predicates
260291
do_test_point_access(
261292
point=h3 * 8 + h2 * 5 + h1,
262293
predicates=[1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224],
294+
var_dom={
295+
h2: tvm.ir.Range(begin=0, end=2),
296+
h1: tvm.ir.Range(begin=0, end=5),
297+
},
263298
expect=None,
264299
)
265300

0 commit comments

Comments
 (0)