Skip to content

[mlir][SCF] Fold unused index_switch results#173560

Merged
matthias-springer merged 1 commit intomainfrom
users/matthias-springer/fold_index_switch
Dec 28, 2025
Merged

[mlir][SCF] Fold unused index_switch results#173560
matthias-springer merged 1 commit intomainfrom
users/matthias-springer/fold_index_switch

Conversation

@matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Dec 25, 2025

Add a canonicalization pattern to fold unused scf.index_switch results.

Depends on #173542.

@llvmbot
Copy link
Member

llvmbot commented Dec 25, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add a canonicalization pattern to fold unused scf.index_switch results.


Full diff: https://github.com/llvm/llvm-project/pull/173560.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+51-1)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+31)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 652414f6cbe54..0a123112cf68f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4711,9 +4711,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
   }
 };
 
+/// Canonicalization patterns that folds away dead results of
+/// "scf.index_switch" ops.
+struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
+  using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IndexSwitchOp op,
+                                PatternRewriter &rewriter) const override {
+    // Find dead results.
+    BitVector deadResults(op.getNumResults(), false);
+    SmallVector<Type> newResultTypes;
+    for (auto [idx, result] : llvm::enumerate(op.getResults())) {
+      if (!result.use_empty()) {
+        newResultTypes.push_back(result.getType());
+      } else {
+        deadResults[idx] = true;
+      }
+    }
+    if (!deadResults.any())
+      return rewriter.notifyMatchFailure(op, "no dead results to fold");
+
+    // Create new op without dead results and inline case regions.
+    auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
+                                       op.getArg(), op.getCases(),
+                                       op.getCaseRegions().size());
+    auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
+      rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
+      // Remove respective operands from yield op.
+      Operation *terminator = newRegion.front().getTerminator();
+      assert(isa<YieldOp>(terminator) && "expected yield op");
+      rewriter.modifyOpInPlace(
+          terminator, [&]() { terminator->eraseOperands(deadResults); });
+    };
+    for (auto [oldRegion, newRegion] :
+         llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
+      inlineCaseRegion(oldRegion, newRegion);
+    inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
+
+    // Replace op with new op.
+    SmallVector<Value> newResults(op.getNumResults(), Value());
+    unsigned nextNewResult = 0;
+    for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+      if (deadResults[idx])
+        continue;
+      newResults[idx] = newOp.getResult(nextNewResult++);
+    }
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
+};
+
 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<FoldConstantCase>(context);
+  results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..d5d0aee3bbe25 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2171,3 +2171,34 @@ func.func @scf_for_all_step_size_0()  {
   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @dead_index_switch_result(
+//  CHECK-SAME:     %[[arg0:.*]]: index
+//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10
+//   CHECK-DAG:   %[[c11:.*]] = arith.constant 11
+//       CHECK:   %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+//       CHECK:   case 1 {
+//       CHECK:     memref.store %[[c10]]
+//       CHECK:     scf.yield %[[arg0]] : index
+//       CHECK:   } 
+//       CHECK:   default {
+//       CHECK:     memref.store %[[c11]]
+//       CHECK:     scf.yield %[[arg0]] : index
+//       CHECK:   }
+//       CHECK:   return %[[switch]]
+func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
+  %non_live, %live = scf.index_switch %arg0 -> i32, index
+  case 1 {
+    %c10 = arith.constant 10 : i32
+    memref.store %c10, %arg1[] : memref<i32>
+    scf.yield %c10, %arg0 : i32, index
+  }
+  default {
+    %c11 = arith.constant 11 : i32
+    memref.store %c11, %arg1[] : memref<i32>
+    scf.yield %c11, %arg0 : i32, index
+  }
+  return %live : index
+}

@llvmbot
Copy link
Member

llvmbot commented Dec 25, 2025

@llvm/pr-subscribers-mlir-scf

Author: Matthias Springer (matthias-springer)

Changes

Add a canonicalization pattern to fold unused scf.index_switch results.


Full diff: https://github.com/llvm/llvm-project/pull/173560.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+51-1)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+31)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 652414f6cbe54..0a123112cf68f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4711,9 +4711,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> {
   }
 };
 
+/// Canonicalization patterns that folds away dead results of
+/// "scf.index_switch" ops.
+struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
+  using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IndexSwitchOp op,
+                                PatternRewriter &rewriter) const override {
+    // Find dead results.
+    BitVector deadResults(op.getNumResults(), false);
+    SmallVector<Type> newResultTypes;
+    for (auto [idx, result] : llvm::enumerate(op.getResults())) {
+      if (!result.use_empty()) {
+        newResultTypes.push_back(result.getType());
+      } else {
+        deadResults[idx] = true;
+      }
+    }
+    if (!deadResults.any())
+      return rewriter.notifyMatchFailure(op, "no dead results to fold");
+
+    // Create new op without dead results and inline case regions.
+    auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
+                                       op.getArg(), op.getCases(),
+                                       op.getCaseRegions().size());
+    auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
+      rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
+      // Remove respective operands from yield op.
+      Operation *terminator = newRegion.front().getTerminator();
+      assert(isa<YieldOp>(terminator) && "expected yield op");
+      rewriter.modifyOpInPlace(
+          terminator, [&]() { terminator->eraseOperands(deadResults); });
+    };
+    for (auto [oldRegion, newRegion] :
+         llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
+      inlineCaseRegion(oldRegion, newRegion);
+    inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
+
+    // Replace op with new op.
+    SmallVector<Value> newResults(op.getNumResults(), Value());
+    unsigned nextNewResult = 0;
+    for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+      if (deadResults[idx])
+        continue;
+      newResults[idx] = newOp.getResult(nextNewResult++);
+    }
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
+};
+
 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<FoldConstantCase>(context);
+  results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..d5d0aee3bbe25 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2171,3 +2171,34 @@ func.func @scf_for_all_step_size_0()  {
   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @dead_index_switch_result(
+//  CHECK-SAME:     %[[arg0:.*]]: index
+//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10
+//   CHECK-DAG:   %[[c11:.*]] = arith.constant 11
+//       CHECK:   %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+//       CHECK:   case 1 {
+//       CHECK:     memref.store %[[c10]]
+//       CHECK:     scf.yield %[[arg0]] : index
+//       CHECK:   } 
+//       CHECK:   default {
+//       CHECK:     memref.store %[[c11]]
+//       CHECK:     scf.yield %[[arg0]] : index
+//       CHECK:   }
+//       CHECK:   return %[[switch]]
+func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index {
+  %non_live, %live = scf.index_switch %arg0 -> i32, index
+  case 1 {
+    %c10 = arith.constant 10 : i32
+    memref.store %c10, %arg1[] : memref<i32>
+    scf.yield %c10, %arg0 : i32, index
+  }
+  default {
+    %c11 = arith.constant 11 : i32
+    memref.store %c11, %arg1[] : memref<i32>
+    scf.yield %c11, %arg0 : i32, index
+  }
+  return %live : index
+}

@linuxlonelyeagle
Copy link
Member

This pattern consists of two parts.

  • Part of it is the functionality of remove-dead-values, but I believe it ought to be implemented within remove-dead-values.
  • Part of it is folding constantOp into memref.store.
    For this part, I believe you should perhaps write a fold function for memref.store.

@matthias-springer
Copy link
Member Author

matthias-springer commented Dec 25, 2025

Part of it is folding constantOp into memref.store.

This is not part of the pattern. I just needed a side-effecting op in the test case.

Part of it is the functionality of remove-dead-values, but I believe it ought to be implemented within remove-dead-values.

remove-dead-values indeed performs this kind of folding. I'm trying to remove it as part of #173505 and this PR is in preparation of that change. I sent an RFC on Discourse.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/rdv_3 branch from f5dadff to 3018e9f Compare December 26, 2025 11:29
Base automatically changed from users/matthias-springer/rdv_3 to main December 26, 2025 11:44
@matthias-springer matthias-springer force-pushed the users/matthias-springer/fold_index_switch branch from 529a159 to da738b4 Compare December 26, 2025 11:47
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/rdv_2 December 26, 2025 11:48
@joker-eph
Copy link
Collaborator

Part of it is the functionality of remove-dead-values, but I believe it ought to be implemented within remove-dead-values.

Whatever a canonicalizer can do without a dataflow analysis should be done: the fact that another more complex pass can do it isn't a reason to not canonicalize: it less heavyweight and more generally applicable.

Copy link
Member

@linuxlonelyeagle linuxlonelyeagle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good.

@linuxlonelyeagle
Copy link
Member

Part of it is the functionality of remove-dead-values, but I believe it ought to be implemented within remove-dead-values.

Whatever a canonicalizer can do without a dataflow analysis should be done: the fact that another more complex pass can do it isn't a reason to not canonicalize: it less heavyweight and more generally applicable.

make semse.

@matthias-springer matthias-springer changed the base branch from users/matthias-springer/rdv_2 to main December 28, 2025 18:24
@matthias-springer matthias-springer enabled auto-merge (squash) December 28, 2025 18:36
@matthias-springer matthias-springer merged commit 5f5560f into main Dec 28, 2025
38 of 42 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/fold_index_switch branch December 28, 2025 19:02
/// Canonicalization patterns that folds away dead results of
/// "scf.index_switch" ops.
struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can do using Base::Base

googlewalt added a commit to googlewalt/llvm-project that referenced this pull request Dec 30, 2025
googlewalt added a commit that referenced this pull request Dec 30, 2025
…73991)

It causes issues with Triton usage.

Also revert dependent "[mlir][SCF] index_switch results (#173560)".
cota added a commit that referenced this pull request Dec 30, 2025
This reverts commit 85bfb54f9dfcb323f7a8cbb38a264a596aa1a3d3,
i.e. it reapplies #173560 which was temporarily reverted in
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jan 6, 2026
Add a canonicalization pattern to fold unused `scf.index_switch`
results.
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jan 6, 2026
…llvm#173991)

It causes issues with Triton usage.

Also revert dependent "[mlir][SCF] index_switch results (llvm#173560)".
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jan 6, 2026
This reverts commit 85bfb54f9dfcb323f7a8cbb38a264a596aa1a3d3,
i.e. it reapplies llvm#173560 which was temporarily reverted in
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants