Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 14 additions & 150 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -990,8 +990,9 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,

namespace {
// Fold away ForOp iter arguments when:
// 1) The argument's corresponding outer region iterators (inputs) are yielded.
// 2) The iter arguments have no use and the corresponding (operation) results
// 1) The op yields the iter arguments.
// 2) The argument's corresponding outer region iterators (inputs) are yielded.
// 3) The iter arguments have no use and the corresponding (operation) results
// have no use.
//
// These arguments must be defined outside of the ForOp region and can just be
Expand All @@ -1000,7 +1001,7 @@ namespace {
// The implementation uses `inlineBlockBefore` to steal the content of the
// original ForOp and avoid cloning.
struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
using Base::Base;
using OpRewritePattern<scf::ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const final {
Expand Down Expand Up @@ -1029,11 +1030,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
forOp.getYieldedValues() // iter yield
)) {
// Forwarded is `true` when:
// 1) The region `iter` argument the corresponding input is yielded.
// 2) The region `iter` argument has no use, and the corresponding op
// 1) The region `iter` argument is yielded.
// 2) The region `iter` argument the corresponding input is yielded.
// 3) The region `iter` argument has no use, and the corresponding op
// result has no use.
bool forwarded =
(init == yielded) || (arg.use_empty() && result.use_empty());
bool forwarded = (arg == yielded) || (init == yielded) ||
(arg.use_empty() && result.use_empty());
if (forwarded) {
canonicalize = true;
keepMask.push_back(false);
Expand Down Expand Up @@ -1131,7 +1133,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
/// single-iteration loops with their bodies, and removes empty loops that
/// iterate at least once and only return values defined outside of the loop.
struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
using Base::Base;
using OpRewritePattern<ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1202,7 +1204,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
/// use_of(%1)
/// ```
struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
using Base::Base;
using OpRewritePattern<ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -1234,100 +1236,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
}
};

/// Rewriting pattern that folds away cycles in the yield of a scf.for op.
///
/// ```
/// %res:2 = scf.for ... iter_args(%arg0 = %init, %arg1 = %init) {
/// ...
/// use %arg0, %arg1
/// scf.yield %arg1, %arg0
/// }
/// return %res#0, %res#1
/// ```
///
/// folds into:
///
/// ```
/// scf.for ... iter_args() {
/// ...
/// use %init, %init
/// scf.yield
/// }
/// return %init, %init
/// ```
struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
using Base::Base;

LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
ValueRange yieldedValues = op.getYieldedValues();
ValueRange initArgs = op.getInitArgs();
ValueRange results = op.getResults();
ValueRange regionIterArgs = op.getRegionIterArgs();
Block *body = op.getBody();

unsigned numYieldedValues = op.getNumRegionIterArgs();

bool changed = false;
SmallVector<unsigned> cycle;
llvm::SmallBitVector visited(numYieldedValues, false);

// Go through all possible start points for the cycle.
for (auto start : llvm::seq(numYieldedValues)) {
if (visited[start])
continue;

cycle.clear();
unsigned current = start;
bool validCycle = true;
Value initValue = initArgs[start];
// Go through yield -> block arg -> yield cycles and check if all values
// are always equal to the init.
while (!visited[current]) {
cycle.push_back(current);
visited[current] = true;

// Find whether this yield is from a region iter arg.
auto yieldedValue = yieldedValues[current];
if (auto arg = dyn_cast<BlockArgument>(yieldedValue);
!arg || arg.getOwner() != body) {
validCycle = false;
break;
}

// Next yield position.
current = cast<BlockArgument>(yieldedValue).getArgNumber() -
op.getNumInductionVars();

// Check if next position has the same init value.
if (initArgs[current] != initValue) {
validCycle = false;
break;
}
}

// If we found a valid cycle (yielding own iter arg forms cycle of length
// 1), all values in it are always equal to initValue.
if (validCycle) {
changed = true;
for (unsigned idx : cycle) {
// This will leave region args and results dead so other
// canonicalization patterns can clean them up.
rewriter.replaceAllUsesWith(regionIterArgs[idx], initValue);
rewriter.replaceAllUsesWith(results[idx], initValue);
}
}
}
return success(changed);
}
};

} // namespace

void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder,
ForOpYieldCyclesFolder>(context);
results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
context);
}

std::optional<APInt> ForOp::getConstantStep() {
Expand Down Expand Up @@ -4797,59 +4711,9 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
}
};

/// Canonicalization patterns that folds away dead results of
/// "scf.index_switch" ops.
struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IndexSwitchOp op,
PatternRewriter &rewriter) const override {
// Find dead results.
BitVector deadResults(op.getNumResults(), false);
SmallVector<Type> newResultTypes;
for (auto [idx, result] : llvm::enumerate(op.getResults())) {
if (!result.use_empty()) {
newResultTypes.push_back(result.getType());
} else {
deadResults[idx] = true;
}
}
if (!deadResults.any())
return rewriter.notifyMatchFailure(op, "no dead results to fold");

// Create new op without dead results and inline case regions.
auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
op.getArg(), op.getCases(),
op.getCaseRegions().size());
auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
// Remove respective operands from yield op.
Operation *terminator = newRegion.front().getTerminator();
assert(isa<YieldOp>(terminator) && "expected yield op");
rewriter.modifyOpInPlace(
terminator, [&]() { terminator->eraseOperands(deadResults); });
};
for (auto [oldRegion, newRegion] :
llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
inlineCaseRegion(oldRegion, newRegion);
inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());

// Replace op with new op.
SmallVector<Value> newResults(op.getNumResults(), Value());
unsigned nextNewResult = 0;
for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
if (deadResults[idx])
continue;
newResults[idx] = newOp.getResult(nextNewResult++);
}
rewriter.replaceOp(op, newResults);
return success();
}
};

void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
results.add<FoldConstantCase>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
115 changes: 24 additions & 91 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1665,11 +1665,11 @@ func.func @func_execute_region_inline_multi_yield() {
module {
func.func private @foo()->()
func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%1 = scf.execute_region -> memref<1x60xui8> no_inline {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%1 = scf.execute_region -> memref<1x60xui8> no_inline {
func.call @foo():()->()
scf.yield %alloc: memref<1x60xui8>
}
}
return %1 : memref<1x60xui8>
}
}
Expand All @@ -1688,12 +1688,12 @@ func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8>
module {
func.func private @foo()->()
func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
func.call @foo():()->()
scf.yield %alloc, %alloc_1: memref<1x60xui8>, memref<1x120xui8>
}
}
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}
Expand All @@ -1716,18 +1716,18 @@ func.func private @execute_region_yeilding_external_and_local_values() -> (memre
module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
^bb2:
^bb2:
func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
^bb3:
func.call @foo():()->()
^bb3:
func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
}
}
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}
Expand All @@ -1746,19 +1746,19 @@ module {
module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
^bb2:
^bb2:
func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
^bb3:
func.call @foo():()->()
^bb3:
func.call @foo():()->()
scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8>
}
}
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}
Expand All @@ -1778,18 +1778,18 @@ module {
module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>) {
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%1 = scf.execute_region -> (memref<1x60xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
^bb2:
^bb2:
func.call @foo():()->()
scf.yield %alloc : memref<1x60xui8>
^bb3:
^bb3:
func.call @foo():()->()
scf.yield %alloc_1 : memref<1x60xui8>
}
}
return %1 : memref<1x60xui8>
}
}
Expand Down Expand Up @@ -2171,70 +2171,3 @@ func.func @scf_for_all_step_size_0() {
}
return
}

// -----

func.func private @side_effect()

// CHECK-LABEL: func @iter_args_cycles
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i64, %[[C:.*]]: f32)
// CHECK: scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: func.call @side_effect()
// CHECK-NOT: yield
// CHECK: return %[[A]], %[[B]], %[[A]], %[[B]], %[[B]], %[[C]] : i32, i64, i32, i64, i64, f32
func.func @iter_args_cycles(%lb : index, %ub : index, %step : index, %a : i32, %b : i64, %c : f32) -> (i32, i64, i32, i64, i64, f32) {
%res:6 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %a, %3 = %b, %4 = %b, %5 = %c) -> (i32, i64, i32, i64, i64, f32) {
func.call @side_effect() : () -> ()
scf.yield %2, %4, %0, %1, %3, %5 : i32, i64, i32, i64, i64, f32
}
return %res#0, %res#1, %res#2, %res#3, %res#4, %res#5 : i32, i64, i32, i64, i64, f32
}

// -----

func.func private @side_effect(i32)

// CHECK-LABEL: func @iter_args_cycles_non_cycle_start
// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i32)
// CHECK: %[[RES:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER_ARG:.*]] = %[[A]]) -> (i32) {
// CHECK: func.call @side_effect(%[[ITER_ARG]])
// CHECK: yield %[[B]] : i32
// CHECK: return %[[RES]], %[[B]], %[[B]] : i32, i32, i32
func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : index, %a : i32, %b : i32) -> (i32, i32, i32) {
%res:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %b) -> (i32, i32, i32) {
func.call @side_effect(%0) : (i32) -> ()
scf.yield %1, %2, %1 : i32, i32, i32
}
return %res#0, %res#1, %res#2 : i32, i32, i32
}

// -----

// CHECK-LABEL: func @dead_index_switch_result(
// CHECK-SAME: %[[arg0:.*]]: index
// CHECK-DAG: %[[c10:.*]] = arith.constant 10
// CHECK-DAG: %[[c11:.*]] = arith.constant 11
// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
// CHECK: case 1 {
// CHECK: memref.store %[[c10]]
// CHECK: scf.yield %[[arg0]] : index
// CHECK: }
// CHECK: default {
// CHECK: memref.store %[[c11]]
// CHECK: scf.yield %[[arg0]] : index
// CHECK: }
// CHECK: return %[[switch]]
func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
%non_live, %live = scf.index_switch %arg0 -> i32, index
case 1 {
%c10 = arith.constant 10 : i32
memref.store %c10, %arg1[] : memref<i32>
scf.yield %c10, %arg0 : i32, index
}
default {
%c11 = arith.constant 11 : i32
memref.store %c11, %arg1[] : memref<i32>
scf.yield %c11, %arg0 : i32, index
}
return %live : index
}