[mlir][cf] Canonicalize block args with uniform incoming values#183966
Merged
[mlir][cf] Canonicalize block args with uniform incoming values#183966
Conversation
Member
|
@llvm/pr-subscribers-mlir-cf @llvm/pr-subscribers-mlir Author: Fedor Nikolaev (felichita) ChangesAdd a canonicalization pattern that replaces block arguments with a Idea from #182711 cc: @matthias-springer , @joker-eph Full diff: https://github.com/llvm/llvm-project/pull/183966.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 0ce0d55f4397c..cdc44122068b3 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -86,7 +86,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
return failure();
}
-// This side effect models "program termination".
+// This side effect models "program termination".
void AssertOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
@@ -204,9 +204,85 @@ static LogicalResult simplifyPassThroughBr(BranchOp op,
return success();
}
+/// %c = arith.constant 0 : i32
+/// cf.br ^bb1(%c : i32) // pred 1
+/// cf.br ^bb1(%c : i32) // pred 2
+/// ^bb1(%arg0: i32):
+/// use(%arg0)
+/// ->
+/// ^bb1(%arg0: i32):
+/// use(%c) // %arg0 has no uses and can be removed
+///
+/// If all incoming values for a block argument from all predecessors are the
+/// same SSA value, replace uses of the block argument with that value. This
+/// allows the block argument to be removed by dead code elimination.
+static bool simplifyUniformBlockArgs(Block *dest, PatternRewriter &rewriter) {
+ if (dest->hasNoPredecessors() ||
+ llvm::hasSingleElement(dest->getPredecessors()))
+ return false;
+
+ bool changed = false;
+ for (BlockArgument arg : dest->getArguments()) {
+ if (arg.use_empty())
+ continue;
+
+ Value commonValue;
+ bool allSame = true;
+ for (Block *pred : dest->getPredecessors()) {
+ auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
+ if (!branch) {
+ allSame = false;
+ break;
+ }
+
+ Value incoming;
+ for (unsigned i = 0; i < branch->getNumSuccessors(); ++i) {
+ if (branch->getSuccessor(i) != dest)
+ continue;
+ SuccessorOperands succOps = branch.getSuccessorOperands(i);
+ if (arg.getArgNumber() >= succOps.size()) {
+ allSame = false;
+ break;
+ }
+ incoming = succOps[arg.getArgNumber()];
+ break;
+ }
+ if (!incoming || (commonValue && commonValue != incoming)) {
+ allSame = false;
+ break;
+ }
+ commonValue = incoming;
+ }
+
+ if (allSame && commonValue && commonValue != arg) {
+ rewriter.replaceAllUsesWith(arg, commonValue);
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+namespace {
+/// Replaces block arguments with a uniform incoming value across all
+/// predecessors of a CondBranchOp successor.
+struct SimplifyCondBranchBlockArgWithUniformIncomingValues
+ : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CondBranchOp op,
+ PatternRewriter &rewriter) const override {
+ bool changed = false;
+ for (unsigned i = 0; i < op->getNumSuccessors(); ++i)
+ changed |= simplifyUniformBlockArgs(op->getSuccessor(i), rewriter);
+ return success(changed);
+ }
+};
+} // namespace
+
LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
- succeeded(simplifyPassThroughBr(op, rewriter)));
+ succeeded(simplifyPassThroughBr(op, rewriter)) ||
+ simplifyUniformBlockArgs(op.getDest(), rewriter));
}
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
@@ -492,7 +568,8 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
SimplifyCondBranchFromCondBranchOnSameCondition,
- CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
+ CondBranchTruthPropagation, DropUnreachableCondBranch,
+ SimplifyCondBranchBlockArgWithUniformIncomingValues>(context);
}
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 8ddfeb7b0841c..5bcd76badea59 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -131,8 +131,7 @@ func.func @cond_br_passthrough_weights(%arg0 : i32, %arg1 : i32, %cond : i1) ->
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[COND:.*]]: i1
func.func @cond_br_passthrough(%arg0 : i32, %arg1 : i32, %arg2 : i32, %cond : i1) -> (i32, i32) {
// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[ARG0]], %[[ARG2]]
- // CHECK: %[[RES2:.*]] = arith.select %[[COND]], %[[ARG1]], %[[ARG2]]
- // CHECK: return %[[RES]], %[[RES2]]
+ // CHECK: return %[[RES]], %[[ARG1]]
cf.cond_br %cond, ^bb1(%arg0 : i32), ^bb2(%arg2, %arg2 : i32, i32)
@@ -686,3 +685,38 @@ func.func @no_merge_self_arg_loop(%step: i1) -> i1 {
^exit(%result: i1):
return %result : i1
}
+
+// CHECK-LABEL: func @fold_uniform_branch_block_arg
+// CHECK-SAME: %[[COND:.*]]: i1, %[[C:.*]]: i32
+func.func @fold_uniform_branch_block_arg(%cond: i1, %c: i32) -> i32 {
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ "foo.op"() : () -> ()
+ cf.br ^bb3(%c : i32)
+^bb2:
+ "foo.op"() : () -> ()
+ cf.br ^bb3(%c : i32)
+^bb3(%arg0: i32):
+ // CHECK: ^bb3:
+ // CHECK: return %[[C]]
+ return %arg0 : i32
+}
+
+// Verify that block arguments are not folded when incoming values differ
+// across predecessors.
+
+// CHECK-LABEL: func @no_fold_non_uniform_block_arg
+// CHECK-SAME: %[[COND:.*]]: i1, %[[A:.*]]: i32, %[[B:.*]]: i32
+func.func @no_fold_non_uniform_block_arg(%cond: i1, %a: i32, %b: i32) -> i32 {
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ "foo.op"() : () -> ()
+ cf.br ^bb3(%a : i32)
+^bb2:
+ "foo.op"() : () -> ()
+ cf.br ^bb3(%b : i32)
+^bb3(%arg0: i32):
+ // CHECK: ^bb3(%[[ARG0:.*]]: i32):
+ // CHECK-NEXT: return %[[ARG0]]
+ return %arg0 : i32
+}
|
joker-eph
reviewed
Feb 28, 2026
joker-eph
reviewed
Feb 28, 2026
joker-eph
reviewed
Feb 28, 2026
824c76c to
5b59fa2
Compare
5b59fa2 to
84d3628
Compare
Member
matthias-springer
left a comment
There was a problem hiding this comment.
Looks good, just some minor comments.
84d3628 to
d317cd4
Compare
d317cd4 to
6bc3398
Compare
matthias-springer
approved these changes
Mar 4, 2026
|
|
||
| // CHECK-LABEL: func @no_fold_same_dest_different_args | ||
| func.func @no_fold_same_dest_different_args(%a: i32, %b: i32) -> i32 { | ||
| "test.producing_br"(%a, %b)[^bb1, ^bb1] |
Member
There was a problem hiding this comment.
Can you also add a test case where the folding actually takes places? "test.producing_br"(%a, %a)[^bb1, ^bb1]
Add a canonicalization pattern that replaces block arguments with a common SSA value when all predecessors pass the same value for that argument. This allows the block argument to be removed by dead code elimination. Ref llvm#182711
6bc3398 to
a2a391b
Compare
sahas3
pushed a commit
to sahas3/llvm-project
that referenced
this pull request
Mar 4, 2026
…#183966) Add a canonicalization pattern that replaces block arguments with a common SSA value when all predecessors pass the same value for that argument. This allows the block argument to be removed by dead code elimination. First itteration Idea from llvm#182711
sujianIBM
pushed a commit
to sujianIBM/llvm-project
that referenced
this pull request
Mar 5, 2026
…#183966) Add a canonicalization pattern that replaces block arguments with a common SSA value when all predecessors pass the same value for that argument. This allows the block argument to be removed by dead code elimination. First itteration Idea from llvm#182711
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add a canonicalization pattern that replaces block arguments with a
common SSA value when all predecessors pass the same value for that
argument. This allows the block argument to be removed by dead code
elimination. First itteration
Idea from #182711