Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions mlir/lib/Transforms/RemoveDeadValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
// TODO: this isn't correct in face of multiple terminators.
auto terminator = cast<RegionBranchTerminatorOpInterface>(
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
region.front().getTerminator());
if (!terminator)
continue;
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
for (const RegionSuccessor &successor : getSuccessors(terminator)) {
Expand All @@ -498,11 +500,17 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
Operation *terminator =
region ? region->front().getTerminator() : nullptr;
RegionBranchPoint point =
auto terminatorIface =
terminator
? RegionBranchPoint(
cast<RegionBranchTerminatorOpInterface>(terminator))
: RegionBranchPoint::parent();
? dyn_cast<RegionBranchTerminatorOpInterface>(terminator)
: nullptr;
// If terminator doesn't implement RegionBranchTerminatorOpInterface,
// we can't analyze it, so skip.
if (terminator && !terminatorIface)
return;
RegionBranchPoint point =
terminatorIface ? RegionBranchPoint(terminatorIface)
: RegionBranchPoint::parent();

for (const RegionSuccessor &successor : getSuccessors(point)) {
Region *successorRegion = successor.getSuccessor();
Expand Down Expand Up @@ -566,8 +574,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
for (Region &region : regionBranchOp->getRegions()) {
if (region.empty())
continue;
auto terminator = cast<RegionBranchTerminatorOpInterface>(
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
region.front().getTerminator());
if (!terminator)
continue;
for (const RegionSuccessor &successor : getSuccessors(terminator)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
Expand Down