[mlir][SCF] Fold unused index_switch results#173560
Conversation
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd a canonicalization pattern to fold unused Full diff: https://github.com/llvm/llvm-project/pull/173560.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 652414f6cbe54..0a123112cf68f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4711,9 +4711,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
}
};
+/// Canonicalization patterns that folds away dead results of
+/// "scf.index_switch" ops.
+struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
+ using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexSwitchOp op,
+ PatternRewriter &rewriter) const override {
+ // Find dead results.
+ BitVector deadResults(op.getNumResults(), false);
+ SmallVector<Type> newResultTypes;
+ for (auto [idx, result] : llvm::enumerate(op.getResults())) {
+ if (!result.use_empty()) {
+ newResultTypes.push_back(result.getType());
+ } else {
+ deadResults[idx] = true;
+ }
+ }
+ if (!deadResults.any())
+ return rewriter.notifyMatchFailure(op, "no dead results to fold");
+
+ // Create new op without dead results and inline case regions.
+ auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
+ op.getArg(), op.getCases(),
+ op.getCaseRegions().size());
+ auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
+ rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
+ // Remove respective operands from yield op.
+ Operation *terminator = newRegion.front().getTerminator();
+ assert(isa<YieldOp>(terminator) && "expected yield op");
+ rewriter.modifyOpInPlace(
+ terminator, [&]() { terminator->eraseOperands(deadResults); });
+ };
+ for (auto [oldRegion, newRegion] :
+ llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
+ inlineCaseRegion(oldRegion, newRegion);
+ inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
+
+ // Replace op with new op.
+ SmallVector<Value> newResults(op.getNumResults(), Value());
+ unsigned nextNewResult = 0;
+ for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+ if (deadResults[idx])
+ continue;
+ newResults[idx] = newOp.getResult(nextNewResult++);
+ }
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+};
+
void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldConstantCase>(context);
+ results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..d5d0aee3bbe25 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2171,3 +2171,34 @@ func.func @scf_for_all_step_size_0() {
}
return
}
+
+// -----
+
+// CHECK-LABEL: func @dead_index_switch_result(
+// CHECK-SAME: %[[arg0:.*]]: index
+// CHECK-DAG: %[[c10:.*]] = arith.constant 10
+// CHECK-DAG: %[[c11:.*]] = arith.constant 11
+// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+// CHECK: case 1 {
+// CHECK: memref.store %[[c10]]
+// CHECK: scf.yield %[[arg0]] : index
+// CHECK: }
+// CHECK: default {
+// CHECK: memref.store %[[c11]]
+// CHECK: scf.yield %[[arg0]] : index
+// CHECK: }
+// CHECK: return %[[switch]]
+func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
+ %non_live, %live = scf.index_switch %arg0 -> i32, index
+ case 1 {
+ %c10 = arith.constant 10 : i32
+ memref.store %c10, %arg1[] : memref<i32>
+ scf.yield %c10, %arg0 : i32, index
+ }
+ default {
+ %c11 = arith.constant 11 : i32
+ memref.store %c11, %arg1[] : memref<i32>
+ scf.yield %c11, %arg0 : i32, index
+ }
+ return %live : index
+}
|
|
@llvm/pr-subscribers-mlir-scf Author: Matthias Springer (matthias-springer) ChangesAdd a canonicalization pattern to fold unused Full diff: https://github.com/llvm/llvm-project/pull/173560.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 652414f6cbe54..0a123112cf68f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4711,9 +4711,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
}
};
+/// Canonicalization patterns that folds away dead results of
+/// "scf.index_switch" ops.
+struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
+ using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IndexSwitchOp op,
+ PatternRewriter &rewriter) const override {
+ // Find dead results.
+ BitVector deadResults(op.getNumResults(), false);
+ SmallVector<Type> newResultTypes;
+ for (auto [idx, result] : llvm::enumerate(op.getResults())) {
+ if (!result.use_empty()) {
+ newResultTypes.push_back(result.getType());
+ } else {
+ deadResults[idx] = true;
+ }
+ }
+ if (!deadResults.any())
+ return rewriter.notifyMatchFailure(op, "no dead results to fold");
+
+ // Create new op without dead results and inline case regions.
+ auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
+ op.getArg(), op.getCases(),
+ op.getCaseRegions().size());
+ auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
+ rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
+ // Remove respective operands from yield op.
+ Operation *terminator = newRegion.front().getTerminator();
+ assert(isa<YieldOp>(terminator) && "expected yield op");
+ rewriter.modifyOpInPlace(
+ terminator, [&]() { terminator->eraseOperands(deadResults); });
+ };
+ for (auto [oldRegion, newRegion] :
+ llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
+ inlineCaseRegion(oldRegion, newRegion);
+ inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
+
+ // Replace op with new op.
+ SmallVector<Value> newResults(op.getNumResults(), Value());
+ unsigned nextNewResult = 0;
+ for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+ if (deadResults[idx])
+ continue;
+ newResults[idx] = newOp.getResult(nextNewResult++);
+ }
+ rewriter.replaceOp(op, newResults);
+ return success();
+ }
+};
+
void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<FoldConstantCase>(context);
+ results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..d5d0aee3bbe25 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2171,3 +2171,34 @@ func.func @scf_for_all_step_size_0() {
}
return
}
+
+// -----
+
+// CHECK-LABEL: func @dead_index_switch_result(
+// CHECK-SAME: %[[arg0:.*]]: index
+// CHECK-DAG: %[[c10:.*]] = arith.constant 10
+// CHECK-DAG: %[[c11:.*]] = arith.constant 11
+// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+// CHECK: case 1 {
+// CHECK: memref.store %[[c10]]
+// CHECK: scf.yield %[[arg0]] : index
+// CHECK: }
+// CHECK: default {
+// CHECK: memref.store %[[c11]]
+// CHECK: scf.yield %[[arg0]] : index
+// CHECK: }
+// CHECK: return %[[switch]]
+func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
+ %non_live, %live = scf.index_switch %arg0 -> i32, index
+ case 1 {
+ %c10 = arith.constant 10 : i32
+ memref.store %c10, %arg1[] : memref<i32>
+ scf.yield %c10, %arg0 : i32, index
+ }
+ default {
+ %c11 = arith.constant 11 : i32
+ memref.store %c11, %arg1[] : memref<i32>
+ scf.yield %c11, %arg0 : i32, index
+ }
+ return %live : index
+}
|
|
This pattern consists of two parts.
|
This is not part of the pattern. I just needed a side-effecting op in the test case.
|
f5dadff to
3018e9f
Compare
529a159 to
da738b4
Compare
Whatever a canonicalizer can do without a dataflow analysis should be done: the fact that another more complex pass can do it isn't a reason to not canonicalize: it less heavyweight and more generally applicable. |
make semse. |
9bef674 to
cd480a2
Compare
| /// Canonicalization patterns that folds away dead results of | ||
| /// "scf.index_switch" ops. | ||
| struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> { | ||
| using OpRewritePattern<IndexSwitchOp>::OpRewritePattern; |
This reverts commit 5f5560f.
This reverts commit 85bfb54f9dfcb323f7a8cbb38a264a596aa1a3d3, i.e. it reapplies #173560 which was temporarily reverted in
Add a canonicalization pattern to fold unused `scf.index_switch` results.
…llvm#173991) It causes issues with Triton usage. Also revert dependent "[mlir][SCF] index_switch results (llvm#173560)".
This reverts commit 85bfb54f9dfcb323f7a8cbb38a264a596aa1a3d3, i.e. it reapplies llvm#173560 which was temporarily reverted in
Add a canonicalization pattern to fold unused
scf.index_switchresults.Depends on #173542.