Skip to content

[mlir][cf] Canonicalize block args with uniform incoming values#183966

Merged
joker-eph merged 1 commit intollvm:mainfrom
felichita:cf-canonicalize-uniform-block-args
Mar 4, 2026
Merged

[mlir][cf] Canonicalize block args with uniform incoming values#183966
joker-eph merged 1 commit intollvm:mainfrom
felichita:cf-canonicalize-uniform-block-args

Conversation

@felichita
Copy link
Contributor

@felichita felichita commented Feb 28, 2026

Add a canonicalization pattern that replaces block arguments with a
common SSA value when all predecessors pass the same value for that
argument. This allows the block argument to be removed by dead code
elimination. First itteration

Idea from #182711

@llvmbot
Copy link
Member

llvmbot commented Feb 28, 2026

@llvm/pr-subscribers-mlir-cf

@llvm/pr-subscribers-mlir

Author: Fedor Nikolaev (felichita)

Changes

Add a canonicalization pattern that replaces block arguments with a
common SSA value when all predecessors pass the same value for that
argument. This allows the block argument to be removed by dead code
elimination. First itteration

Idea from #182711

cc: @matthias-springer , @joker-eph


Full diff: https://github.com/llvm/llvm-project/pull/183966.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+80-3)
  • (modified) mlir/test/Dialect/ControlFlow/canonicalize.mlir (+36-2)
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 0ce0d55f4397c..cdc44122068b3 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -86,7 +86,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
   return failure();
 }
 
-// This side effect models "program termination". 
+// This side effect models "program termination".
 void AssertOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -204,9 +204,85 @@ static LogicalResult simplifyPassThroughBr(BranchOp op,
   return success();
 }
 
+///   %c = arith.constant 0 : i32
+///   cf.br ^bb1(%c : i32)      // pred 1
+///   cf.br ^bb1(%c : i32)      // pred 2
+/// ^bb1(%arg0: i32):
+///   use(%arg0)
+/// ->
+/// ^bb1(%arg0: i32):
+///   use(%c)                   // %arg0 has no uses and can be removed
+///
+/// If all incoming values for a block argument from all predecessors are the
+/// same SSA value, replace uses of the block argument with that value. This
+/// allows the block argument to be removed by dead code elimination.
+static bool simplifyUniformBlockArgs(Block *dest, PatternRewriter &rewriter) {
+  if (dest->hasNoPredecessors() ||
+      llvm::hasSingleElement(dest->getPredecessors()))
+    return false;
+
+  bool changed = false;
+  for (BlockArgument arg : dest->getArguments()) {
+    if (arg.use_empty())
+      continue;
+
+    Value commonValue;
+    bool allSame = true;
+    for (Block *pred : dest->getPredecessors()) {
+      auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
+      if (!branch) {
+        allSame = false;
+        break;
+      }
+
+      Value incoming;
+      for (unsigned i = 0; i < branch->getNumSuccessors(); ++i) {
+        if (branch->getSuccessor(i) != dest)
+          continue;
+        SuccessorOperands succOps = branch.getSuccessorOperands(i);
+        if (arg.getArgNumber() >= succOps.size()) {
+          allSame = false;
+          break;
+        }
+        incoming = succOps[arg.getArgNumber()];
+        break;
+      }
+      if (!incoming || (commonValue && commonValue != incoming)) {
+        allSame = false;
+        break;
+      }
+      commonValue = incoming;
+    }
+
+    if (allSame && commonValue && commonValue != arg) {
+      rewriter.replaceAllUsesWith(arg, commonValue);
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+namespace {
+/// Replaces block arguments with a uniform incoming value across all
+/// predecessors of a CondBranchOp successor.
+struct SimplifyCondBranchBlockArgWithUniformIncomingValues
+    : public OpRewritePattern<CondBranchOp> {
+  using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(CondBranchOp op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    for (unsigned i = 0; i < op->getNumSuccessors(); ++i)
+      changed |= simplifyUniformBlockArgs(op->getSuccessor(i), rewriter);
+    return success(changed);
+  }
+};
+} // namespace
+
 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
-                 succeeded(simplifyPassThroughBr(op, rewriter)));
+                 succeeded(simplifyPassThroughBr(op, rewriter)) ||
+                 simplifyUniformBlockArgs(op.getDest(), rewriter));
 }
 
 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
@@ -492,7 +568,8 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
               SimplifyCondBranchIdenticalSuccessors,
               SimplifyCondBranchFromCondBranchOnSameCondition,
-              CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
+              CondBranchTruthPropagation, DropUnreachableCondBranch,
+              SimplifyCondBranchBlockArgWithUniformIncomingValues>(context);
 }
 
 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 8ddfeb7b0841c..5bcd76badea59 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -131,8 +131,7 @@ func.func @cond_br_passthrough_weights(%arg0 : i32, %arg1 : i32, %cond : i1) ->
 // CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
 func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
   // CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[ARG0]], %[[ARG2]]
-  // CHECK: %[[RES2:.*]] = arith.select %[[COND]], %[[ARG1]], %[[ARG2]]
-  // CHECK: return %[[RES]], %[[RES2]]
+  // CHECK: return %[[RES]], %[[ARG1]]
 
   cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32)
 
@@ -686,3 +685,38 @@ func.func @no_merge_self_arg_loop(%step: i1) -> i1 {
 ^exit(%result: i1):
   return %result : i1
 }
+
+// CHECK-LABEL: func @fold_uniform_branch_block_arg
+// CHECK-SAME: %[[COND:.*]]: i1, %[[C:.*]]: i32
+func.func @fold_uniform_branch_block_arg(%cond: i1, %c: i32) -> i32 {
+  cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%c : i32)
+^bb2:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%c : i32)
+^bb3(%arg0: i32):
+  // CHECK: ^bb3:
+  // CHECK: return %[[C]]
+  return %arg0 : i32
+}
+
+// Verify that block arguments are not folded when incoming values differ
+// across predecessors.
+
+// CHECK-LABEL: func @no_fold_non_uniform_block_arg
+// CHECK-SAME: %[[COND:.*]]: i1, %[[A:.*]]: i32, %[[B:.*]]: i32
+func.func @no_fold_non_uniform_block_arg(%cond: i1, %a: i32, %b: i32) -> i32 {
+  cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%a : i32)
+^bb2:
+  "foo.op"() : () -> ()
+  cf.br ^bb3(%b : i32)
+^bb3(%arg0: i32):
+  // CHECK: ^bb3(%[[ARG0:.*]]: i32):
+  // CHECK-NEXT: return %[[ARG0]]
+  return %arg0 : i32
+}

@felichita felichita force-pushed the cf-canonicalize-uniform-block-args branch from 824c76c to 5b59fa2 Compare March 1, 2026 12:28
@felichita felichita requested a review from joker-eph March 1, 2026 12:31
@felichita felichita force-pushed the cf-canonicalize-uniform-block-args branch from 5b59fa2 to 84d3628 Compare March 3, 2026 13:11
Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just some minor comments.

@felichita felichita force-pushed the cf-canonicalize-uniform-block-args branch from 84d3628 to d317cd4 Compare March 3, 2026 22:30
@felichita felichita force-pushed the cf-canonicalize-uniform-block-args branch from d317cd4 to 6bc3398 Compare March 3, 2026 23:11

// CHECK-LABEL: func @no_fold_same_dest_different_args
func.func @no_fold_same_dest_different_args(%a: i32, %b: i32) -> i32 {
"test.producing_br"(%a, %b)[^bb1, ^bb1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test case where the folding actually takes places? "test.producing_br"(%a, %a)[^bb1, ^bb1]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Add a canonicalization pattern that replaces block arguments with a
common SSA value when all predecessors pass the same value for that
argument. This allows the block argument to be removed by dead code
elimination.

Ref llvm#182711
@felichita felichita force-pushed the cf-canonicalize-uniform-block-args branch from 6bc3398 to a2a391b Compare March 4, 2026 10:29
@joker-eph joker-eph merged commit f1aa7c3 into llvm:main Mar 4, 2026
10 checks passed
sahas3 pushed a commit to sahas3/llvm-project that referenced this pull request Mar 4, 2026
…#183966)

Add a canonicalization pattern that replaces block arguments with a
common SSA value when all predecessors pass the same value for that
argument. This allows the block argument to be removed by dead code
elimination. First itteration

Idea from llvm#182711
sujianIBM pushed a commit to sujianIBM/llvm-project that referenced this pull request Mar 5, 2026
…#183966)

Add a canonicalization pattern that replaces block arguments with a
common SSA value when all predecessors pass the same value for that
argument. This allows the block argument to be removed by dead code
elimination. First itteration

Idea from llvm#182711
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants