Skip to content

Commit e07149c

Browse files
[mlir][linalg] Add option to generate rank-reducing slices in DropUnitDims
This change extends the `ReplaceUnitExtents` pattern so that users can choose between of two strategies for generating rank reductions: * CollapseShapeOp / ExpandShapeOp (was already implemented but code was cleaned up; default strategy) * rank-reducing ExtractSliceOp / InsertSliceOp Also add helper functions to the memref dialect that we already have on the tensor dialect: `getMixedSizes`, `createCanonicalRankReducingSubViewOp`, `rankReduceIfNeeded`. We are using ReassociationIndices instead of ReassoicationExprs in many other places and this makes the code easier to read. Also adding a new test case (that also passed before). Differential Revision: https://reviews.llvm.org/D139947
1 parent befd167 commit e07149c

File tree

7 files changed

+283
-143
lines changed

7 files changed

+283
-143
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
3131
Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool",
3232
/*default=*/"false",
3333
"Only folds the one-trip loops from Linalg ops on tensors "
34-
"(for testing purposes only)">
34+
"(for testing purposes only)">,
35+
Option<"useRankReducingSlices", "use-rank-reducing-slices", "bool",
36+
/*default=*/"false",
37+
"Generate rank-reducing slices instead of reassociative reshapes">
3538
];
3639
let dependentDialects = [
3740
"linalg::LinalgDialect", "AffineDialect", "memref::MemRefDialect"

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,12 @@ void populateFuseTensorPadWithProducerLinalgOpPatterns(
129129
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
130130

131131
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
132-
/// tensors.
133-
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);
132+
/// tensors via reassociative reshape ops.
133+
void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns);
134+
135+
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
136+
/// tensors via rank-reducing slices.
137+
void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns);
134138

135139
/// Patterns that are used to inline constant operands into linalg generic ops.
136140
void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ Type getTensorTypeFromMemRefType(Type type);
5454
/// single deallocate if it exists or nullptr.
5555
Optional<Operation *> findDealloc(Value allocValue);
5656

57+
/// Return the dimensions of the given memref value.
58+
SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc,
59+
Value value);
60+
61+
/// Create a rank-reducing SubViewOp @[0 .. 0] with strides [1 .. 1] and
62+
/// appropriate sizes (i.e. `memref.getSizes()`) to reduce the rank of `memref`
63+
/// to that of `targetShape`.
64+
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc,
65+
Value memref,
66+
ArrayRef<int64_t> targetShape);
5767
} // namespace memref
5868
} // namespace mlir
5969

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,6 +1954,15 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
19541954
/// Return the dimensions of the source type that are dropped when
19551955
/// the result is rank-reduced.
19561956
llvm::SmallBitVector getDroppedDims();
1957+
1958+
/// Given a `value`, asserted to be of MemRefType, build a SubViewOp that
1959+
/// results in a rank reduction to the desired memref shape and return the
1960+
/// new value created.
1961+
/// If the shape of `value` is already the `desiredShape`, just return
1962+
/// `value`.
1963+
/// If the shape of `value` cannot be rank-reduced to `desiredShape`, fail.
1964+
static FailureOr<Value> rankReduceIfNeeded(
1965+
OpBuilder &b, Location loc, Value value, ArrayRef<int64_t> desiredShape);
19571966
}];
19581967

19591968
let hasCanonicalizer = 1;

0 commit comments

Comments
 (0)