-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] reverse int8 type's printing logic #69361
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Chengji Yao (yaochengji) ChangesSpecializing for 8-bit integers to ensure values are printed as integers in a generic way will cause a bug., see #69310 Full diff: https://github.com/llvm/llvm-project/pull/69361.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index d761743a82bf86b..867c98078ae5171 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -58,8 +58,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let parameters = (ins
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
- ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes,
- OptionalArrayRefParameter<"int8_t">:$partial_axes,
+ ArrayRefParameter<"::mlir::DenseI64ArrayAttr">:$split_axes,
+ OptionalArrayRefParameter<"int64_t">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 8ca4b6653104221..a8aa0a694bee29f 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -70,7 +70,7 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
}];
let arguments = (ins
SymbolNameAttr:$sym_name,
- I8Attr:$rank,
+ I64Attr:$rank,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$dim_sizes
);
let assemblyFormat = [{
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 379392ace46961a..f1fabf95a68b7ad 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -350,8 +350,7 @@ template <typename AsmPrinterT, typename T,
!std::is_convertible<T &, Attribute &>::value &&
!std::is_convertible<T &, ValueRange>::value &&
!std::is_convertible<T &, APFloat &>::value &&
- !llvm::is_one_of<T, bool, int8_t, uint8_t, float,
- double>::value,
+ !llvm::is_one_of<T, bool, float, double>::value,
T> * = nullptr>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
@@ -367,17 +366,6 @@ operator<<(AsmPrinterT &p, bool value) {
return p << (value ? StringRef("true") : "false");
}
-/// Specialization for 8-bit integers to ensure values are printed as integers
-// and not characters.
-template <
- typename AsmPrinterT, typename T,
- std::enable_if_t<llvm::is_one_of<T, int8_t, uint8_t>::value, T> * = nullptr>
-inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
- AsmPrinterT &>
-operator<<(AsmPrinterT &p, T value) {
- return p << static_cast<int16_t>(value);
-}
-
template <typename AsmPrinterT, typename ValueRangeT>
inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
AsmPrinterT &>
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b2a47102528758c..e8dc14cf0fa9c04 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -47,7 +47,7 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
LogicalResult ClusterOp::verify() {
ArrayRef<int64_t> dimSizes = getDimSizes();
- uint8_t rank = getRank();
+ uint64_t rank = getRank();
if (rank == 0)
return emitOpError("rank of cluster is expected to be a positive integer");
@@ -71,15 +71,15 @@ LogicalResult ClusterOp::verify() {
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- SymbolRefAttr, ArrayRef<DenseI8ArrayAttr> splitAxes,
- ArrayRef<int8_t> partialAxes, Partial) {
+ SymbolRefAttr, ArrayRef<DenseI64ArrayAttr> splitAxes,
+ ArrayRef<int64_t> partialAxes, Partial) {
// TODO: At present cluster symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
- llvm::SmallSet<int8_t, 4> visitedAxes;
+ llvm::SmallSet<int64_t, 4> visitedAxes;
- auto checkMeshAxis = [&](ArrayRef<int8_t> axesArray) -> LogicalResult {
- for (int8_t axis : axesArray) {
+ auto checkMeshAxis = [&](ArrayRef<int64_t> axesArray) -> LogicalResult {
+ for (int64_t axis : axesArray) {
if (axis < 0)
return emitError() << "mesh axis is expected to be non-negative";
if (!visitedAxes.insert(axis).second)
@@ -88,8 +88,8 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
};
- for (DenseI8ArrayAttr subAxes : splitAxes) {
- ArrayRef<int8_t> subAxesArray = subAxes.asArrayRef();
+ for (DenseI64ArrayAttr subAxes : splitAxes) {
+ ArrayRef<int64_t> subAxesArray = subAxes.asArrayRef();
if (failed(checkMeshAxis(subAxesArray)))
return failure();
}
|
LGTM |
@@ -58,8 +58,8 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> { | |||
|
|||
let parameters = (ins | |||
AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster, | |||
ArrayRefParameter<"::mlir::DenseI8ArrayAttr">:$split_axes, | |||
OptionalArrayRefParameter<"int8_t">:$partial_axes, | |||
ArrayRefParameter<"::mlir::DenseI64ArrayAttr">:$split_axes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about DenseI32ArrayAttr?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified, both split_axes
and partial_axes
to int32
FWIW, I've tested this patch on both |
@yaochengji you most recent commit is authored from a |
Oh, Thanks! It is now corrected on my dev machine. |
Specializing for 8-bit integers to ensure values are printed as integers in a generic way will cause a bug, see #69310