-
Notifications
You must be signed in to change notification settings - Fork 16k
[MLIR] Fix crash in RemoveDeadValues when terminator lacks RegionBranchTerminatorOpInterface #175300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…chTerminatorOpInterface The RemoveDeadValues pass was using cast<RegionBranchTerminatorOpInterface> which asserts that the terminator implements this interface. However, some dialects (like CIR) have region terminators that don't implement this interface, causing crashes. This patch changes cast to dyn_cast and skips processing for terminators that don't implement RegionBranchTerminatorOpInterface. Fixes llvm#174502
|
@llvm/pr-subscribers-mlir-core Author: None (nataliakokoromyti) ChangesThe RemoveDeadValues pass was using cast<RegionBranchTerminatorOpInterface> which asserts that the terminator implements this interface. However, some dialects (like CIR) have region terminators that don't implement this interface, causing crashes. This patch changes cast to dyn_cast and skips processing for terminators that don't implement RegionBranchTerminatorOpInterface. Fixes #174502 Full diff: https://github.com/llvm/llvm-project/pull/175300.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index fc2c2acf8afd3..4d20abb415229 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -477,8 +477,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
// TODO: this isn't correct in face of multiple terminators.
- auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
region.front().getTerminator());
+ if (!terminator)
+ continue;
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
for (const RegionSuccessor &successor : getSuccessors(terminator)) {
@@ -498,11 +500,17 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
Operation *terminator =
region ? region->front().getTerminator() : nullptr;
- RegionBranchPoint point =
+ auto terminatorIface =
terminator
- ? RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator))
- : RegionBranchPoint::parent();
+ ? dyn_cast<RegionBranchTerminatorOpInterface>(terminator)
+ : nullptr;
+ // If terminator doesn't implement RegionBranchTerminatorOpInterface,
+ // we can't analyze it, so skip.
+ if (terminator && !terminatorIface)
+ return;
+ RegionBranchPoint point =
+ terminatorIface ? RegionBranchPoint(terminatorIface)
+ : RegionBranchPoint::parent();
for (const RegionSuccessor &successor : getSuccessors(point)) {
Region *successorRegion = successor.getSuccessor();
@@ -566,8 +574,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
region.front().getTerminator());
+ if (!terminator)
+ continue;
for (const RegionSuccessor &successor : getSuccessors(terminator)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
|
|
@llvm/pr-subscribers-mlir Author: None (nataliakokoromyti) ChangesThe RemoveDeadValues pass was using cast<RegionBranchTerminatorOpInterface> which asserts that the terminator implements this interface. However, some dialects (like CIR) have region terminators that don't implement this interface, causing crashes. This patch changes cast to dyn_cast and skips processing for terminators that don't implement RegionBranchTerminatorOpInterface. Fixes #174502 Full diff: https://github.com/llvm/llvm-project/pull/175300.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index fc2c2acf8afd3..4d20abb415229 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -477,8 +477,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
if (region.empty())
continue;
// TODO: this isn't correct in face of multiple terminators.
- auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
region.front().getTerminator());
+ if (!terminator)
+ continue;
nonForwardedRets[terminator] =
BitVector(terminator->getNumOperands(), true);
for (const RegionSuccessor &successor : getSuccessors(terminator)) {
@@ -498,11 +500,17 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
Operation *terminator =
region ? region->front().getTerminator() : nullptr;
- RegionBranchPoint point =
+ auto terminatorIface =
terminator
- ? RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator))
- : RegionBranchPoint::parent();
+ ? dyn_cast<RegionBranchTerminatorOpInterface>(terminator)
+ : nullptr;
+ // If terminator doesn't implement RegionBranchTerminatorOpInterface,
+ // we can't analyze it, so skip.
+ if (terminator && !terminatorIface)
+ return;
+ RegionBranchPoint point =
+ terminatorIface ? RegionBranchPoint(terminatorIface)
+ : RegionBranchPoint::parent();
for (const RegionSuccessor &successor : getSuccessors(point)) {
Region *successorRegion = successor.getSuccessor();
@@ -566,8 +574,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- auto terminator = cast<RegionBranchTerminatorOpInterface>(
+ auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(
region.front().getTerminator());
+ if (!terminator)
+ continue;
for (const RegionSuccessor &successor : getSuccessors(terminator)) {
Region *successorRegion = successor.getSuccessor();
for (auto [opOperand, input] :
|
The RemoveDeadValues pass was using cast which asserts that the terminator implements this interface. However, some dialects (like CIR) have region terminators that don't implement this interface, causing crashes.
This patch changes cast to dyn_cast and skips processing for terminators that don't implement RegionBranchTerminatorOpInterface.
Fixes #174502