Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 31 additions & 102 deletions mlir/lib/Transforms/RemoveDeadValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
// terminator operands of region branch ops, and,
// (D) Removes simple and region branch ops that have all non-live results and
// don't affect memory in any way,
// (E) Replaces dead operands of branch ops with `ub.poison`, relying on the
// canonicalizer to remove the corresponding block arguments.
//
// iff
//
Expand Down Expand Up @@ -101,24 +103,11 @@ struct OperandsToCleanup {
bool replaceWithPoison = false;
};

struct BlockArgsToCleanup {
Block *b;
BitVector nonLiveArgs;
};

struct SuccessorOperandsToCleanup {
BranchOpInterface branch;
unsigned successorIndex;
BitVector nonLiveOperands;
};

struct RDVFinalCleanupList {
SmallVector<Operation *> operations;
SmallVector<FunctionToCleanUp> functions;
SmallVector<OperandsToCleanup> operands;
SmallVector<ResultsToCleanup> results;
SmallVector<BlockArgsToCleanup> blocks;
SmallVector<SuccessorOperandsToCleanup> successorOperands;
};

// Some helper functions...
Expand Down Expand Up @@ -476,11 +465,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
///
/// Otherwise, iterate through each successor block of `branchOp`.
/// (1) For each successor block, gather all operands from all successors.
/// (2) Fetch their associated liveness analysis data and collect for future
/// removal.
/// (3) Identify and collect the dead operands from the successor block
/// as well as their corresponding arguments.

/// (2) Determine which operands are dead using liveness analysis.
/// (3) Replace dead successor operands with ub.poison instead of erasing them.
/// Block arguments are left intact — the canonicalizer will remove them
/// once it sees all incoming operands are poison.
static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
Expand All @@ -490,7 +478,8 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
BitVector deadNonForwardedOperands =
markLives(branchOp->getOperands(), nonLiveSet, la).flip();
unsigned numSuccessors = branchOp->getNumSuccessors();
for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {

for (unsigned succIdx : llvm::seq<unsigned>(0, numSuccessors)) {
SuccessorOperands successorOperands =
branchOp.getSuccessorOperands(succIdx);
// Remove all non-forwarded operands from the bit vector.
Expand All @@ -502,28 +491,20 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
return;
}

for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
Block *successorBlock = branchOp->getSuccessor(succIdx);

// Do (1)
SuccessorOperands successorOperands =
branchOp.getSuccessorOperands(succIdx);
SmallVector<Value> operandValues;
for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
++operandIdx) {
operandValues.push_back(successorOperands[operandIdx]);
// For each successor, find dead forwarded operands and
// schedule them for replacement with ub.poison.
BitVector opNonLive(branchOp->getNumOperands(), false);
for (unsigned succIdx : llvm::seq<unsigned>(0, numSuccessors)) {
for (OpOperand &opOperand :
branchOp.getSuccessorOperands(succIdx).getMutableForwardedOperands()) {
if (!hasLive(opOperand.get(), nonLiveSet, la))
opNonLive.set(opOperand.getOperandNumber());
}

// Do (2)
BitVector successorNonLive =
markLives(operandValues, nonLiveSet, la).flip();
collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
successorNonLive);

// Do (3)
cl.blocks.push_back({successorBlock, successorNonLive});
cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
}

if (opNonLive.any())
cl.operands.push_back({branchOp.getOperation(), opNonLive,
/*callee=*/nullptr, /*replaceWithPoison=*/true});
}

/// Create ub.poison ops for the given values. If a value has no uses, return
Expand Down Expand Up @@ -564,56 +545,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
TrackingListener listener;
IRRewriter rewriter(ctx, &listener);

// 1. Blocks, We must remove the block arguments and successor operands before
// deleting the operation, as they may reside in the region operation.
LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
for (auto &b : list.blocks) {
// blocks that are accessed via multiple codepaths processed once
if (b.b->getNumArguments() != b.nonLiveArgs.size())
continue;
LDBG_OS([&](raw_ostream &os) {
os << "Erasing non-live arguments [";
llvm::interleaveComma(b.nonLiveArgs.set_bits(), os);
os << "] from block #" << b.b->computeBlockNumber() << " in region #"
<< b.b->getParent()->getRegionNumber() << " of operation "
<< OpWithFlags(b.b->getParent()->getParentOp(),
OpPrintingFlags().skipRegions().printGenericOpForm());
});
// Note: Iterate from the end to make sure that that indices of not yet
// processes arguments do not change.
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
b.b->getArgument(i).dropAllUses();
b.b->eraseArgument(i);
}
}

// 2. Successor Operands
LDBG() << "Cleaning up " << list.successorOperands.size()
<< " successor operand lists";
for (auto &op : list.successorOperands) {
SuccessorOperands successorOperands =
op.branch.getSuccessorOperands(op.successorIndex);
// blocks that are accessed via multiple codepaths processed once
if (successorOperands.size() != op.nonLiveOperands.size())
continue;
LDBG_OS([&](raw_ostream &os) {
os << "Erasing non-live successor operands [";
llvm::interleaveComma(op.nonLiveOperands.set_bits(), os);
os << "] from successor " << op.successorIndex << " of branch: "
<< OpWithFlags(op.branch.getOperation(),
OpPrintingFlags().skipRegions().printGenericOpForm());
});
// it iterates backwards because erase invalidates all successor indexes
for (int i = successorOperands.size() - 1; i >= 0; --i) {
if (!op.nonLiveOperands[i])
continue;
successorOperands.erase(i);
}
}

// 3. Functions
// 1. Functions
LDBG() << "Cleaning up " << list.functions.size() << " functions";
// Record which function arguments were erased so we can shrink call-site
// argument segments for CallOpInterface operations (e.g. ops using
Expand Down Expand Up @@ -647,7 +579,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
(void)f.funcOp.eraseResults(f.nonLiveRets);
}

// 4. Operands
// 2. Operands
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperandsToCleanup &o : list.operands) {
// Handle call-specific cleanup only when we have a cached callee reference.
Expand Down Expand Up @@ -705,7 +637,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
}
}

// 5. Results
// 3. Results
LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
LDBG_OS([&](raw_ostream &os) {
Expand All @@ -718,7 +650,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
dropUsesAndEraseResults(rewriter, r.op, r.nonLive);
}

// 6. Operations
// 4. Operations
LDBG() << "Cleaning up " << list.operations.size() << " operations";
for (Operation *op : list.operations) {
LDBG() << "Erasing operation: "
Expand Down Expand Up @@ -755,7 +687,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
rewriter.eraseOp(op);
}

// 7. Remove all dead poison ops.
// 5. Remove all dead poison ops.
for (ub::PoisonOp poisonOp : listener.poisonOps) {
if (poisonOp.use_empty())
poisonOp.erase();
Expand Down Expand Up @@ -808,20 +740,17 @@ void RemoveDeadValues::runOnOperation() {
if (!canonicalize)
return;

// Canonicalize all region branch ops.
SmallVector<Operation *> opsToCanonicalize;
module->walk([&](RegionBranchOpInterface regionBranchOp) {
opsToCanonicalize.push_back(regionBranchOp.getOperation());
});
// Collect all canonicalization patterns for region branch ops.
// Canonicalize all region branch ops and branch ops.
RewritePatternSet owningPatterns(context);
DenseSet<RegisteredOperationName> populatedPatterns;
for (Operation *op : opsToCanonicalize)
module->walk([&](Operation *op) {
if (!isa<RegionBranchOpInterface, BranchOpInterface>(op))
return;
if (std::optional<RegisteredOperationName> info = op->getRegisteredInfo())
if (populatedPatterns.insert(*info).second)
info->getCanonicalizationPatterns(owningPatterns, context);
if (failed(applyOpPatternsGreedily(opsToCanonicalize,
std::move(owningPatterns)))) {
});
if (failed(applyPatternsGreedily(module, std::move(owningPatterns)))) {
module->emitError("greedy pattern rewrite failed to converge");
signalPassFailure();
}
Expand Down
6 changes: 5 additions & 1 deletion mlir/test/Dialect/SPIRV/IR/return-ops.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s --remove-dead-values | FileCheck %s
// RUN: mlir-opt %s -remove-dead-values="canonicalize=0" -split-input-file | FileCheck %s
// RUN: mlir-opt %s -remove-dead-values="canonicalize=1" -split-input-file | FileCheck %s --check-prefix=CHECK-CANONICALIZE

// Make sure that the return value op is considered as a return-like op and
// remains live.
Expand All @@ -8,6 +9,9 @@
// CHECK-NEXT: %[[BITCAST0:.*]] = spirv.Bitcast %[[ARG1]] : vector<2xi32> to vector<2xf32>
// CHECK-NEXT: %[[BITCAST1:.*]] = spirv.Bitcast %[[BITCAST0]] : vector<2xf32> to vector<2xi32>
// CHECK-NEXT: spirv.ReturnValue %[[BITCAST1]] : vector<2xi32>
// CHECK-CANONICALIZE-LABEL: @preserve_return_value
// CHECK-CANONICALIZE-SAME: (%[[ARG0:.*]]: vector<2xi32>, %[[ARG1:.*]]: vector<2xi32>) -> vector<2xi32>
// CHECK-CANONICALIZE-NEXT: spirv.ReturnValue %[[ARG1]] : vector<2xi32>

spirv.func @preserve_return_value(%arg0: vector<2xi32>, %arg1: vector<2xi32>) -> vector<2xi32> "None" {
%0 = spirv.Bitcast %arg0 : vector<2xi32> to vector<2xf32>
Expand Down
80 changes: 61 additions & 19 deletions mlir/test/Transforms/remove-dead-values.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,24 @@ module @named_module_acceptable {
func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0: i1) {
%non_live = arith.constant 0 : i32
// CHECK-NOT: arith.constant
// CHECK-CANONICALIZE-NOT: arith.constant
cf.br ^bb1(%non_live : i32)
// CHECK: cf.br ^[[BB1:bb[0-9]+]]
// CHECK: cf.br ^[[BB1:bb[0-9]+]](%{{.*}} : i32)
// CHECK-CANONICALIZE: cf.br ^[[BB1:bb[0-9]+]]
^bb1(%non_live_1 : i32):
// CHECK: ^[[BB1]]:
// CHECK: ^[[BB1]](%{{.*}}: i32):
// CHECK-CANONICALIZE: ^[[BB1]]:
%non_live_5 = arith.constant 1 : i32
cf.br ^bb3(%non_live_1, %non_live_5 : i32, i32)
// CHECK: cf.br ^[[BB3:bb[0-9]+]]
// CHECK-NOT: i32
// CHECK: cf.br ^[[BB3:bb[0-9]+]](%{{.*}}, %{{.*}} : i32, i32)
// CHECK-CANONICALIZE: cf.cond_br %arg0, ^[[BB1]], ^[[BB2:bb[0-9]+]]
^bb3(%non_live_2 : i32, %non_live_6 : i32):
// CHECK: ^[[BB3]]:
// CHECK: ^[[BB3]](%{{.*}}: i32, %{{.*}}: i32):
cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32)
// CHECK: cf.cond_br %arg0, ^[[BB1]], ^[[BB4:bb[0-9]+]]
// CHECK: cf.cond_br %arg0, ^[[BB1]](%{{.*}} : i32), ^[[BB4:bb[0-9]+]](%{{.*}} : i32)
^bb4(%non_live_4 : i32):
// CHECK: ^[[BB4]]:
// CHECK: ^[[BB4]](%{{.*}}: i32):
// CHECK-CANONICALIZE: ^[[BB2]]:
return
}

Expand Down Expand Up @@ -345,9 +349,9 @@ func.func private @identity(%arg1 : i32) -> (i32) {
// Note that this cleanup cannot be done by the `canonicalize` pass.
//
// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
// CHECK-CANONICALIZE: %[[c10:.*]] = arith.constant 10
// CHECK-CANONICALIZE-NEXT: scf.index_switch %[[arg0]]
// CHECK-CANONICALIZE-NEXT: case 1 {
// CHECK-CANONICALIZE-NEXT: %[[c10:.*]] = arith.constant 10
// CHECK-CANONICALIZE-NEXT: memref.store %[[c10]], %[[arg1]][]
// CHECK-CANONICALIZE: scf.yield
// CHECK-CANONICALIZE-NEXT: }
Expand Down Expand Up @@ -476,6 +480,47 @@ func.func @kernel(%arg0: memref<18xf32>) {

// -----

// Test that RemoveDeadValues does not crash when gpu.launch appears in a block
// with multiple predecessors. The dead branch operand (%c20) must be replaced
// with ub.poison, gpu.launch and its grid/block size operands must be
// preserved, and the live block argument must remain intact.
//
// CHECK-LABEL: func.func @gpu_launch_in_multi_predecessor_block
// CHECK: arith.constant true
// CHECK: cf.cond_br
// CHECK: arith.constant 10
// CHECK: cf.br ^[[BB3:bb[0-9]+]](%{{.*}} : i64)
// CHECK: ub.poison
// CHECK: cf.br ^[[BB3]](%{{.*}} : i64)
// CHECK: ^[[BB3]](%{{.*}}: i64):
// CHECK: return
// CHECK-NOT: arith.constant 20
//
// CHECK-CANONICALIZE-LABEL: func.func @gpu_launch_in_multi_predecessor_block
// CHECK-CANONICALIZE-NEXT: %[[c10:.*]] = arith.constant 10 : i64
// CHECK-CANONICALIZE-NEXT: return %[[c10]]
func.func @gpu_launch_in_multi_predecessor_block() -> i64 {
%cond = arith.constant true
cf.cond_br %cond, ^bb1, ^bb2
^bb1:
%c10 = arith.constant 10 : i64
cf.br ^bb3(%c10 : i64)
^bb2:
%c20 = arith.constant 20 : i64
cf.br ^bb3(%c20 : i64)
^bb3(%arg0: i64):
%c1 = arith.constant 1 : index
gpu.launch
blocks(%bx, %by, %bz) in (%gx = %c1, %gy = %c1, %gz = %c1)
threads(%tx, %ty, %tz) in (%bsx = %c1, %bsy = %c1, %bsz = %c1) {
%blk_x = gpu.block_id x
%thr_x = gpu.thread_id x
gpu.terminator
}
func.return %arg0 : i64
}

// -----

// CHECK-LABEL: llvm_unreachable
// CHECK-LABEL: @fn_with_llvm_unreachable
Expand Down Expand Up @@ -768,14 +813,14 @@ func.func @affine_loop_no_use_iv_has_side_effect_op() {
// CHECK: return %[[while]]#0

// CHECK-CANONICALIZE-LABEL: func @scf_while_dead_iter_args()
// CHECK-CANONICALIZE: %[[p0:.*]] = ub.poison : i32
// CHECK-CANONICALIZE: %[[c5:.*]] = arith.constant 5 : i32
// CHECK-CANONICALIZE: %[[while:.*]] = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> i32 {
// CHECK-CANONICALIZE: vector.print %[[arg0]]
// CHECK-CANONICALIZE: %[[cmpi:.*]] = arith.cmpi
// CHECK-CANONICALIZE: scf.condition(%[[cmpi]]) %[[arg0]]
// CHECK-CANONICALIZE: } do {
// CHECK-CANONICALIZE: ^bb0(%[[arg1:.*]]: i32):
// CHECK-CANONICALIZE: %[[p0:.*]] = ub.poison : i32
// CHECK-CANONICALIZE: scf.yield %[[p0]]
// CHECK-CANONICALIZE: }
// CHECK-CANONICALIZE: return %[[while]]
Expand All @@ -799,7 +844,13 @@ func.func @scf_while_dead_iter_args() -> i32 {

// -----

// CHECK-LABEL: func.func @replace_dead_operation_results_with_poison
// Check that this prevents a crash in the canonicalization phase which
// happens after the dead value removal phase. Also check that only used
// results of an erased op are replaced with ub.poison.

// CHECK-CANONICALIZE-LABEL: func.func @replace_dead_operation_results_with_poison
// CHECK-CANONICALIZE-NEXT: %[[p:.*]] = ub.poison : vector<1xindex>
// CHECK-CANONICALIZE-NEXT: return %[[p]]
func.func @replace_dead_operation_results_with_poison(%0: vector<1xindex>) -> vector<1xindex> {
%1 = scf.while (%arg0 = %0) : (vector<1xindex>) -> vector<1xindex> {
%cond = arith.constant true
Expand All @@ -809,15 +860,6 @@ func.func @replace_dead_operation_results_with_poison(%0: vector<1xindex>) -> ve
scf.yield %arg0 : vector<1xindex>
}
%2 = scf.while (%arg0 = %1) : (vector<1xindex>) -> vector<1xindex> {
// Check that the binary value in condition is replaced with poison, and
// the condition itself is well-formed IR. This prevents a crash in the
// canonicalization phase which happens after the dead value removal phase.
// Also check that only used results of an erased op are replaced with ub.poison.
// CHECK-CANONICALIZE: %[[COND:.*]] = ub.poison : i1
// CHECK-CANONICALIZE-NEXT: %[[NEXT:.*]] = ub.poison : vector<1xindex>
// CHECK-CANONICALIZE-NEXT: scf.condition(%[[COND]]) %[[NEXT]]
// CHECK-CANONICALIZE-NOT: ub.poison : i32
// CHECK-CANONICALIZE-NOT: "test.three"
%cond, %unused, %next = "test.three"(%1) : (vector<1xindex>) -> (i1, i32, vector<1xindex>)
scf.condition(%cond) %next : vector<1xindex>
} do {
Expand Down