-
Notifications
You must be signed in to change notification settings - Fork 16k
[mlir][Transforms] remove-dead-values: Rely on canonicalizer for region simplification
#173505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
88c7c3a to
36f2c32
Compare
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
36f2c32 to
552b819
Compare
63d57b9 to
db70228
Compare
7e3d572 to
7937621
Compare
remove-dead-valuesremove-dead-values: Rely on canonicalizer for region simplification
7937621 to
0ba7790
Compare
|
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit simplifies the
Region-based ops are difficult. The liveness analysis may determine that an SSA value is dead. However, that does not mean that the value can actually be removed. Doing so may violate an region data flow (as modeled by the Before this commit, there used to be complex logic to determine when it is safe to erase an SSA value. That logic was broken. The new implementation does not remove any block arguments or op results of region-based ops. Instead, operands of region-based ops and region branch terminators are replaced with Depends on #173560. Patch is 38.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/173505.diff 2 Files Affected:
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 62ce5e0bbb77e..a347f335c9c1e 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -94,8 +94,11 @@ struct ResultsToCleanup {
struct OperandsToCleanup {
Operation *op;
BitVector nonLive;
- Operation *callee =
- nullptr; // Optional: For CallOpInterface ops, stores the callee function
+ // Optional: For CallOpInterface ops, stores the callee function.
+ Operation *callee = nullptr;
+ // Determines whether the operand should be replaced with a ub.poison result
+ // or erased entirely.
+ bool replaceWithPoison = false;
};
struct BlockArgsToCleanup {
@@ -199,9 +202,9 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
}
}
-/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
-/// is 1.
-static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
+/// Erase the i-th result of `op` iff toErase[i] is 1.
+static void eraseResults(RewriterBase &rewriter, Operation *op,
+ BitVector toErase) {
assert(op->getNumResults() == toErase.size() &&
"expected the number of results in `op` and the size of `toErase` to "
"be the same");
@@ -210,7 +213,6 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
for (OpResult result : op->getResults())
if (!toErase[result.getResultNumber()])
newResultTypes.push_back(result.getType());
- IRRewriter rewriter(op);
rewriter.setInsertionPointAfter(op);
OperationState state(op->getLoc(), op->getName().getStringRef(),
op->getOperands(), newResultTypes, op->getAttrs());
@@ -226,14 +228,12 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
unsigned indexOfNextNewCallOpResultToReplace = 0;
for (auto [index, result] : llvm::enumerate(op->getResults())) {
assert(result && "expected result to be non-null");
- if (toErase[index]) {
- result.dropAllUses();
- } else {
+ if (!toErase[index]) {
result.replaceAllUsesWith(
newOp->getResult(indexOfNextNewCallOpResultToReplace++));
}
}
- op->erase();
+ rewriter.eraseOp(op);
}
/// Convert a list of `Operand`s to a list of `OpOperand`s.
@@ -404,30 +404,20 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
///
/// Scenario 1: If the operation has no memory effects and none of its results
/// are live:
-/// (1') Enqueue all its uses for deletion.
-/// (2') Enqueue the branch itself for deletion.
+/// 1.1. Enqueue all its uses for deletion.
+/// 1.2. Enqueue the branch itself for deletion.
///
/// Scenario 2: Otherwise:
-/// (1) Collect its unnecessary operands (operands forwarded to unnecessary
-/// results or arguments).
-/// (2) Process each of its regions.
-/// (3) Collect the uses of its unnecessary results (results forwarded from
-/// unnecessary operands
-/// or terminator operands).
-/// (4) Add these results to the deletion list.
-///
-/// Processing a region includes:
-/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
-/// from unnecessary operands
-/// or terminator operands).
-/// (b) Collecting these unnecessary arguments.
-/// (c) Collecting its unnecessary terminator operands (terminator operands
-/// forwarded to unnecessary results
-/// or arguments).
+/// 2.1. Collect block arguments and op results that we would like to keep,
+/// based on their liveness.
+/// 2.2. Find all operands that are forwarded to only dead region successor
+/// inputs. I.e., forwarded to block arguments / op results that we do
+/// not want to keep.
+/// 2.3. Enqueue all such operands for replacement with ub.poison.
///
-/// Value Flow Note: In this operation, values flow as follows:
-/// - From operands and terminator operands (successor operands)
-/// - To arguments and results (successor inputs).
+/// Note: In scenario 2, block arguments and op results are not removed.
+/// However, the IR is simplified such that canonicalization patterns can
+/// remove them later.
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
@@ -441,284 +431,103 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
// It could never be live because of this op but its liveness could have been
// attributed to something else.
- // Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
!hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
- // Mark live results of `regionBranchOp` in `liveResults`.
- auto markLiveResults = [&](BitVector &liveResults) {
- liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
- };
-
- // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
- auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
- liveArgs[®ion] = regionLiveArgs;
+ // Compute values that are alive.
+ DenseSet<Value> valuesToKeep;
+ for (Value result : regionBranchOp->getResults()) {
+ if (hasLive(result, nonLiveSet, la))
+ valuesToKeep.insert(result);
+ }
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ if (region.empty())
+ continue;
+ for (Value arg : region.front().getArguments()) {
+ if (hasLive(arg, nonLiveSet, la))
+ valuesToKeep.insert(arg);
}
- };
+ }
- // Return the successors of `region` if the latter is not null. Else return
- // the successors of `regionBranchOp`.
- auto getSuccessors = [&](RegionBranchPoint point) {
+ // Mapping from operands to forwarded successor inputs. An operand can be
+ // forwarded to multiple successors.
+ DenseMap<OpOperand *, SmallVector<Value>> operandToSuccessorInputs;
+ auto helper = [&](RegionBranchPoint point) {
SmallVector<RegionSuccessor> successors;
regionBranchOp.getSuccessorRegions(point, successors);
- return successors;
- };
-
- // Return the operands of `terminator` that are forwarded to `successor` if
- // the former is not null. Else return the operands of `regionBranchOp`
- // forwarded to `successor`.
- auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
- Operation *terminator = nullptr) {
- OperandRange operands =
- terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
- .getSuccessorOperands(successor)
- : regionBranchOp.getEntrySuccessorOperands(successor);
- SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
- return opOperands;
- };
-
- // Mark the non-forwarded operands of `regionBranchOp` in
- // `nonForwardedOperands`.
- auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
- nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
- for (OpOperand *opOperand : getForwardedOpOperands(successor))
- nonForwardedOperands.reset(opOperand->getOperandNumber());
+ for (const RegionSuccessor &successor : successors) {
+ // Handle branch from point --> successor.
+ ValueRange argsOrResults = successor.getSuccessorInputs();
+ OperandRange operands =
+ point.isParent() ? regionBranchOp.getEntrySuccessorOperands(successor)
+ : cast<RegionBranchTerminatorOpInterface>(
+ point.getTerminatorPredecessorOrNull())
+ .getSuccessorOperands(successor);
+ assert(
+ argsOrResults.size() == operands.size() &&
+ "expected the same number of successor inputs as forwarded operands");
+
+ for (auto [opOperand, input] :
+ llvm::zip_equal(operandsToOpOperands(operands), argsOrResults)) {
+ operandToSuccessorInputs[opOperand].push_back(input);
+ }
}
};
- // Mark the non-forwarded terminator operands of the various regions of
- // `regionBranchOp` in `nonForwardedRets`.
- auto markNonForwardedReturnValues =
- [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- // TODO: this isn't correct in face of multiple terminators.
- Operation *terminator = region.front().getTerminator();
- nonForwardedRets[terminator] =
- BitVector(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
- for (OpOperand *opOperand :
- getForwardedOpOperands(successor, terminator))
- nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
- }
- }
- };
-
- // Update `valuesToKeep` (which is expected to correspond to operands or
- // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
- // `region`. When `valuesToKeep` correspond to operands, `region` is null.
- // Else, `region` is the parent region of the terminator.
- auto updateOperandsOrTerminatorOperandsToKeep =
- [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
- DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
- Operation *terminator =
- region ? region->front().getTerminator() : nullptr;
- RegionBranchPoint point =
- terminator
- ? RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator))
- : RegionBranchPoint::parent();
-
- for (const RegionSuccessor &successor : getSuccessors(point)) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor, terminator),
- successor.getSuccessorInputs())) {
- size_t operandNum = opOperand->getOperandNumber();
- bool updateBasedOn =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
- }
- }
- };
-
- // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
- // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
- // value is modified, else, false.
- auto recomputeResultsAndArgsToKeep =
- [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
- BitVector &operandsToKeep,
- DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
- bool &resultsOrArgsToKeepChanged) {
- resultsOrArgsToKeepChanged = false;
-
- // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor),
- successor.getSuccessorInputs())) {
- bool recomputeBasedOn =
- operandsToKeep[opOperand->getOperandNumber()];
- bool toRecompute =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- if (!toRecompute && recomputeBasedOn)
- resultsOrArgsToKeepChanged = true;
- if (successorRegion) {
- argsToKeep[successorRegion][cast<BlockArgument>(input)
- .getArgNumber()] =
- argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()] |
- recomputeBasedOn;
- } else {
- resultsToKeep[cast<OpResult>(input).getResultNumber()] =
- resultsToKeep[cast<OpResult>(input).getResultNumber()] |
- recomputeBasedOn;
- }
- }
- }
-
- // Recompute `resultsToKeep` and `argsToKeep` based on
- // `terminatorOperandsToKeep`.
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- Operation *terminator = region.front().getTerminator();
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor, terminator),
- successor.getSuccessorInputs())) {
- bool recomputeBasedOn =
- terminatorOperandsToKeep[region.back().getTerminator()]
- [opOperand->getOperandNumber()];
- bool toRecompute =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- if (!toRecompute && recomputeBasedOn)
- resultsOrArgsToKeepChanged = true;
- if (successorRegion) {
- argsToKeep[successorRegion][cast<BlockArgument>(input)
- .getArgNumber()] =
- argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()] |
- recomputeBasedOn;
- } else {
- resultsToKeep[cast<OpResult>(input).getResultNumber()] =
- resultsToKeep[cast<OpResult>(input).getResultNumber()] |
- recomputeBasedOn;
- }
- }
- }
- }
- };
-
- // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
- // `operandsToKeep`, and `terminatorOperandsToKeep`.
- auto markValuesToKeep =
- [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
- BitVector &operandsToKeep,
- DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
- bool resultsOrArgsToKeepChanged = true;
- // We keep updating and recomputing the values until we reach a point
- // where they stop changing.
- while (resultsOrArgsToKeepChanged) {
- // Update the operands that need to be kept.
- updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
- resultsToKeep, argsToKeep);
-
- // Update the terminator operands that need to be kept.
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- updateOperandsOrTerminatorOperandsToKeep(
- terminatorOperandsToKeep[region.back().getTerminator()],
- resultsToKeep, argsToKeep, ®ion);
- }
-
- // Recompute the results and arguments that need to be kept.
- recomputeResultsAndArgsToKeep(
- resultsToKeep, argsToKeep, operandsToKeep,
- terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
- }
- };
-
- // Scenario 2.
- // At this point, we know that every non-forwarded operand of `regionBranchOp`
- // is live.
-
- // Stores the results of `regionBranchOp` that we want to keep.
- BitVector resultsToKeep;
- // Stores the mapping from regions of `regionBranchOp` to their arguments that
- // we want to keep.
- DenseMap<Region *, BitVector> argsToKeep;
- // Stores the operands of `regionBranchOp` that we want to keep.
- BitVector operandsToKeep;
- // Stores the mapping from region terminators in `regionBranchOp` to their
- // operands that we want to keep.
- DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
-
- // Initializing the above variables...
-
- // The live results of `regionBranchOp` definitely need to be kept.
- markLiveResults(resultsToKeep);
- // Similarly, the live arguments of the regions in `regionBranchOp` definitely
- // need to be kept.
- markLiveArgs(argsToKeep);
- // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
- // A live forwarded operand can be removed but no non-forwarded operand can be
- // removed since it "controls" the flow of data in this control flow op.
- markNonForwardedOperands(operandsToKeep);
- // Similarly, the non-forwarded terminator operands of the regions in
- // `regionBranchOp` definitely need to be kept.
- markNonForwardedReturnValues(terminatorOperandsToKeep);
-
- // Mark the values (results, arguments, operands, and terminator operands)
- // that we want to keep.
- markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
- terminatorOperandsToKeep);
-
- // Do (1).
- cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
-
- // Do (2.a) and (2.b).
+ // Example:
+ //
+ // %0 = scf.while : () -> i32 {
+ // scf.condition(...) %forwarded_value : i32
+ // } do {
+ // ^bb0(%arg0: i32):
+ // scf.yield
+ // }
+ // // No uses of %0.
+ //
+ // In the above example, %forwarded_value is forwarded to %arg0 and %0. Both
+ // %arg0 and %0 are dead, so %forwarded_value can be replaced with a
+ // ub.poison result.
+ //
+ // operandToSuccessorInputs[%forwarded_value] = {%arg0, %0}
+ //
+ helper(RegionBranchPoint::parent());
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- BitVector argsToRemove = argsToKeep[®ion].flip();
- cl.blocks.push_back({®ion.front(), argsToRemove});
- collectNonLiveValues(nonLiveSet, region.front().getArguments(),
- argsToRemove);
+ helper(RegionBranchPoint(cast<RegionBranchTerminatorOpInterface>(
+ region.front().getTerminator())));
}
- // Do (2.c).
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
+ DenseMap<Operation *, BitVector> deadOperandsPerOp;
+ for (auto [opOperand, successorInputs] : operandToSuccessorInputs) {
+ // If one of the successor inputs is live, the respective operand must be
+ // kept.
+ bool anyAlive = llvm::any_of(successorInputs, [&](Value input) {
+ return valuesToKeep.contains(input);
+ });
+ if (anyAlive)
continue;
- Operation *terminator = region.front().getTerminator();
- cl.operands.push_back(
- {terminator, terminatorOperandsToKeep[terminator].flip()});
+
+ // All successor inputs are dead: ub.poison can be passed as operand.
+ // Create an entry in `deadOperandsPerOp` (initialized to "false", i.e.,
+ // no "dead" op operands) if it's the first time that we are seeing an op
+ // operand for this op. Otherwise, just take the existing bit vector from
+ // the map.
+ BitVector &deadOperands =
+ deadOperandsPerOp
+ .try_emplace(opOperand->getOwner(),
+ opOperand->getOwner()->getNumOperands(), false)
+ .first->second;
+ deadOperands.set(opOperand->getOperandNumber());
}
- // Do (3) and (4).
- BitVector resultsToRemove = ...
[truncated]
|
529a159 to
da738b4
Compare
0ba7790 to
19e6fab
Compare
9bef674 to
cd480a2
Compare
19e6fab to
ed836e5
Compare
ed836e5 to
03266d5
Compare
f4bd3fb to
7fef1b4
Compare
777c729 to
3b73a7e
Compare
3b73a7e to
bde143c
Compare
bde143c to
8c30a79
Compare
…174208) `remove-dead-values` performs various cleanups: 1. Erasing block arguments 2. Erasing successor operands 3. Erasing operations 4. Erasing function arguments / results 5. Erasing operands 6. Erasing results This commit moves Step 3 (erasing operations) to the end. While that does not fix any bugs by itself, it is potentially safer. If an operation is erased, we must be careful that the operation is not accessed in the following steps. That can no longer happen if IR is erased only in the final step and not before. This commit is prefetching a change from #173505 (to keep that PR shorter). With #173505, it will become necessary to erase IR in the final step.
05aa443 to
a6c7e40
Compare
|
I was able to simplify this PR a bit further. The changes are now mostly in |
a6c7e40 to
a56e68f
Compare
…lvm#174208) `remove-dead-values` performs various cleanups: 1. Erasing block arguments 2. Erasing successor operands 3. Erasing operations 4. Erasing function arguments / results 5. Erasing operands 6. Erasing results This commit moves Step 3 (erasing operations) to the end. While that does not fix any bugs by itself, it is potentially safer. If an operation is erased, we must be careful that the operation is not accessed in the following steps. That can no longer happen if IR is erased only in the final step and not before. This commit is prefetching a change from llvm#173505 (to keep that PR shorter). With llvm#173505, it will become necessary to erase IR in the final step.
|
Are there any other changes that you would like me to make? This PR is ready to be merged from my perspective. |
simple test working draft: do not erase IR, just replace uses
a56e68f to
d206fb1
Compare
…ops (#176712) Collect canonicalization patterns from the region branch ops (instead of populating all canonicalization patterns). Addresses a [comment](#173505 (comment)) on a merged PR.
…terns from ops (#176712) Collect canonicalization patterns from the region branch ops (instead of populating all canonicalization patterns). Addresses a [comment](llvm/llvm-project#173505 (comment)) on a merged PR.
…ops (llvm#176712) Collect canonicalization patterns from the region branch ops (instead of populating all canonicalization patterns). Addresses a [comment](llvm#173505 (comment)) on a merged PR.
This commit simplifies the
remove-dead-valuespass and fixes a bug in the handling ofRegionBranchOpInterfaceops. The pass used to produce invalid IR ("null value found") for the newly added test case.remove-dead-valuesis a pass for additional IR simplification that cannot be performed by the canonicalizer pass. Based on a liveness analysis, it erases dead values / IR. (The liveness analysis is a dataflow analysis that has more information about the IR than a canonicalization pattern, which can see only "local" information.)Region-based ops are difficult. The liveness analysis may determine that an SSA value is dead. However, that does not mean that the value can actually be removed. Doing so may violate an region data flow (as modeled by the
RegionBranchOpInterface). As an example, consider the case where a region branch terminator may dispatch to one of two region successor with the same forwarded values. A successor input (block argument) can be erased only if it is dead on both successors.Before this commit, there used to be complex logic to determine when it is safe to erase an SSA value. That logic was broken. The new implementation does not remove any block arguments or op results of region-based ops. Instead, operands of region-based ops and region branch terminators are replaced with
ub.poisonif all of their successor values are dead. This simplifies the IR good enough for the canonicalizer to perform the remaining region simplification (i.e., dropping block arguments etc.).RFC: https://discourse.llvm.org/t/rfc-delegate-simplification-of-region-based-ops-from-remove-dead-values-to-canonicalizer/89194