diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp index 0ce0d55f4397c..435c37bc95aac 100644 --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -86,7 +86,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { return failure(); } -// This side effect models "program termination". +// This side effect models "program termination". void AssertOp::getEffects( SmallVectorImpl> &effects) { @@ -204,9 +204,83 @@ static LogicalResult simplifyPassThroughBr(BranchOp op, return success(); } +/// If all incoming values for a block argument from all predecessors are the +/// same SSA value, replace uses of the block argument with that value. This +/// allows the block argument to be removed by dead code elimination. +/// +/// %c = arith.constant 0 : i32 +/// cf.br ^bb1(%c : i32) // pred 1 +/// cf.br ^bb1(%c : i32) // pred 2 +/// ^bb1(%arg0: i32): +/// use(%arg0) +/// -> +/// ^bb1(%arg0: i32): +/// use(%c) // %arg0 has no uses and can be removed +/// +static LogicalResult simplifyUniformBlockArgs(Block *dest, + PatternRewriter &rewriter) { + if (dest->hasNoPredecessors() || + llvm::hasSingleElement(dest->getPredecessors())) + return failure(); + + bool changed = false; + for (BlockArgument arg : dest->getArguments()) { + if (arg.use_empty()) + continue; + + Value commonValue; + for (Block *pred : dest->getPredecessors()) { + auto branch = dyn_cast(pred->getTerminator()); + if (!branch) { + commonValue = Value(); + break; + } + + for (auto [i, succ] : llvm::enumerate(branch->getSuccessors())) { + if (succ != dest) + continue; + + // Produced operands are modeled by BranchOpInterface as null Values. + Value val = branch.getSuccessorOperands(i)[arg.getArgNumber()]; + if (commonValue && commonValue != val) { + commonValue = Value(); + break; + } + commonValue = val; + } + + if (!commonValue) + break; + } + + if (commonValue && commonValue != arg) { + rewriter.replaceAllUsesWith(arg, commonValue); + changed = true; + } + } + return success(changed); +} + +namespace { +/// Replaces block arguments with a uniform incoming value across all +/// predecessors, for any op implementing BranchOpInterface. +struct SimplifyUniformBlockArguments + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + LogicalResult matchAndRewrite(BranchOpInterface op, + PatternRewriter &rewriter) const override { + bool changed = false; + for (Block *succ : op->getSuccessors()) + changed |= succeeded(simplifyUniformBlockArgs(succ, rewriter)); + return success(changed); + } +}; +} // namespace + LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || - succeeded(simplifyPassThroughBr(op, rewriter))); + succeeded(simplifyPassThroughBr(op, rewriter)) || + succeeded(simplifyUniformBlockArgs(op.getDest(), rewriter))); } void BranchOp::setDest(Block *block) { return setSuccessor(block); } @@ -492,7 +566,8 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); + CondBranchTruthPropagation, DropUnreachableCondBranch, + SimplifyUniformBlockArguments>(context); } SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) { @@ -955,7 +1030,8 @@ void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, .add(&simplifyConstSwitchValue) .add(&simplifyPassThroughSwitch) .add(&simplifySwitchFromSwitchOnSameCondition) - .add(&simplifySwitchFromDefaultSwitchOnSameCondition); + .add(&simplifySwitchFromDefaultSwitchOnSameCondition) + .add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir index 8ddfeb7b0841c..c8cf5931a1297 100644 --- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir +++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir @@ -686,3 +686,101 @@ func.func @no_merge_self_arg_loop(%step: i1) -> i1 { ^exit(%result: i1): return %result : i1 } + +// Verify that block arguments are replaced with a uniform incoming value +// when all predecessors pass the same SSA value + +// CHECK-LABEL: func @fold_uniform_branch_block_arg +// CHECK-SAME: %[[COND:.*]]: i1, %[[C:.*]]: i32 +func.func @fold_uniform_branch_block_arg(%cond: i1, %c: i32) -> i32 { + cf.cond_br %cond, ^bb1, ^bb2 +^bb1: + "foo.op"() : () -> () + cf.br ^bb3(%c : i32) +^bb2: + "foo.op"() : () -> () + cf.br ^bb3(%c : i32) +^bb3(%arg0: i32): + // CHECK: ^bb3: + // CHECK: return %[[C]] + return %arg0 : i32 +} + +// Verify that block arguments are not folded when incoming values differ +// across predecessors. + +// CHECK-LABEL: func @no_fold_non_uniform_block_arg +// CHECK-SAME: %[[COND:.*]]: i1, %[[A:.*]]: i32, %[[B:.*]]: i32 +func.func @no_fold_non_uniform_block_arg(%cond: i1, %a: i32, %b: i32) -> i32 { + cf.cond_br %cond, ^bb1, ^bb2 +^bb1: + "foo.op"() : () -> () + cf.br ^bb3(%a : i32) +^bb2: + "foo.op"() : () -> () + cf.br ^bb3(%b : i32) +^bb3(%arg0: i32): + // CHECK: ^bb3(%[[ARG0:.*]]: i32): + // CHECK-NEXT: return %[[ARG0]] + return %arg0 : i32 +} + +// Verify no folding when the same block appears multiple times as a +// successor with different operands. + +// CHECK-LABEL: func @no_fold_same_dest_different_args +func.func @no_fold_same_dest_different_args(%a: i32, %b: i32) -> i32 { + "test.producing_br"(%a, %b)[^bb1, ^bb1] + {operandSegmentSizes = array} : (i32, i32) -> i32 +^bb1(%arg0: i32): + // CHECK: ^bb1(%[[ARG0:.*]]: i32): + // CHECK-NEXT: return %[[ARG0]] + return %arg0 : i32 +} + +// Verify folding when the same block appears multiple times as a +// successor with the same operand. + +// CHECK-LABEL: func @fold_same_dest_same_args +// CHECK-SAME: %[[A:.*]]: i32 +func.func @fold_same_dest_same_args(%a: i32) -> i32 { + "test.producing_br"(%a, %a)[^bb1, ^bb1] + {operandSegmentSizes = array} : (i32, i32) -> i32 +^bb1(%arg0: i32): + // CHECK: ^bb1: + // CHECK-NEXT: return %[[A]] + return %arg0 : i32 +} + +// Verify no folding when a predecessor has an unknown terminator. + +// CHECK-LABEL: func @no_fold_unknown_terminator +func.func @no_fold_unknown_terminator(%a: i32) -> i32 { + cf.br ^bb1 +^bb1: + "foo.two_successors"()[^bb2, ^bb3] : () -> () +^bb2: + // CHECK: ^bb2(%[[ARG0:.*]]: i32): + cf.br ^bb3(%a : i32) +^bb3(%arg0: i32): + // CHECK-NEXT: return %[[ARG0]] + return %arg0 : i32 +} + +// Verify that unused block arguments are skipped and only used arguments +// with uniform incoming values are folded. + +// CHECK-LABEL: func @skip_unused_block_arg +func.func @skip_unused_block_arg(%flag: i32, %a: i32, %b: i32) -> i32 { + "foo.pred"()[^bb1, ^bb2] : () -> () +^bb1: + cf.br ^bb3(%b, %a : i32, i32) +^bb2: + cf.br ^bb3(%a, %a : i32, i32) +^bb3(%arg0: i32, %arg1: i32): + "foo.use"(%arg0) : (i32) -> () + // CHECK: ^bb3(%[[ARG0:.*]]: i32): + // CHECK-NEXT: "foo.use"(%[[ARG0]]) + // CHECK-NEXT: return %[[A:.*]] + return %arg1 : i32 +}