-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Use DenseI64ArrayAttr
for constant_mask dim sizes
#100997
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis prevents a bunch of boilerplate conversions to/from IntegerAttrs and int64_ts. Other than that this is a NFC. Full diff: https://github.com/llvm/llvm-project/pull/100997.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 39ad03c801140..3cdbd21874567 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2443,7 +2443,7 @@ def Vector_TypeCastOp :
def Vector_ConstantMaskOp :
Vector_Op<"constant_mask", [Pure]>,
- Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
+ Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a constant vector mask";
let description = [{
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index d297c40760cd8..669ae586e5786 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) {
// Inspect constant mask index. If the index exceeds the
// dimension size, all bits are set. If the index is zero
// or less, no bits are set.
- ArrayAttr masks = m.getMaskDimSizes();
+ ArrayRef<int64_t> masks = m.getMaskDimSizes();
auto shape = m.getType().getShape();
bool allTrue = true;
bool allFalse = true;
for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
- int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
- if (i < dimSize)
+ if (maskIdx < dimSize)
allTrue = false;
- if (i > 0)
+ if (maskIdx > 0)
allFalse = false;
}
if (allTrue)
@@ -3593,8 +3592,7 @@ class StridedSliceConstantMaskFolder final
if (extractStridedSliceOp.hasNonUnitStrides())
return failure();
// Gather constant mask dimension sizes.
- SmallVector<int64_t, 4> maskDimSizes;
- populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
+ ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
// Gather strided slice offsets and sizes.
SmallVector<int64_t, 4> sliceOffsets;
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
@@ -3625,7 +3623,7 @@ class StridedSliceConstantMaskFolder final
// region.
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
- vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
+ sliceMaskDimSizes);
return success();
}
};
@@ -5410,21 +5408,19 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
if (constantMaskOp) {
- auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
+ auto maskDimSizes = constantMaskOp.getMaskDimSizes();
auto numMaskOperands = maskDimSizes.size();
// Check every mask dim size to see whether it can be dropped
for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
--i) {
- if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
+ if (maskDimSizes[i] != 1)
return failure();
}
auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
- ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
-
rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
- newMaskOperandsAttr);
+ newMaskOperands);
return success();
}
@@ -5804,12 +5800,10 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
// ConstantMaskOp case.
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
- SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
- applyPermutationToVector(newMaskDimSizes, permutation);
+ auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
- transpOp, transpOp.getResultVectorType(),
- ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
+ transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
return success();
}
};
@@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() {
if (resultType.getRank() == 0) {
if (getMaskDimSizes().size() != 1)
return emitError("array attr must have length 1 for 0-D vectors");
- auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
+ auto dim = getMaskDimSizes()[0];
if (dim != 0 && dim != 1)
return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
return success();
@@ -5846,9 +5840,8 @@ LogicalResult ConstantMaskOp::verify() {
// result dimension size.
auto resultShape = resultType.getShape();
auto resultScalableDims = resultType.getScalableDims();
- SmallVector<int64_t, 4> maskDimSizes;
- for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
- int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
+ ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
+ for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
if (maskDimSize < 0 || maskDimSize > resultShape[index])
return emitOpError(
"array attr of size out of bounds of vector result dimension size");
@@ -5856,7 +5849,6 @@ LogicalResult ConstantMaskOp::verify() {
maskDimSize != resultShape[index])
return emitOpError(
"only supports 'none set' or 'all set' scalable dimensions");
- maskDimSizes.push_back(maskDimSize);
}
// Verify that if one mask dim size is zero, they all should be zero (because
// the mask region is a conjunction of each mask dimension interval).
@@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() {
// Check the corner case of 0-D vectors first.
if (resultType.getRank() == 0) {
assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
- return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
+ return getMaskDimSizes()[0] == 1;
}
- for (const auto [resultSize, intAttr] :
+ for (const auto [resultSize, maskDimSize] :
llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
- int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
if (maskDimSize < resultSize)
return false;
}
@@ -6007,9 +5998,8 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
}
// Replace 'createMaskOp' with ConstantMaskOp.
- rewriter.replaceOpWithNewOp<ConstantMaskOp>(
- createMaskOp, retTy,
- vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
+ rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
+ maskDimSizes);
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index dfeb7bc53adad..bfc05c71f5340 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -111,7 +111,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
if (rank == 0) {
assert(dimSizes.size() == 1 &&
"Expected exactly one dim size for a 0-D vector");
- bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
+ bool value = dimSizes.front() == 1;
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, dstType,
DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
@@ -119,7 +119,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
return success();
}
- int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
+ int64_t trueDimSize = dimSizes.front();
if (rank == 1) {
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
@@ -147,7 +147,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
- loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
+ loc, lowType, dimSizes.drop_front());
Value result = rewriter.create<arith::ConstantOp>(
loc, dstType, rewriter.getZeroAttr(dstType));
for (int64_t d = 0; d < trueDimSize; d++)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 7ed3dea42b771..3d74502951404 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -551,8 +551,8 @@ struct CastAwayConstantMaskLeadingOneDim
int64_t dropDim = oldType.getRank() - newType.getRank();
SmallVector<int64_t> dimSizes;
- for (auto attr : mask.getMaskDimSizes())
- dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+ for (int64_t size : mask.getMaskDimSizes())
+ dimSizes.push_back(size);
// If any of the dropped unit dims has a size of `0`, the entire mask is a
// zero mask, else the unit dim has no effect on the mask.
@@ -563,7 +563,7 @@ struct CastAwayConstantMaskLeadingOneDim
newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
auto newMask = rewriter.create<vector::ConstantMaskOp>(
- mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
+ mask.getLoc(), newType, newDimSizes);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index ac2a4d3abcc68..d3296ee38c249 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -83,17 +83,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
newMaskOperands);
} else if (constantMaskOp) {
- ArrayRef<Attribute> maskDimSizes =
- constantMaskOp.getMaskDimSizes().getValue();
+ ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
size_t numMaskOperands = maskDimSizes.size();
- auto origIndex =
- cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
- IntegerAttr maskIndexAttr =
- rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
- SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
- newMaskDimSizes.push_back(maskIndexAttr);
- newMask = rewriter.create<vector::ConstantMaskOp>(
- loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
+ int64_t origIndex = maskDimSizes[numMaskOperands - 1];
+ int64_t maskIndex = (origIndex + scale - 1) / scale;
+ SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
+ newMaskDimSizes.push_back(maskIndex);
+ newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
+ newMaskDimSizes);
}
while (!extractOps.empty()) {
|
This prevents a bunch of boilerplate conversions to/from IntegerAttrs and int64_ts. Other than that this is a NFC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
If anyone is in the mood for some light refactoring, I think the following ops could get the same treatment:
:) |
Follow on from llvm#100997. This again removes from boilerplate conversions to/from IntegerAttr and int64_t (otherwise, this is a NFC).
Follow on from #100997. This again removes from boilerplate conversions to/from IntegerAttr and int64_t (otherwise, this is a NFC).
…lvm#100997) This prevents a bunch of boilerplate conversions to/from IntegerAttrs and int64_ts. Other than that this is a NFC.
Follow on from llvm#100997. This again removes from boilerplate conversions to/from IntegerAttr and int64_t (otherwise, this is a NFC).
This prevents a bunch of boilerplate conversions to/from IntegerAttrs and int64_ts. Other than that this is a NFC.