Revert "[mlir][scf] Fold away scf.for iter args cycles (#173436)"#173991
Revert "[mlir][scf] Fold away scf.for iter args cycles (#173436)"#173991googlewalt merged 2 commits intollvm:mainfrom
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: Walter Lee (googlewalt) ChangesIt causes issues with Triton usage. Also revert dependent "[mlir][SCF] index_switch results (#173560)". Full diff: https://github.com/llvm/llvm-project/pull/173991.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 46d09abd89d69..652414f6cbe54 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -990,8 +990,9 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
namespace {
// Fold away ForOp iter arguments when:
-// 1) The argument's corresponding outer region iterators (inputs) are yielded.
-// 2) The iter arguments have no use and the corresponding (operation) results
+// 1) The op yields the iter arguments.
+// 2) The argument's corresponding outer region iterators (inputs) are yielded.
+// 3) The iter arguments have no use and the corresponding (operation) results
// have no use.
//
// These arguments must be defined outside of the ForOp region and can just be
@@ -1000,7 +1001,7 @@ namespace {
// The implementation uses `inlineBlockBefore` to steal the content of the
// original ForOp and avoid cloning.
struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
- using Base::Base;
+ using OpRewritePattern<scf::ForOp>::OpRewritePattern;
LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const final {
@@ -1029,11 +1030,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
forOp.getYieldedValues() // iter yield
)) {
// Forwarded is `true` when:
- // 1) The region `iter` argument the corresponding input is yielded.
- // 2) The region `iter` argument has no use, and the corresponding op
+ // 1) The region `iter` argument is yielded.
+ // 2) The region `iter` argument the corresponding input is yielded.
+ // 3) The region `iter` argument has no use, and the corresponding op
// result has no use.
- bool forwarded =
- (init == yielded) || (arg.use_empty() && result.use_empty());
+ bool forwarded = (arg == yielded) || (init == yielded) ||
+ (arg.use_empty() && result.use_empty());
if (forwarded) {
canonicalize = true;
keepMask.push_back(false);
@@ -1131,7 +1133,7 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
/// single-iteration loops with their bodies, and removes empty loops that
/// iterate at least once and only return values defined outside of the loop.
struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
- using Base::Base;
+ using OpRewritePattern<ForOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
@@ -1202,7 +1204,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
/// use_of(%1)
/// ```
struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
- using Base::Base;
+ using OpRewritePattern<ForOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
@@ -1234,100 +1236,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
}
};
-/// Rewriting pattern that folds away cycles in the yield of a scf.for op.
-///
-/// ```
-/// %res:2 = scf.for ... iter_args(%arg0 = %init, %arg1 = %init) {
-/// ...
-/// use %arg0, %arg1
-/// scf.yield %arg1, %arg0
-/// }
-/// return %res#0, %res#1
-/// ```
-///
-/// folds into:
-///
-/// ```
-/// scf.for ... iter_args() {
-/// ...
-/// use %init, %init
-/// scf.yield
-/// }
-/// return %init, %init
-/// ```
-struct ForOpYieldCyclesFolder : public OpRewritePattern<ForOp> {
- using Base::Base;
-
- LogicalResult matchAndRewrite(ForOp op,
- PatternRewriter &rewriter) const override {
- ValueRange yieldedValues = op.getYieldedValues();
- ValueRange initArgs = op.getInitArgs();
- ValueRange results = op.getResults();
- ValueRange regionIterArgs = op.getRegionIterArgs();
- Block *body = op.getBody();
-
- unsigned numYieldedValues = op.getNumRegionIterArgs();
-
- bool changed = false;
- SmallVector<unsigned> cycle;
- llvm::SmallBitVector visited(numYieldedValues, false);
-
- // Go through all possible start points for the cycle.
- for (auto start : llvm::seq(numYieldedValues)) {
- if (visited[start])
- continue;
-
- cycle.clear();
- unsigned current = start;
- bool validCycle = true;
- Value initValue = initArgs[start];
- // Go through yield -> block arg -> yield cycles and check if all values
- // are always equal to the init.
- while (!visited[current]) {
- cycle.push_back(current);
- visited[current] = true;
-
- // Find whether this yield is from a region iter arg.
- auto yieldedValue = yieldedValues[current];
- if (auto arg = dyn_cast<BlockArgument>(yieldedValue);
- !arg || arg.getOwner() != body) {
- validCycle = false;
- break;
- }
-
- // Next yield position.
- current = cast<BlockArgument>(yieldedValue).getArgNumber() -
- op.getNumInductionVars();
-
- // Check if next position has the same init value.
- if (initArgs[current] != initValue) {
- validCycle = false;
- break;
- }
- }
-
- // If we found a valid cycle (yielding own iter arg forms cycle of length
- // 1), all values in it are always equal to initValue.
- if (validCycle) {
- changed = true;
- for (unsigned idx : cycle) {
- // This will leave region args and results dead so other
- // canonicalization patterns can clean them up.
- rewriter.replaceAllUsesWith(regionIterArgs[idx], initValue);
- rewriter.replaceAllUsesWith(results[idx], initValue);
- }
- }
- }
- return success(changed);
- }
-};
-
} // namespace
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder,
- ForOpYieldCyclesFolder>(context);
+ results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
+ context);
}
std::optional<APInt> ForOp::getConstantStep() {
@@ -4797,59 +4711,9 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
}
};
-/// Canonicalization patterns that folds away dead results of
-/// "scf.index_switch" ops.
-struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
- using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(IndexSwitchOp op,
- PatternRewriter &rewriter) const override {
- // Find dead results.
- BitVector deadResults(op.getNumResults(), false);
- SmallVector<Type> newResultTypes;
- for (auto [idx, result] : llvm::enumerate(op.getResults())) {
- if (!result.use_empty()) {
- newResultTypes.push_back(result.getType());
- } else {
- deadResults[idx] = true;
- }
- }
- if (!deadResults.any())
- return rewriter.notifyMatchFailure(op, "no dead results to fold");
-
- // Create new op without dead results and inline case regions.
- auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
- op.getArg(), op.getCases(),
- op.getCaseRegions().size());
- auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
- rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
- // Remove respective operands from yield op.
- Operation *terminator = newRegion.front().getTerminator();
- assert(isa<YieldOp>(terminator) && "expected yield op");
- rewriter.modifyOpInPlace(
- terminator, [&]() { terminator->eraseOperands(deadResults); });
- };
- for (auto [oldRegion, newRegion] :
- llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
- inlineCaseRegion(oldRegion, newRegion);
- inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
-
- // Replace op with new op.
- SmallVector<Value> newResults(op.getNumResults(), Value());
- unsigned nextNewResult = 0;
- for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
- if (deadResults[idx])
- continue;
- newResults[idx] = newOp.getResult(nextNewResult++);
- }
- rewriter.replaceOp(op, newResults);
- return success();
- }
-};
-
void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
+ results.add<FoldConstantCase>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 984ea10f7e540..ac590fc0c47b9 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1665,11 +1665,11 @@ func.func @func_execute_region_inline_multi_yield() {
module {
func.func private @foo()->()
func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8> {
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
- %1 = scf.execute_region -> memref<1x60xui8> no_inline {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %1 = scf.execute_region -> memref<1x60xui8> no_inline {
func.call @foo():()->()
scf.yield %alloc: memref<1x60xui8>
- }
+ }
return %1 : memref<1x60xui8>
}
}
@@ -1688,12 +1688,12 @@ func.func private @execute_region_yeilding_external_value() -> memref<1x60xui8>
module {
func.func private @foo()->()
func.func private @execute_region_yeilding_external_and_local_values() -> (memref<1x60xui8>, memref<1x120xui8>) {
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
- %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
func.call @foo():()->()
scf.yield %alloc, %alloc_1: memref<1x60xui8>, memref<1x120xui8>
- }
+ }
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}
@@ -1716,18 +1716,18 @@ func.func private @execute_region_yeilding_external_and_local_values() -> (memre
module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_same_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
- %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
- ^bb2:
+ ^bb2:
func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
- ^bb3:
- func.call @foo():()->()
+ ^bb3:
+ func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
- }
+ }
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}
@@ -1746,19 +1746,19 @@ module {
module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>, memref<1x120xui8>) {
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
- %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
- %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
+ %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<1x120xui8>
%1, %2 = scf.execute_region -> (memref<1x60xui8>, memref<1x120xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
- ^bb2:
+ ^bb2:
func.call @foo():()->()
scf.yield %alloc, %alloc_1 : memref<1x60xui8>, memref<1x120xui8>
- ^bb3:
- func.call @foo():()->()
+ ^bb3:
+ func.call @foo():()->()
scf.yield %alloc, %alloc_2 : memref<1x60xui8>, memref<1x120xui8>
- }
+ }
return %1, %2 : memref<1x60xui8>, memref<1x120xui8>
}
}
@@ -1778,18 +1778,18 @@ module {
module {
func.func private @foo()->()
func.func private @execute_region_multiple_yields_different_operands() -> (memref<1x60xui8>) {
- %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
- %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
+ %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x60xui8>
%1 = scf.execute_region -> (memref<1x60xui8>) no_inline {
%c = "test.cmp"() : () -> i1
cf.cond_br %c, ^bb2, ^bb3
- ^bb2:
+ ^bb2:
func.call @foo():()->()
scf.yield %alloc : memref<1x60xui8>
- ^bb3:
+ ^bb3:
func.call @foo():()->()
scf.yield %alloc_1 : memref<1x60xui8>
- }
+ }
return %1 : memref<1x60xui8>
}
}
@@ -2171,70 +2171,3 @@ func.func @scf_for_all_step_size_0() {
}
return
}
-
-// -----
-
-func.func private @side_effect()
-
-// CHECK-LABEL: func @iter_args_cycles
-// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i64, %[[C:.*]]: f32)
-// CHECK: scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
-// CHECK: func.call @side_effect()
-// CHECK-NOT: yield
-// CHECK: return %[[A]], %[[B]], %[[A]], %[[B]], %[[B]], %[[C]] : i32, i64, i32, i64, i64, f32
-func.func @iter_args_cycles(%lb : index, %ub : index, %step : index, %a : i32, %b : i64, %c : f32) -> (i32, i64, i32, i64, i64, f32) {
- %res:6 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %a, %3 = %b, %4 = %b, %5 = %c) -> (i32, i64, i32, i64, i64, f32) {
- func.call @side_effect() : () -> ()
- scf.yield %2, %4, %0, %1, %3, %5 : i32, i64, i32, i64, i64, f32
- }
- return %res#0, %res#1, %res#2, %res#3, %res#4, %res#5 : i32, i64, i32, i64, i64, f32
-}
-
-// -----
-
-func.func private @side_effect(i32)
-
-// CHECK-LABEL: func @iter_args_cycles_non_cycle_start
-// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[A:.*]]: i32, %[[B:.*]]: i32)
-// CHECK: %[[RES:.*]] = scf.for %[[IV:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER_ARG:.*]] = %[[A]]) -> (i32) {
-// CHECK: func.call @side_effect(%[[ITER_ARG]])
-// CHECK: yield %[[B]] : i32
-// CHECK: return %[[RES]], %[[B]], %[[B]] : i32, i32, i32
-func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : index, %a : i32, %b : i32) -> (i32, i32, i32) {
- %res:3 = scf.for %i = %lb to %ub step %step iter_args(%0 = %a, %1 = %b, %2 = %b) -> (i32, i32, i32) {
- func.call @side_effect(%0) : (i32) -> ()
- scf.yield %1, %2, %1 : i32, i32, i32
- }
- return %res#0, %res#1, %res#2 : i32, i32, i32
-}
-
-// -----
-
-// CHECK-LABEL: func @dead_index_switch_result(
-// CHECK-SAME: %[[arg0:.*]]: index
-// CHECK-DAG: %[[c10:.*]] = arith.constant 10
-// CHECK-DAG: %[[c11:.*]] = arith.constant 11
-// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
-// CHECK: case 1 {
-// CHECK: memref.store %[[c10]]
-// CHECK: scf.yield %[[arg0]] : index
-// CHECK: }
-// CHECK: default {
-// CHECK: memref.store %[[c11]]
-// CHECK: scf.yield %[[arg0]] : index
-// CHECK: }
-// CHECK: return %[[switch]]
-func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
- %non_live, %live = scf.index_switch %arg0 -> i32, index
- case 1 {
- %c10 = arith.constant 10 : i32
- memref.store %c10, %arg1[] : memref<i32>
- scf.yield %c10, %arg0 : i32, index
- }
- default {
- %c11 = arith.constant 11 : i32
- memref.store %c11, %arg1[] : memref<i32>
- scf.yield %c11, %arg0 : i32, index
- }
- return %live : index
-}
|
|
I think "[mlir][SCF] index_switch results (#173560)" is not dependent on the other commit; they just conflict because they update the same test file. So after these two reverts we should reapply 173560, fixing the conflict. (I am testing locally just to be sure.) |
|
Sounds good. Thanks for checking. |
|
Done. Re-landed. |
Gah, failed to paste the back-link to this PR in the commit message. Apologies. |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/169/builds/18466 Here is the relevant piece of the build log for the reference |
…llvm#173991) It causes issues with Triton usage. Also revert dependent "[mlir][SCF] index_switch results (llvm#173560)".
It causes issues with Triton usage.
Also revert dependent "[mlir][SCF] index_switch results (#173560)".