-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][mesh] Add null check for dyn_cast to prevent crash #149266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR adds a null check for dyn_cast to prevent crash, and use `isa` instead `dyn_cast` to make code clean.
|
@llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR adds a null check for dyn_cast result before use to prevent crash, and use Full diff: https://github.com/llvm/llvm-project/pull/149266.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
index 3f1041cb25103..243dbf081b999 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Simplifications.h
@@ -62,9 +62,11 @@ void populateAllReduceEndomorphismSimplificationPatterns(
auto isEndomorphismOp = [reduction](Operation *op,
std::optional<Operation *> referenceOp) {
auto allReduceOp = llvm::dyn_cast<AllReduceOp>(op);
+ if (!allReduceOp)
+ return false;
auto inType = cast<ShapedType>(allReduceOp.getInput().getType());
auto outType = cast<ShapedType>(allReduceOp.getResult().getType());
- if (!allReduceOp || inType.getElementType() != outType.getElementType() ||
+ if (inType.getElementType() != outType.getElementType() ||
allReduceOp.getReduction() != reduction) {
return false;
}
@@ -87,9 +89,7 @@ void populateAllReduceEndomorphismSimplificationPatterns(
return refAllReduceOp->getAttrs() == allReduceOp->getAttrs() &&
inType.getElementType() == refType.getElementType();
};
- auto isAlgebraicOp = [](Operation *op) {
- return static_cast<bool>(llvm::dyn_cast<AlgebraicOp>(op));
- };
+ auto isAlgebraicOp = [](Operation *op) { return isa<AlgebraicOp>(op); };
using ConcreteEndomorphismSimplification = EndomorphismSimplification<
std::decay_t<decltype(getEndomorphismOpOperand)>,
diff --git a/mlir/test/Dialect/Mesh/simplifications.mlir b/mlir/test/Dialect/Mesh/simplifications.mlir
index 2540fbf9510c4..e955f4c134259 100644
--- a/mlir/test/Dialect/Mesh/simplifications.mlir
+++ b/mlir/test/Dialect/Mesh/simplifications.mlir
@@ -165,3 +165,15 @@ func.func @all_reduce_arith_minsi_endomorphism(
// CHECK: return %[[ALL_REDUCE_RES]]
return %2 : tensor<5xi32>
}
+
+// Ensure this case without endomorphism op not crash.
+// CHECK-LABEL: func.func @no_endomorphism_op
+func.func @no_endomorphism_op(%arg0: tensor<2xi64>) -> i64 {
+ %c0 = arith.constant 0 : index
+ %c1_i64 = arith.constant 1 : i64
+ // CHECK: tensor.extract
+ %extracted = tensor.extract %arg0[%c0] : tensor<2xi64>
+ // CHECK: arith.maxsi
+ %0 = arith.maxsi %extracted, %c1_i64 : i64
+ return %0 : i64
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix. LGTM.
This PR adds a null check for dyn_cast result before use to prevent crash, and use `isa` instead `dyn_cast` to make code clean. Fixes llvm#148619.
This PR adds a null check for dyn_cast result before use to prevent crash, and use `isa` instead `dyn_cast` to make code clean. Fixes #148619.
This PR adds a null check for dyn_cast result before use to prevent crash, and use
isainsteaddyn_castto make code clean. Fixes #148619.