-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[mlir][transform] Implement FlattenElementwiseLinalgOp transform op
#81431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
6e05d6a
aff79ba
cd0ebb1
780394c
db62df3
27fb208
6ec2d19
9676452
f6e1eca
a350306
bdd04e2
12444bc
9b33d45
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1449,12 +1449,8 @@ void generateCollapsedIndexingRegion(Location loc, Block *block, | |
| } | ||
| } | ||
|
|
||
| template <typename LinalgType> | ||
| Operation *createCollapsedOp(LinalgType op, | ||
| const CollapsingInfo &collapsingInfo, | ||
| RewriterBase &rewriter) { | ||
| static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value, | ||
| "unsupported linalg op type to create"); | ||
| LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, | ||
| RewriterBase &rewriter) { | ||
| Location loc = op->getLoc(); | ||
|
|
||
| // Get the input operands. | ||
|
|
@@ -1479,40 +1475,40 @@ Operation *createCollapsedOp(LinalgType op, | |
| resultTypes.push_back(newOutput.getType()); | ||
| } | ||
|
|
||
| if (isa<linalg::CopyOp>(op)) { | ||
| return rewriter.create<linalg::CopyOp>(loc, inputOperands[0], | ||
| outputOperands[0]); | ||
| } | ||
| Operation *collapsedOp = clone( | ||
| rewriter, op, resultTypes, | ||
| llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands))); | ||
|
|
||
| // Get the iterator types for the operand. | ||
| SmallVector<utils::IteratorType> iteratorTypes = | ||
| getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo); | ||
| if (op->hasAttr("indexing_maps")) { | ||
|
||
| // Get the indexing maps. | ||
| auto indexingMaps = | ||
| llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) { | ||
| return getCollapsedOpIndexingMap(map, collapsingInfo); | ||
| }); | ||
|
|
||
| // Get the indexing maps. | ||
| auto indexingMaps = | ||
| llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) { | ||
| return getCollapsedOpIndexingMap(map, collapsingInfo); | ||
| }); | ||
| collapsedOp->setAttr("indexing_maps", | ||
| rewriter.getAffineMapArrayAttr(indexingMaps)); | ||
| } | ||
|
|
||
| Operation *collapsedOp = rewriter.create<linalg::GenericOp>( | ||
| loc, resultTypes, inputOperands, outputOperands, indexingMaps, | ||
| iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); | ||
| Block *origOpBlock = &op->getRegion(0).front(); | ||
| Block *collapsedOpBlock = &collapsedOp->getRegion(0).front(); | ||
| rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, | ||
| collapsedOpBlock->getArguments()); | ||
| if (op->hasAttr("iterator_types")) { | ||
| // Get the iterator types for the operand. | ||
| SmallVector<Attribute> iteratorTypes = llvm::map_to_vector( | ||
| getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo), | ||
| [&](utils::IteratorType itTy) { | ||
| return cast<Attribute>( | ||
| IteratorTypeAttr::get(rewriter.getContext(), itTy)); | ||
| }); | ||
| collapsedOp->setAttr("iterator_types", | ||
| rewriter.getArrayAttr(iteratorTypes)); | ||
| } | ||
|
|
||
| return collapsedOp; | ||
| return cast<LinalgOp>(collapsedOp); | ||
| } | ||
|
|
||
| /// Implementation of fusion with reshape operation by collapsing dimensions. | ||
| template <typename LinalgType> | ||
| FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims( | ||
| LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims, | ||
| FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims( | ||
| LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims, | ||
| RewriterBase &rewriter) { | ||
| static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value, | ||
| "unsupported linalg op type to collapse"); | ||
|
|
||
| // Bail on trivial no-op cases. | ||
| if (op.getNumLoops() <= 1 || foldedIterationDims.empty() || | ||
| llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { | ||
|
|
@@ -1541,8 +1537,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims( | |
| } | ||
|
|
||
| // Bail on non-canonical ranges. | ||
| SmallVector<Range> loopRanges = | ||
| cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc()); | ||
| SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc()); | ||
| auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) { | ||
| if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) | ||
| return cast<IntegerAttr>(attr).getInt() == value; | ||
|
|
@@ -1558,8 +1553,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims( | |
| op, "expected all loop ranges to have zero start and unit stride"); | ||
| } | ||
|
|
||
| LinalgType collapsedOp = cast<LinalgType>( | ||
| createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter)); | ||
| LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter); | ||
|
|
||
| Location loc = op->getLoc(); | ||
| if (collapsedOp.hasIndexSemantics()) { | ||
|
|
@@ -1600,7 +1594,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims( | |
| results.push_back(collapsedOpResult); | ||
| } | ||
| } | ||
| return results; | ||
| return CollapseResult{results, collapsedOp}; | ||
| } | ||
|
|
||
| namespace { | ||
|
|
@@ -1632,15 +1626,14 @@ class FoldWithProducerReshapeOpByCollapsing | |
| continue; | ||
| } | ||
|
|
||
| std::optional<SmallVector<Value>> replacements = | ||
| collapseOpIterationDims<linalg::GenericOp>( | ||
| genericOp, collapsableIterationDims, rewriter); | ||
| if (!replacements) { | ||
| std::optional<CollapseResult> collapseResult = collapseOpIterationDims( | ||
| genericOp, collapsableIterationDims, rewriter); | ||
| if (!collapseResult) { | ||
| return rewriter.notifyMatchFailure( | ||
| genericOp, "failed to do the fusion by collapsing transformation"); | ||
| } | ||
|
|
||
| rewriter.replaceOp(genericOp, *replacements); | ||
| rewriter.replaceOp(genericOp, (*collapseResult).results); | ||
| return success(); | ||
| } | ||
| return failure(); | ||
|
|
@@ -1674,13 +1667,12 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> { | |
| op, "specified dimensions cannot be collapsed"); | ||
| } | ||
|
|
||
| std::optional<SmallVector<Value>> replacements = | ||
| collapseOpIterationDims<LinalgType>(op, collapsableIterationDims, | ||
| rewriter); | ||
| if (!replacements) { | ||
| std::optional<CollapseResult> collapseResult = | ||
| collapseOpIterationDims(op, collapsableIterationDims, rewriter); | ||
| if (!collapseResult) { | ||
| return rewriter.notifyMatchFailure(op, "failed to collapse dimensions"); | ||
| } | ||
| rewriter.replaceOp(op, *replacements); | ||
| rewriter.replaceOp(op, (*collapseResult).results); | ||
srcarroll marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return success(); | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| // RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s | ||
|
|
||
| // CHECK-LABEL: func.func @fill( | ||
| // CHECK-SAME: %[[ARG0:.*]]: f32, | ||
| // CHECK-SAME: %[[ARG1:.*]]: memref<32x7xf32> | ||
| // CHECK-NEXT: %[[FLATTENED:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]] | ||
| // CHECK-NEXT: linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : memref<224xf32>) | ||
| func.func @fill(%cst: f32, %arg: memref<32x7xf32>) { | ||
| linalg.fill ins(%cst: f32) outs(%arg: memref<32x7xf32>) | ||
| return | ||
| } | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op | ||
| %flattened = transform.structured.flatten_elementwise %0 | ||
| : (!transform.any_op) -> !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @fill_tensor( | ||
| // CHECK-SAME: %[[ARG0:.*]]: f32, | ||
| // CHECK-SAME: %[[ARG1:.*]]: tensor<32x7xf32> | ||
| // CHECK-NEXT: %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]] | ||
| // CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>) | ||
| // CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]] | ||
| func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> { | ||
| %0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) -> tensor<32x7xf32> | ||
| return %0 : tensor<32x7xf32> | ||
| } | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op | ||
| %flattened = transform.structured.flatten_elementwise %0 | ||
| : (!transform.any_op) -> !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK-LABEL: func.func @map( | ||
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32> | ||
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32> | ||
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32> | ||
| // CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]] | ||
| // CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]] | ||
| // CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]] | ||
| // CHECK-NEXT: linalg.map { arith.addf } ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>) | ||
| func.func @map(%arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) { | ||
| linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) | ||
| return | ||
| } | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op | ||
| %flattened = transform.structured.flatten_elementwise %0 | ||
| : (!transform.any_op) -> !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> | ||
| // CHECK-LABEL: func.func @generic | ||
| // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32> | ||
| // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32> | ||
| // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32> | ||
| // CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]] | ||
| // CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]] | ||
| // CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]] | ||
| // CHECK-NEXT: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>) | ||
| // CHECK-NEXT: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32) | ||
| // CHECK-NEXT: %[[SUM:.*]] = arith.addf %[[A]], %[[B]] | ||
| // CHECK-NEXT: linalg.yield %[[SUM]] | ||
| #map = affine_map<(d0, d1) -> (d0, d1)> | ||
| func.func @generic( %arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) { | ||
| linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) { | ||
| ^bb0(%a: f32, %b: f32, %c: f32): | ||
| %0 = arith.addf %a, %b : f32 | ||
| linalg.yield %0 : f32 | ||
| } | ||
| return | ||
| } | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { | ||
| %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op | ||
| %flattened = transform.structured.flatten_elementwise %0 | ||
| : (!transform.any_op) -> !transform.any_op | ||
| transform.yield | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the operation being cloned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow up.. I see why you are cloning it.... This is an interesting way of doing it. Does it work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does.