Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ class IterMapRewriter : public ExprMutator {
if (predicate_induced_max.defined())
predicate_induced_max = predicate_induced_max.value() - base;
}
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_, false);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
// scale should be 1
if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
Expand Down Expand Up @@ -722,11 +722,7 @@ class IterMapRewriter : public ExprMutator {
IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
// We are normalizing a regular iter
if (expr->args.size() < 1) return expr;
if (auto opt = TryCombineSplitFromSameSource(expr)) {
expr = opt.value();
if (expr->args.size() < 1) return expr;
}
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_, true);
if (opt.defined()) {
return opt.value();
} else {
Expand Down Expand Up @@ -996,9 +992,18 @@ class IterMapRewriter : public ExprMutator {
* Try to normalize IterSum into a fused IterMark
* \param expr The input sum.
* \param check_level The check level if iter mapping.
* \return The sum with the fused IterMark and extra offset if succeed.
* \param allow_early_skip Whether do we allow early skip if expr is simple
* (this may cause us to return parameters that are not canonically wrapped as
* IterSum(IterMark)) \return The sum with the fused IterMark and extra offset if succeed.
*/
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level) {
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level,
bool allow_early_skip) {
if (auto opt = TryCombineSplitFromSameSource(expr)) {
expr = opt.value();
if (expr->args.size() <= 1 && allow_early_skip) {
return expr;
}
}
// select the iterators in order
std::vector<bool> visited(expr->args.size(), false);
int base_index = FindBaseIter(expr, visited, NullOpt);
Expand Down Expand Up @@ -1554,12 +1559,8 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
return IterSumExpr();
} else if (sum->args.size() == 1) {
return sum;
} else if (auto opt = TryCombineSplitFromSameSource(sum)) {
if (opt.value()->args.size() == 1) {
return opt.value();
}
}
auto opt_fused = TryFuseIters(sum, check_level_);
auto opt_fused = TryFuseIters(sum, check_level_, true);
if (!opt_fused) {
ErrorLogger(this) << "Dividend " << original_dividend
<< ", can't be written as a single fused IterSum";
Expand Down