diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 957621c16bf2b..2610fa3d1cffe 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -517,6 +517,16 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, // Do (2) BitVector successorNonLive = markLives(operandValues, nonLiveSet, la).flip(); + // A block argument should not be considered dead if the liveness analysis + // determines it is live. This can happen when a branch is in a statically + // unreachable (dead) block: the forwarded operand appears dead because it + // is in the dead block, but the successor block argument may still be live + // because it is also forwarded from other live predecessor branches. + for (auto [index, blockArg] : + llvm::enumerate(successorBlock->getArguments())) { + if (successorNonLive[index] && hasLive({blockArg}, nonLiveSet, la)) + successorNonLive.reset(index); + } collectNonLiveValues(nonLiveSet, successorBlock->getArguments(), successorNonLive); diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 19bc6b2fddd66..343f7616f73fb 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -847,3 +847,33 @@ module @func_with_non_call_users { } spirv.EntryPoint "GLCompute" @callee } + + +// ----- + +// Verify that RemoveDeadValues does not crash when a branch in a statically +// dead block forwards a dead value to a block argument that is live (because +// another, reachable predecessor also forwards a value to it). +// The constant comparison %cmp is always true (2 < 5), so ^bb2 is unreachable. +// The block argument %arg0 in ^bb3 must not be incorrectly marked dead. +// CHECK-LABEL: func.func @branch_in_dead_block_live_successor_arg +func.func @branch_in_dead_block_live_successor_arg() -> i64 { + %c0 = arith.constant 0 : i64 + %c1 = arith.constant 1 : i64 + %sum = arith.addi %c0, %c1 : i64 + %c2 = arith.constant 2 : i64 + %mul = arith.muli %sum, %c2 : i64 + %c5 = arith.constant 5 : i64 + %cmp = arith.cmpi slt, %mul, %c5 : i64 + cf.cond_br %cmp, ^bb1, ^bb2 +^bb1: + %c10 = arith.constant 10 : i64 + cf.br ^bb3(%c10 : i64) +^bb2: + %c20 = arith.constant 20 : i64 + cf.br ^bb3(%c20 : i64) +^bb3(%arg0: i64): + // CHECK: arith.addi + %result = arith.addi %arg0, %sum : i64 + func.return %result : i64 +}