diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 309573a562872..53ed31877c6f2 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2295,6 +2295,49 @@ def ConvertConv2DToImg2ColOp : Op { + let description = [{ + Flattens the iteration space and (applicable) operands of elementwise + linalg ops to a single dimension. + + Returns one handle: + - Flattened linalg operation. + + #### Return modes: + + Returns a definite failure if target is not isolated from above. + Returns a silenceable failure if the pattern application failed. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::linalg::LinalgOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + //===----------------------------------------------------------------------===// // Transpose Conv2D //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index a848d12fbbb50..65cf19e7a4fcd 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1074,6 +1074,11 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence); bool areDimSequencesPreserved(ArrayRef maps, ArrayRef dimSequences); +struct CollapseResult { + SmallVector results; + LinalgOp collapsedOp; +}; + /// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition /// to calling this method is that for each list in `foldedIterationDim`, the /// sequence of dimensions is contiguous in domains of all `indexing_maps` of @@ -1081,9 +1086,8 @@ bool areDimSequencesPreserved(ArrayRef maps, /// When valid, the method also collapses the operands of the op. Returns /// replacement values of the results of the original `linalgOp` by inserting /// reshapes to get back values of compatible types. -template -FailureOr> -collapseOpIterationDims(LinalgType op, +FailureOr +collapseOpIterationDims(LinalgOp op, ArrayRef foldedIterationDims, RewriterBase &rewriter); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 299965bcfc3ab..ef9cd5561665f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3244,6 +3244,31 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// FlattenElementwiseLinalgOp. +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + if (target.getNumLoops() <= 1) + return DiagnosedSilenceableFailure::success(); + ReassociationIndices reassociation(target.getNumLoops()); + std::iota(reassociation.begin(), reassociation.end(), 0); + auto maybeFlattened = + (isElementwise(target)) + ? collapseOpIterationDims(target, reassociation, rewriter) + : FailureOr(rewriter.notifyMatchFailure( + target, "only elementwise flattening is supported")); + if (failed(maybeFlattened)) + return emitDefaultSilenceableFailure(target); + results.push_back(maybeFlattened->collapsedOp); + rewriter.replaceOp(target, maybeFlattened->results); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // TransposeConv2DOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 4977940cfbd79..4797bfb2267d7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1446,24 +1446,20 @@ void generateCollapsedIndexingRegion(Location loc, Block *block, } } -template -Operation *createCollapsedOp(LinalgType op, - const CollapsingInfo &collapsingInfo, - RewriterBase &rewriter) { - static_assert(llvm::is_one_of::value, - "unsupported linalg op type to create"); +void collapseOperandsAndResults(LinalgOp op, + const CollapsingInfo &collapsingInfo, + RewriterBase &rewriter, + SmallVectorImpl &inputOperands, + SmallVectorImpl &outputOperands, + SmallVectorImpl &resultTypes) { Location loc = op->getLoc(); - - // Get the input operands. - SmallVector inputOperands = + inputOperands = llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) { return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo, rewriter); }); // Get the output operands and result types. - SmallVector resultTypes; - SmallVector outputOperands; resultTypes.reserve(op.getNumDpsInits()); outputOperands.reserve(op.getNumDpsInits()); for (OpOperand &output : op.getDpsInitsMutable()) { @@ -1475,41 +1471,69 @@ Operation *createCollapsedOp(LinalgType op, if (!op.hasPureBufferSemantics()) resultTypes.push_back(newOutput.getType()); } +} - if (isa(op)) { - return rewriter.create(loc, inputOperands[0], - outputOperands[0]); - } +/// Clone a `LinalgOp` to a collapsed version of same name +template +OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, + const CollapsingInfo &collapsingInfo) { + return nullptr; +} - // Get the iterator types for the operand. - SmallVector iteratorTypes = - getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo); +/// Collapse any `LinalgOp` that does not require any specialization such as +/// indexing_maps, iterator_types, etc. +template <> +LinalgOp cloneToCollapsedOp(RewriterBase &rewriter, LinalgOp origOp, + const CollapsingInfo &collapsingInfo) { + SmallVector inputOperands, outputOperands; + SmallVector resultTypes; + collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands, + outputOperands, resultTypes); + return cast(clone( + rewriter, origOp, resultTypes, + llvm::to_vector(llvm::concat(inputOperands, outputOperands)))); +} - // Get the indexing maps. - auto indexingMaps = - llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) { +/// Collapse a `GenericOp` +template <> +GenericOp cloneToCollapsedOp(RewriterBase &rewriter, + GenericOp origOp, + const CollapsingInfo &collapsingInfo) { + SmallVector inputOperands, outputOperands; + SmallVector resultTypes; + collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands, + outputOperands, resultTypes); + SmallVector indexingMaps( + llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) { return getCollapsedOpIndexingMap(map, collapsingInfo); - }); + })); + + SmallVector iteratorTypes(getCollapsedOpIteratorTypes( + origOp.getIteratorTypesArray(), collapsingInfo)); - Operation *collapsedOp = rewriter.create( - loc, resultTypes, inputOperands, outputOperands, indexingMaps, + GenericOp collapsedOp = rewriter.create( + origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps, iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); - Block *origOpBlock = &op->getRegion(0).front(); + Block *origOpBlock = &origOp->getRegion(0).front(); Block *collapsedOpBlock = &collapsedOp->getRegion(0).front(); rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, collapsedOpBlock->getArguments()); - return collapsedOp; } +LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, + RewriterBase &rewriter) { + if (GenericOp genericOp = dyn_cast(op.getOperation())) { + return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo); + } else { + return cloneToCollapsedOp(rewriter, op, collapsingInfo); + } +} + /// Implementation of fusion with reshape operation by collapsing dimensions. -template -FailureOr> mlir::linalg::collapseOpIterationDims( - LinalgType op, ArrayRef foldedIterationDims, +FailureOr mlir::linalg::collapseOpIterationDims( + LinalgOp op, ArrayRef foldedIterationDims, RewriterBase &rewriter) { - static_assert(llvm::is_one_of::value, - "unsupported linalg op type to collapse"); - // Bail on trivial no-op cases. if (op.getNumLoops() <= 1 || foldedIterationDims.empty() || llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { @@ -1538,8 +1562,7 @@ FailureOr> mlir::linalg::collapseOpIterationDims( } // Bail on non-canonical ranges. - SmallVector loopRanges = - cast(op.getOperation()).createLoopRanges(rewriter, op.getLoc()); + SmallVector loopRanges = op.createLoopRanges(rewriter, op.getLoc()); auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) { if (auto attr = llvm::dyn_cast_if_present(ofr)) return cast(attr).getInt() == value; @@ -1555,8 +1578,7 @@ FailureOr> mlir::linalg::collapseOpIterationDims( op, "expected all loop ranges to have zero start and unit stride"); } - LinalgType collapsedOp = cast( - createCollapsedOp(op, collapsingInfo, rewriter)); + LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter); Location loc = op->getLoc(); if (collapsedOp.hasIndexSemantics()) { @@ -1597,7 +1619,7 @@ FailureOr> mlir::linalg::collapseOpIterationDims( results.push_back(collapsedOpResult); } } - return results; + return CollapseResult{results, collapsedOp}; } namespace { @@ -1629,15 +1651,14 @@ class FoldWithProducerReshapeOpByCollapsing continue; } - std::optional> replacements = - collapseOpIterationDims( - genericOp, collapsableIterationDims, rewriter); - if (!replacements) { + std::optional collapseResult = collapseOpIterationDims( + genericOp, collapsableIterationDims, rewriter); + if (!collapseResult) { return rewriter.notifyMatchFailure( genericOp, "failed to do the fusion by collapsing transformation"); } - rewriter.replaceOp(genericOp, *replacements); + rewriter.replaceOp(genericOp, collapseResult->results); return success(); } return failure(); @@ -1671,13 +1692,12 @@ class CollapseLinalgDimensions : public OpRewritePattern { op, "specified dimensions cannot be collapsed"); } - std::optional> replacements = - collapseOpIterationDims(op, collapsableIterationDims, - rewriter); - if (!replacements) { + std::optional collapseResult = + collapseOpIterationDims(op, collapsableIterationDims, rewriter); + if (!collapseResult) { return rewriter.notifyMatchFailure(op, "failed to collapse dimensions"); } - rewriter.replaceOp(op, *replacements); + rewriter.replaceOp(op, collapseResult->results); return success(); } diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir new file mode 100644 index 0000000000000..858c133dd536c --- /dev/null +++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir @@ -0,0 +1,99 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @fill( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: memref<32x7xf32> +// CHECK-NEXT: %[[FLATTENED:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]] +// CHECK-NEXT: linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : memref<224xf32>) +func.func @fill(%cst: f32, %arg: memref<32x7xf32>) { + linalg.fill ins(%cst: f32) outs(%arg: memref<32x7xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %flattened = transform.structured.flatten_elementwise %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fill_tensor( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: tensor<32x7xf32> +// CHECK-NEXT: %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]] +// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>) +// CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]] +func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> { + %0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) -> tensor<32x7xf32> + return %0 : tensor<32x7xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %flattened = transform.structured.flatten_elementwise %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @map( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]] +// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]] +// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]] +// CHECK-NEXT: linalg.map { arith.addf } ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>) +func.func @map(%arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) { + linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %flattened = transform.structured.flatten_elementwise %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @generic +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]] +// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]] +// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]] +// CHECK-NEXT: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>) +// CHECK-NEXT: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32) +// CHECK-NEXT: %[[SUM:.*]] = arith.addf %[[A]], %[[B]] +// CHECK-NEXT: linalg.yield %[[SUM]] +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @generic( %arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) { + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) { + ^bb0(%a: f32, %b: f32, %c: f32): + %0 = arith.addf %a, %b : f32 + linalg.yield %0 : f32 + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %flattened = transform.structured.flatten_elementwise %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +}