-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Fix incorrect memref::DimOp canonicalization, add tensor::DimOp canonicalization #84225
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write If you have received no comments on your PR for a week, you can request a review If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-memref Author: Sayan Saha (sahas3) ChangesThe current canonicalization of
results in:
Properly fixing this issue requires a dominator analysis which is expensive to run within a canonicalization pattern. So, this patch moves the canonicalization pattern to Full diff: https://github.com/llvm/llvm-project/pull/84225.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index c71517666b609c..2333c92fd7b12c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -629,7 +629,6 @@ def MemRef_DimOp : MemRef_Op<"dim", [
Speculation::Speculatability getSpeculatability();
}];
- let hasCanonicalizer = 1;
let hasFolder = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index b0a4de2da1e869..e1cb5b477debbc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -317,7 +317,6 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
MLIRContext *context = enclosingOp->getContext();
RewritePatternSet patterns(context);
patterns.add<LinalgRewritePattern<LoopType>>(context);
- memref::DimOp::getCanonicalizationPatterns(patterns, context);
tensor::DimOp::getCanonicalizationPatterns(patterns, context);
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldAffineOp>(context);
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 248193481acfc6..00b7fa122a6c96 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1069,39 +1069,6 @@ OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
return {};
}
-namespace {
-/// Fold dim of a memref reshape operation to a load into the reshape's shape
-/// operand.
-struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
- using OpRewritePattern<DimOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(DimOp dim,
- PatternRewriter &rewriter) const override {
- auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
-
- if (!reshape)
- return failure();
-
- // Place the load directly after the reshape to ensure that the shape memref
- // was not mutated.
- rewriter.setInsertionPointAfter(reshape);
- Location loc = dim.getLoc();
- Value load =
- rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex());
- if (load.getType() != dim.getType())
- load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
- rewriter.replaceOp(dim, load);
- return success();
- }
-};
-
-} // namespace
-
-void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<DimOfMemRefReshape>(context);
-}
-
// ---------------------------------------------------------------------------
// DmaStartOp
// ---------------------------------------------------------------------------
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index fe2f250e6b9290..ce9792f813cbb3 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -824,11 +824,36 @@ struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
return success();
}
};
+
+/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
+/// operand.
+struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
+ using OpRewritePattern<DimOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(DimOp dim,
+ PatternRewriter &rewriter) const override {
+ auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
+
+ if (!reshape)
+ return failure();
+
+ // Since tensors are immutable we don't need to worry about where to place
+ // the load call
+ rewriter.setInsertionPointAfter(dim);
+ Location loc = dim.getLoc();
+ Value load =
+ rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
+ if (load.getType() != dim.getType())
+ load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load);
+ rewriter.replaceOp(dim, load);
+ return success();
+ }
+};
} // namespace
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<DimOfCastOp, DimOfDestStyleOp>(context);
+ results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index a772a25da57382..0054a8ac785a89 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -242,48 +242,6 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
// -----
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape(
-// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
-// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
-// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-// CHECK-NEXT: memref.store
-// CHECK-NOT: memref.dim
-// CHECK: return %[[DIM]] : index
-func.func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
- -> index {
- %c3 = arith.constant 3 : index
- %0 = memref.reshape %arg0(%arg1)
- : (memref<*xf32>, memref<?xindex>) -> memref<*xf32>
- // Update the shape to test that he load ends up in the right place.
- memref.store %c3, %arg1[%c3] : memref<?xindex>
- %1 = memref.dim %0, %c3 : memref<*xf32>
- return %1 : index
-}
-
-// -----
-
-// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
-// CHECK-LABEL: func @dim_of_memref_reshape_i32(
-// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
-// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
-// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
-// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
-// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]]
-// CHECK-NOT: memref.dim
-// CHECK: return %[[CAST]] : index
-func.func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref<?xi32>)
- -> index {
- %c3 = arith.constant 3 : index
- %0 = memref.reshape %arg0(%arg1)
- : (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
- %1 = memref.dim %0, %c3 : memref<*xf32>
- return %1 : index
-}
-
-// -----
-
// CHECK-LABEL: func @alloc_const_fold
func.func @alloc_const_fold() -> memref<?xf32> {
// CHECK-NEXT: memref.alloc() : memref<4xf32>
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index d17c23adfb14d8..45d37c553a0025 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2250,3 +2250,83 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t
// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles
// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]]
// CHECK: return %[[SRC]]
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> memref.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape(
+// CHECK-SAME: %[[MEM:[0-9a-z]+]]: tensor<*xf32>,
+// CHECK-SAME: %[[SHP:[0-9a-z]+]]: tensor<?xindex>
+// CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
+// CHECK-NEXT: %[[DIM:.*]] = tensor.extract %[[SHP]][%[[IDX]]]
+// CHECK-NOT: tensor.store
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+// CHECK: return %[[DIM]] : index
+func.func @dim_of_reshape(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = tensor.reshape %arg0(%arg1)
+ : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ // Update the shape to test that the load ends up in the right place.
+ tensor.insert %c3 into %arg1[%c3] : tensor<?xindex>
+ %1 = tensor.dim %0, %c3 : tensor<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: Folding of tensor.dim(tensor.reshape %v %shp, %idx) -> tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_i32(
+// CHECK: tensor.extract
+// CHECK-NEXT: %[[CAST:.*]] = arith.index_cast
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+// CHECK: return %[[CAST]] : index
+func.func @dim_of_reshape_i32(%arg0: tensor<*xf32>, %arg1: tensor<?xi32>)
+ -> index {
+ %c3 = arith.constant 3 : index
+ %0 = tensor.reshape %arg0(%arg1)
+ : (tensor<*xf32>, tensor<?xi32>) -> tensor<*xf32>
+ %1 = tensor.dim %0, %c3 : tensor<*xf32>
+ return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_for(
+// CHECK: scf.for
+// CHECK-NEXT: tensor.extract
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_for( %arg0: tensor<*xf32>, %arg1: tensor<?xindex>) -> index {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ %0 = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+
+ %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %c1) -> (index) {
+ %2 = tensor.dim %0, %arg2 : tensor<*xf32>
+ %3 = arith.muli %arg3, %2 : index
+ scf.yield %3 : index
+ }
+ return %1 : index
+}
+
+// -----
+
+// Test case: tensor.dim(tensor.reshape %v %shp, %idx) is not folded into tensor.extract %shp[%idx]
+// CHECK-LABEL: func @dim_of_reshape_undominated(
+// CHECK: arith.muli
+// CHECK-NEXT: tensor.extract
+// CHECK-NOT: tensor.dim
+// CHECK-NOT: tensor.reshape
+func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xindex>, %arg2: index) -> index {
+ %c4 = arith.constant 4 : index
+ %reshape = tensor.reshape %arg0(%arg1) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
+ %0 = arith.muli %arg2, %c4 : index
+ %dim = tensor.dim %reshape, %0 : tensor<*xf32>
+ return %dim : index
+ }
|
LGTM overall, but should we also have the negative tests you have on the memref::DimOp? |
I am not sure that removing the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the memrefs canonicalization needs to be removed
Actually I understand why this is a problem. This op is just bonkers. The semantics are pretty poor. So at least it makes sense to remove the canonicalization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tensor dimOp canonicalization seemed valuable, are you bringing this another PR?
@joker-eph Are you suggesting it should be it's own PR? |
Nevermind: GitHub was showing me the diff since the last review (or just for the most recent push) so I was only seeing the Memref one. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
@MaheshRavishankar any concern left here? |
Nope. I dismissed my review here. Thanks for checking |
Thanks! Merging then |
@sahas3 Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested Please check whether problems have been caused by your change specifically, as How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
The current canonicalization of
memref.dim
operating on the result ofmemref.reshape
intomemref.load
is incorrect as it doesn't check whether theindex
operand ofmemref.dim
dominates the sourcememref.reshape
op. It always introducesmemref.load
right aftermemref.reshape
to ensure thememref
is not mutated before thememref.load
call. As a result, the following error is observed:results in:
Properly fixing this issue requires a dominator analysis which is expensive to run within a canonicalization pattern. So, this patch moves the canonicalization pattern to
tensor.dim
. Since tensors are immutable we don't need to worry about where to introduce thetensor.extract
call after canonicalization.