diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 46d09abd89d69..652414f6cbe54 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -990,8 +990,9 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, namespace { // Fold away ForOp iter arguments when: -// 1) The argument's corresponding outer region iterators (inputs) are yielded. -// 2) The iter arguments have no use and the corresponding (operation) results +// 1) The op yields the iter arguments. +// 2) The argument's corresponding outer region iterators (inputs) are yielded. +// 3) The iter arguments have no use and the corresponding (operation) results // have no use. // // These arguments must be defined outside of the ForOp region and can just be @@ -1000,7 +1001,7 @@ namespace { // The implementation uses `inlineBlockBefore` to steal the content of the // original ForOp and avoid cloning. struct ForOpIterArgsFolder : public OpRewritePattern { - using Base::Base; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const final { @@ -1029,11 +1030,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern { forOp.getYieldedValues() // iter yield )) { // Forwarded is `true` when: - // 1) The region `iter` argument the corresponding input is yielded. - // 2) The region `iter` argument has no use, and the corresponding op + // 1) The region `iter` argument is yielded. + // 2) The region `iter` argument the corresponding input is yielded. + // 3) The region `iter` argument has no use, and the corresponding op // result has no use. - bool forwarded = - (init == yielded) || (arg.use_empty() && result.use_empty()); + bool forwarded = (arg == yielded) || (init == yielded) || + (arg.use_empty() && result.use_empty()); if (forwarded) { canonicalize = true; keepMask.push_back(false); @@ -1131,7 +1133,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern { /// single-iteration loops with their bodies, and removes empty loops that /// iterate at least once and only return values defined outside of the loop. struct SimplifyTrivialLoops : public OpRewritePattern { - using Base::Base; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ForOp op, PatternRewriter &rewriter) const override { @@ -1202,7 +1204,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern { /// use_of(%1) /// ``` struct ForOpTensorCastFolder : public OpRewritePattern { - using Base::Base; + using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ForOp op, PatternRewriter &rewriter) const override { @@ -1234,100 +1236,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern { } }; -/// Rewriting pattern that folds away cycles in the yield of a scf.for op. -/// -/// ``` -/// %res:2 = scf.for ... iter_args(%arg0 = %init, %arg1 = %init) { -/// ... -/// use %arg0, %arg1 -/// scf.yield %arg1, %arg0 -/// } -/// return %res#0, %res#1 -/// ``` -/// -/// folds into: -/// -/// ``` -/// scf.for ... iter_args() { -/// ... -/// use %init, %init -/// scf.yield -/// } -/// return %init, %init -/// ``` -struct ForOpYieldCyclesFolder : public OpRewritePattern { - using Base::Base; - - LogicalResult matchAndRewrite(ForOp op, - PatternRewriter &rewriter) const override { - ValueRange yieldedValues = op.getYieldedValues(); - ValueRange initArgs = op.getInitArgs(); - ValueRange results = op.getResults(); - ValueRange regionIterArgs = op.getRegionIterArgs(); - Block *body = op.getBody(); - - unsigned numYieldedValues = op.getNumRegionIterArgs(); - - bool changed = false; - SmallVector cycle; - llvm::SmallBitVector visited(numYieldedValues, false); - - // Go through all possible start points for the cycle. - for (auto start : llvm::seq(numYieldedValues)) { - if (visited[start]) - continue; - - cycle.clear(); - unsigned current = start; - bool validCycle = true; - Value initValue = initArgs[start]; - // Go through yield -> block arg -> yield cycles and check if all values - // are always equal to the init. - while (!visited[current]) { - cycle.push_back(current); - visited[current] = true; - - // Find whether this yield is from a region iter arg. - auto yieldedValue = yieldedValues[current]; - if (auto arg = dyn_cast(yieldedValue); - !arg || arg.getOwner() != body) { - validCycle = false; - break; - } - - // Next yield position. - current = cast(yieldedValue).getArgNumber() - - op.getNumInductionVars(); - - // Check if next position has the same init value. - if (initArgs[current] != initValue) { - validCycle = false; - break; - } - } - - // If we found a valid cycle (yielding own iter arg forms cycle of length - // 1), all values in it are always equal to initValue. - if (validCycle) { - changed = true; - for (unsigned idx : cycle) { - // This will leave region args and results dead so other - // canonicalization patterns can clean them up. - rewriter.replaceAllUsesWith(regionIterArgs[idx], initValue); - rewriter.replaceAllUsesWith(results[idx], initValue); - } - } - } - return success(changed); - } -}; - } // namespace void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add( + context); } std::optional ForOp::getConstantStep() { @@ -4797,59 +4711,9 @@ struct FoldConstantCase : OpRewritePattern { } }; -/// Canonicalization patterns that folds away dead results of -/// "scf.index_switch" ops. -struct FoldUnusedIndexSwitchResults : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(IndexSwitchOp op, - PatternRewriter &rewriter) const override { - // Find dead results. - BitVector deadResults(op.getNumResults(), false); - SmallVector newResultTypes; - for (auto [idx, result] : llvm::enumerate(op.getResults())) { - if (!result.use_empty()) { - newResultTypes.push_back(result.getType()); - } else { - deadResults[idx] = true; - } - } - if (!deadResults.any()) - return rewriter.notifyMatchFailure(op, "no dead results to fold"); - - // Create new op without dead results and inline case regions. - auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes, - op.getArg(), op.getCases(), - op.getCaseRegions().size()); - auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) { - rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); - // Remove respective operands from yield op. - Operation *terminator = newRegion.front().getTerminator(); - assert(isa(terminator) && "expected yield op"); - rewriter.modifyOpInPlace( - terminator, [&]() { terminator->eraseOperands(deadResults); }); - }; - for (auto [oldRegion, newRegion] : - llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions())) - inlineCaseRegion(oldRegion, newRegion); - inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion()); - - // Replace op with new op. - SmallVector newResults(op.getNumResults(), Value()); - unsigned nextNewResult = 0; - for (unsigned idx = 0; idx < op.getNumResults(); ++idx) { - if (deadResults[idx]) - continue; - newResults[idx] = newOp.getResult(nextNewResult++); - } - rewriter.replaceOp(op, newResults); - return success(); - } -}; - void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 984ea10f7e540..ac590fc0c47b9 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1665,11 +1665,11 @@ func.func @func_execute_region_inline_multi_yield() { module { func.func private @foo()->() func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> - %1 = scf.execute_region -> memref<1x60xui8> no_inline { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %1 = scf.execute_region -> memref<1x60xui8> no_inline { func.call @foo():()->() scf.yield %alloc: memref<1x60xui8> - } + } return %1 : memref<1x60xui8> } } @@ -1688,12 +1688,12 @@ func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> module { func.func private @foo()->() func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> - %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline { + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline { %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> func.call @foo():()->() scf.yield %alloc, %alloc_1: memref<1x60xui8>, memref<1x120xui8> - } + } return %1, %2 : memref<1x60xui8>, memref<1x120xui8> } } @@ -1716,18 +1716,18 @@ func.func private @execute_region_yeilding_external_and_local_values() -> (memre module { func.func private @foo()->() func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline { %c = "test.cmp"() : () -> i1 cf.cond_br %c, ^bb2, ^bb3 - ^bb2: + ^bb2: func.call @foo():()->() scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8> - ^bb3: - func.call @foo():()->() + ^bb3: + func.call @foo():()->() scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8> - } + } return %1, %2 : memref<1x60xui8>, memref<1x120xui8> } } @@ -1746,19 +1746,19 @@ module { module { func.func private @foo()->() func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>, memref<1x120xui8>) { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> - %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8> %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline { %c = "test.cmp"() : () -> i1 cf.cond_br %c, ^bb2, ^bb3 - ^bb2: + ^bb2: func.call @foo():()->() scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8> - ^bb3: - func.call @foo():()->() + ^bb3: + func.call @foo():()->() scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8> - } + } return %1, %2 : memref<1x60xui8>, memref<1x120xui8> } } @@ -1778,18 +1778,18 @@ module { module { func.func private @foo()->() func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>) { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8> %1 = scf.execute_region -> (memref<1x60xui8>) no_inline { %c = "test.cmp"() : () -> i1 cf.cond_br %c, ^bb2, ^bb3 - ^bb2: + ^bb2: func.call @foo():()->() scf.yield %alloc : memref<1x60xui8> - ^bb3: + ^bb3: func.call @foo():()->() scf.yield %alloc_1 : memref<1x60xui8> - } + } return %1 : memref<1x60xui8> } } @@ -2171,70 +2171,3 @@ func.func @scf_for_all_step_size_0() { } return } - -// ----- - -func.func private @side_effect() - -// CHECK-LABEL: func @iter_args_cycles -// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i64, %[[C:.*]]: f32) -// CHECK: scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { -// CHECK: func.call @side_effect() -// CHECK-NOT: yield -// CHECK: return %[[A]], %[[B]], %[[A]], %[[B]], %[[B]], %[[C]] : i32, i64, i32, i64, i64, f32 -func.func @iter_args_cycles(%lb : index, %ub : index, %step : index, %a : i32, %b : i64, %c : f32) -> (i32, i64, i32, i64, i64, f32) { - %res:6 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %a, %3 = %b, %4 = %b, %5 = %c) -> (i32, i64, i32, i64, i64, f32) { - func.call @side_effect() : () -> () - scf.yield %2, %4, %0, %1, %3, %5 : i32, i64, i32, i64, i64, f32 - } - return %res#0, %res#1, %res#2, %res#3, %res#4, %res#5 : i32, i64, i32, i64, i64, f32 -} - -// ----- - -func.func private @side_effect(i32) - -// CHECK-LABEL: func @iter_args_cycles_non_cycle_start -// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i32) -// CHECK: %[[RES:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER_ARG:.*]] = %[[A]]) -> (i32) { -// CHECK: func.call @side_effect(%[[ITER_ARG]]) -// CHECK: yield %[[B]] : i32 -// CHECK: return %[[RES]], %[[B]], %[[B]] : i32, i32, i32 -func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : index, %a : i32, %b : i32) -> (i32, i32, i32) { - %res:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %b) -> (i32, i32, i32) { - func.call @side_effect(%0) : (i32) -> () - scf.yield %1, %2, %1 : i32, i32, i32 - } - return %res#0, %res#1, %res#2 : i32, i32, i32 -} - -// ----- - -// CHECK-LABEL: func @dead_index_switch_result( -// CHECK-SAME: %[[arg0:.*]]: index -// CHECK-DAG: %[[c10:.*]] = arith.constant 10 -// CHECK-DAG: %[[c11:.*]] = arith.constant 11 -// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index -// CHECK: case 1 { -// CHECK: memref.store %[[c10]] -// CHECK: scf.yield %[[arg0]] : index -// CHECK: } -// CHECK: default { -// CHECK: memref.store %[[c11]] -// CHECK: scf.yield %[[arg0]] : index -// CHECK: } -// CHECK: return %[[switch]] -func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref) -> index { - %non_live, %live = scf.index_switch %arg0 -> i32, index - case 1 { - %c10 = arith.constant 10 : i32 - memref.store %c10, %arg1[] : memref - scf.yield %c10, %arg0 : i32, index - } - default { - %c11 = arith.constant 11 : i32 - memref.store %c11, %arg1[] : memref - scf.yield %c11, %arg0 : i32, index - } - return %live : index -}