From 6e05d6a3ed218797ae264fc88f8998a0a4b945dc Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 11 Feb 2024 02:33:16 -0600 Subject: [PATCH 01/11] Implement FlattenElementwiseLinalgOp transform --- .../Linalg/TransformOps/LinalgTransformOps.td | 42 +++++++++ .../TransformOps/LinalgTransformOps.cpp | 87 +++++++++++++++++++ 2 files changed, 129 insertions(+) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 309573a562872..d8d864d14ea69 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2295,6 +2295,48 @@ def ConvertConv2DToImg2ColOp : Op { + let description = [{ + Flattens elementwise linalg ops. + + Returns one handle: + - Flattened linalg operation. + + #### Return modes: + + Returns a definite failure if target is not isolated from above. + Returns a silenceable failure if the pattern application failed. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::linalg::LinalgOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + //===----------------------------------------------------------------------===// // Transpose Conv2D //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 585fd14b40d76..57fce5e7a749f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3243,6 +3243,93 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// FlattenElementwiseLinalgOp. +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + auto flatten = [&](linalg::LinalgOp op) -> FailureOr { + if (!isElementwise(target)) { + return rewriter.notifyMatchFailure( + target, "only elementwise flattening is supported"); + } + if (!llvm::all_of(target.getIndexingMapsArray(), + [](auto map) { return map.isMinorIdentity(); })) { + return rewriter.notifyMatchFailure( + target, "only minor identity indexing maps is supported"); + } + ShapedType nonEmptyShapeType = nullptr; + for (const auto &resultVal : target.getDpsInitsMutable()) { + auto resultType = resultVal.get().getType(); + if (ShapedType resultShapedType = dyn_cast(resultType)) { + if (resultShapedType.getShape().empty()) + continue; + if (nonEmptyShapeType == nullptr) { + nonEmptyShapeType = resultShapedType; + } else if (resultShapedType != nonEmptyShapeType) { + return rewriter.notifyMatchFailure( + target, "all operands (except rank 0) must have same types"); + } + } + } + if (target.hasPureBufferSemantics()) { + if (!llvm::all_of(target->getOperands(), [](Value operand) { + if (auto memRefTy = dyn_cast(operand.getType())) + return memRefTy.getLayout().isIdentity(); + return true; + })) { + return rewriter.notifyMatchFailure( + target, "only memrefs with identity layout is supported"); + } + } + ReassociationIndices reassociation(nonEmptyShapeType.getRank()); + std::iota(reassociation.begin(), reassociation.end(), 0); + auto flattenOperand = [&](const Value &operand) { + return (!isa(operand.getType())) + ? operand + : rewriter + .create(target.getLoc(), + operand, reassociation) + .getResult(); + }; + SmallVector flattenedInputs( + llvm::map_range(target.getDpsInputs(), [&](const Value &operand) { + return flattenOperand(operand); + })); + SmallVector flattenedInits( + llvm::map_range(target.getDpsInits(), [&](const Value &operand) { + return flattenOperand(operand); + })); + + SmallVector flattenedMaps(llvm::map_range( + llvm::concat(flattenedInputs, flattenedInits), + [&](const Value &val) { + if (auto memRefTy = dyn_cast(val.getType())) + return AffineMap::getMinorIdentityMap(1, memRefTy.getRank(), + target.getContext()); + return AffineMap::getMinorIdentityMap(1, 0, target.getContext()); + })); + + auto flattenedLinalgOp = rewriter.create( + target.getLoc(), TypeRange(), flattenedInputs, flattenedInits, + flattenedMaps, + SmallVector{utils::IteratorType::parallel}); + flattenedLinalgOp.getRegion().takeBody(target->getRegion(0)); + return flattenedLinalgOp; + return success(); + }; + auto maybeFlattened = flatten(target); + if (failed(maybeFlattened)) + return emitDefaultSilenceableFailure(target); + results.push_back(*maybeFlattened); + rewriter.eraseOp(target); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // TransposeConv2DOp //===----------------------------------------------------------------------===// From aff79baad62b53f8f10f733d5ff3c0068556549d Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 11 Feb 2024 14:57:07 -0600 Subject: [PATCH 02/11] Add a couple regression tests --- .../TransformOps/LinalgTransformOps.cpp | 50 +++++++----- .../Dialect/Linalg/flatten-elementwise.mlir | 77 +++++++++++++++++++ 2 files changed, 106 insertions(+), 21 deletions(-) create mode 100644 mlir/test/Dialect/Linalg/flatten-elementwise.mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 57fce5e7a749f..15f7f82e24f3a 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3252,19 +3252,22 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( transform::ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); - auto flatten = [&](linalg::LinalgOp op) -> FailureOr { + if (target.getNumLoops() <= 1) + return DiagnosedSilenceableFailure::success(); + auto flatten = [&](linalg::LinalgOp &op) -> FailureOr { if (!isElementwise(target)) { return rewriter.notifyMatchFailure( target, "only elementwise flattening is supported"); } + // TODO: Support broadcasting and permutations if (!llvm::all_of(target.getIndexingMapsArray(), [](auto map) { return map.isMinorIdentity(); })) { return rewriter.notifyMatchFailure( target, "only minor identity indexing maps is supported"); } ShapedType nonEmptyShapeType = nullptr; - for (const auto &resultVal : target.getDpsInitsMutable()) { - auto resultType = resultVal.get().getType(); + for (const auto &resultVal : target->getOperands()) { + auto resultType = resultVal.getType(); if (ShapedType resultShapedType = dyn_cast(resultType)) { if (resultShapedType.getShape().empty()) continue; @@ -3277,6 +3280,7 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( } } if (target.hasPureBufferSemantics()) { + // TODO: Relax restrictions on layout if (!llvm::all_of(target->getOperands(), [](Value operand) { if (auto memRefTy = dyn_cast(operand.getType())) return memRefTy.getLayout().isIdentity(); @@ -3285,8 +3289,11 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( return rewriter.notifyMatchFailure( target, "only memrefs with identity layout is supported"); } + } else { + // TODO: Support tensors + return rewriter.notifyMatchFailure(target, "tensors are not supported"); } - ReassociationIndices reassociation(nonEmptyShapeType.getRank()); + ReassociationIndices reassociation(target.getNumLoops()); std::iota(reassociation.begin(), reassociation.end(), 0); auto flattenOperand = [&](const Value &operand) { return (!isa(operand.getType())) @@ -3296,37 +3303,38 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( operand, reassociation) .getResult(); }; - SmallVector flattenedInputs( - llvm::map_range(target.getDpsInputs(), [&](const Value &operand) { - return flattenOperand(operand); - })); - SmallVector flattenedInits( - llvm::map_range(target.getDpsInits(), [&](const Value &operand) { + SmallVector flattenedOperands( + llvm::map_range(target->getOperands(), [&](const Value &operand) { return flattenOperand(operand); })); - SmallVector flattenedMaps(llvm::map_range( - llvm::concat(flattenedInputs, flattenedInits), - [&](const Value &val) { + SmallVector flattenedMaps( + llvm::map_range(flattenedOperands, [&](const Value &val) { if (auto memRefTy = dyn_cast(val.getType())) return AffineMap::getMinorIdentityMap(1, memRefTy.getRank(), target.getContext()); return AffineMap::getMinorIdentityMap(1, 0, target.getContext()); })); - auto flattenedLinalgOp = rewriter.create( - target.getLoc(), TypeRange(), flattenedInputs, flattenedInits, - flattenedMaps, - SmallVector{utils::IteratorType::parallel}); - flattenedLinalgOp.getRegion().takeBody(target->getRegion(0)); - return flattenedLinalgOp; - return success(); + rewriter.modifyOpInPlace(op, [&]() { + op->setOperands(flattenedOperands); + // TODO: Find a more general way to determine if op requires explicit + // indexing_maps and iterator_types + if (isa(op)) { + op->setAttr("indexing_maps", + rewriter.getAffineMapArrayAttr(flattenedMaps)); + op->setAttr( + "iterator_types", + rewriter.getArrayAttr({IteratorTypeAttr::get( + rewriter.getContext(), utils::IteratorType::parallel)})); + } + }); + return op; }; auto maybeFlattened = flatten(target); if (failed(maybeFlattened)) return emitDefaultSilenceableFailure(target); results.push_back(*maybeFlattened); - rewriter.eraseOp(target); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir new file mode 100644 index 0000000000000..e360fc3ff5178 --- /dev/null +++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir @@ -0,0 +1,77 @@ +// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @fill( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: memref<32x7xf32> +// CHECK-NEXT: %[[FLATTENED:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]] +// CHECK-NEXT: linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : memref<224xf32>) +func.func @fill(%cst: f32, %arg: memref<32x7xf32>) { + linalg.fill ins(%cst: f32) outs(%arg: memref<32x7xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %flattened = transform.structured.flatten_elementwise %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @map( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]] +// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]] +// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]] +// CHECK-NEXT: linalg.map { arith.addf } ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>) +func.func @map(%arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) { + linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %flattened = transform.structured.flatten_elementwise %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func.func @generic +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32> +// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]] +// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]] +// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]] +// CHECK-NEXT: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>) +// CHECK-NEXT: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32) +// CHECK-NEXT: %[[SUM:.*]] = arith.addf %[[A]], %[[B]] +// CHECK-NEXT: linalg.yield %[[SUM]] +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @generic( %arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) { + linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) { + ^bb0(%a: f32, %b: f32, %c: f32): + %0 = arith.addf %a, %b : f32 + linalg.yield %0 : f32 + } + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %flattened = transform.structured.flatten_elementwise %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} \ No newline at end of file From cd0ebb1051264dbffd4c0fb1a386150a05ff6ef2 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 13 Feb 2024 22:27:00 -0600 Subject: [PATCH 03/11] Refactor `collapseOpIterationDims` to work for all linalg ops --- .../Dialect/Linalg/Transforms/Transforms.h | 3 +- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 60 ++++++++----------- 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index a848d12fbbb50..a566745185ad9 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1081,9 +1081,8 @@ bool areDimSequencesPreserved(ArrayRef maps, /// When valid, the method also collapses the operands of the op. Returns /// replacement values of the results of the original `linalgOp` by inserting /// reshapes to get back values of compatible types. -template FailureOr> -collapseOpIterationDims(LinalgType op, +collapseOpIterationDims(LinalgOp op, ArrayRef foldedIterationDims, RewriterBase &rewriter); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 286b07669a47f..ce58caa6c39aa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1449,12 +1449,8 @@ void generateCollapsedIndexingRegion(Location loc, Block *block, } } -template -Operation *createCollapsedOp(LinalgType op, - const CollapsingInfo &collapsingInfo, - RewriterBase &rewriter) { - static_assert(llvm::is_one_of::value, - "unsupported linalg op type to create"); +LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, + RewriterBase &rewriter) { Location loc = op->getLoc(); // Get the input operands. @@ -1479,14 +1475,17 @@ Operation *createCollapsedOp(LinalgType op, resultTypes.push_back(newOutput.getType()); } - if (isa(op)) { - return rewriter.create(loc, inputOperands[0], - outputOperands[0]); - } + Operation *collapsedOp = clone( + rewriter, op, resultTypes, + llvm::to_vector(llvm::concat(inputOperands, outputOperands))); // Get the iterator types for the operand. - SmallVector iteratorTypes = - getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo); + SmallVector iteratorTypes = llvm::map_to_vector( + getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo), + [&](utils::IteratorType itTy) { + return cast( + IteratorTypeAttr::get(rewriter.getContext(), itTy)); + }); // Get the indexing maps. auto indexingMaps = @@ -1494,25 +1493,22 @@ Operation *createCollapsedOp(LinalgType op, return getCollapsedOpIndexingMap(map, collapsingInfo); }); - Operation *collapsedOp = rewriter.create( - loc, resultTypes, inputOperands, outputOperands, indexingMaps, - iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); - Block *origOpBlock = &op->getRegion(0).front(); - Block *collapsedOpBlock = &collapsedOp->getRegion(0).front(); - rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, - collapsedOpBlock->getArguments()); + // TODO: Find a more general way to determine if op requires explicit + // indexing_maps and iterator_types + if (isa(op)) { + collapsedOp->setAttr("indexing_maps", + rewriter.getAffineMapArrayAttr(indexingMaps)); + collapsedOp->setAttr("iterator_types", + rewriter.getArrayAttr(iteratorTypes)); + } - return collapsedOp; + return cast(collapsedOp); } /// Implementation of fusion with reshape operation by collapsing dimensions. -template FailureOr> mlir::linalg::collapseOpIterationDims( - LinalgType op, ArrayRef foldedIterationDims, + LinalgOp op, ArrayRef foldedIterationDims, RewriterBase &rewriter) { - static_assert(llvm::is_one_of::value, - "unsupported linalg op type to collapse"); - // Bail on trivial no-op cases. if (op.getNumLoops() <= 1 || foldedIterationDims.empty() || llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { @@ -1541,8 +1537,7 @@ FailureOr> mlir::linalg::collapseOpIterationDims( } // Bail on non-canonical ranges. - SmallVector loopRanges = - cast(op.getOperation()).createLoopRanges(rewriter, op.getLoc()); + SmallVector loopRanges = op.createLoopRanges(rewriter, op.getLoc()); auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) { if (auto attr = llvm::dyn_cast_if_present(ofr)) return cast(attr).getInt() == value; @@ -1558,8 +1553,7 @@ FailureOr> mlir::linalg::collapseOpIterationDims( op, "expected all loop ranges to have zero start and unit stride"); } - LinalgType collapsedOp = cast( - createCollapsedOp(op, collapsingInfo, rewriter)); + LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter); Location loc = op->getLoc(); if (collapsedOp.hasIndexSemantics()) { @@ -1632,9 +1626,8 @@ class FoldWithProducerReshapeOpByCollapsing continue; } - std::optional> replacements = - collapseOpIterationDims( - genericOp, collapsableIterationDims, rewriter); + std::optional> replacements = collapseOpIterationDims( + genericOp, collapsableIterationDims, rewriter); if (!replacements) { return rewriter.notifyMatchFailure( genericOp, "failed to do the fusion by collapsing transformation"); @@ -1675,8 +1668,7 @@ class CollapseLinalgDimensions : public OpRewritePattern { } std::optional> replacements = - collapseOpIterationDims(op, collapsableIterationDims, - rewriter); + collapseOpIterationDims(op, collapsableIterationDims, rewriter); if (!replacements) { return rewriter.notifyMatchFailure(op, "failed to collapse dimensions"); } From 780394c14974a2aed9d9e7bbaa86a9584939dbda Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 13 Feb 2024 23:19:12 -0600 Subject: [PATCH 04/11] Refactor `FlattenElementwiseLinalgOp` to use `collapseOpIterationDims` --- .../Dialect/Linalg/Transforms/Transforms.h | 7 ++- .../TransformOps/LinalgTransformOps.cpp | 40 ++--------------- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 44 +++++++++---------- 3 files changed, 31 insertions(+), 60 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index a566745185ad9..65cf19e7a4fcd 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1074,6 +1074,11 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence); bool areDimSequencesPreserved(ArrayRef maps, ArrayRef dimSequences); +struct CollapseResult { + SmallVector results; + LinalgOp collapsedOp; +}; + /// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition /// to calling this method is that for each list in `foldedIterationDim`, the /// sequence of dimensions is contiguous in domains of all `indexing_maps` of @@ -1081,7 +1086,7 @@ bool areDimSequencesPreserved(ArrayRef maps, /// When valid, the method also collapses the operands of the op. Returns /// replacement values of the results of the original `linalgOp` by inserting /// reshapes to get back values of compatible types. -FailureOr> +FailureOr collapseOpIterationDims(LinalgOp op, ArrayRef foldedIterationDims, RewriterBase &rewriter); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 15f7f82e24f3a..25e72ab273833 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3254,7 +3254,7 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( rewriter.setInsertionPoint(target); if (target.getNumLoops() <= 1) return DiagnosedSilenceableFailure::success(); - auto flatten = [&](linalg::LinalgOp &op) -> FailureOr { + auto flatten = [&](linalg::LinalgOp &op) -> FailureOr { if (!isElementwise(target)) { return rewriter.notifyMatchFailure( target, "only elementwise flattening is supported"); @@ -3295,46 +3295,12 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( } ReassociationIndices reassociation(target.getNumLoops()); std::iota(reassociation.begin(), reassociation.end(), 0); - auto flattenOperand = [&](const Value &operand) { - return (!isa(operand.getType())) - ? operand - : rewriter - .create(target.getLoc(), - operand, reassociation) - .getResult(); - }; - SmallVector flattenedOperands( - llvm::map_range(target->getOperands(), [&](const Value &operand) { - return flattenOperand(operand); - })); - - SmallVector flattenedMaps( - llvm::map_range(flattenedOperands, [&](const Value &val) { - if (auto memRefTy = dyn_cast(val.getType())) - return AffineMap::getMinorIdentityMap(1, memRefTy.getRank(), - target.getContext()); - return AffineMap::getMinorIdentityMap(1, 0, target.getContext()); - })); - - rewriter.modifyOpInPlace(op, [&]() { - op->setOperands(flattenedOperands); - // TODO: Find a more general way to determine if op requires explicit - // indexing_maps and iterator_types - if (isa(op)) { - op->setAttr("indexing_maps", - rewriter.getAffineMapArrayAttr(flattenedMaps)); - op->setAttr( - "iterator_types", - rewriter.getArrayAttr({IteratorTypeAttr::get( - rewriter.getContext(), utils::IteratorType::parallel)})); - } - }); - return op; + return collapseOpIterationDims(op, reassociation, rewriter); }; auto maybeFlattened = flatten(target); if (failed(maybeFlattened)) return emitDefaultSilenceableFailure(target); - results.push_back(*maybeFlattened); + results.push_back((*maybeFlattened).collapsedOp); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index ce58caa6c39aa..013a31ee1d959 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1479,23 +1479,23 @@ LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, rewriter, op, resultTypes, llvm::to_vector(llvm::concat(inputOperands, outputOperands))); - // Get the iterator types for the operand. - SmallVector iteratorTypes = llvm::map_to_vector( - getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo), - [&](utils::IteratorType itTy) { - return cast( - IteratorTypeAttr::get(rewriter.getContext(), itTy)); - }); - - // Get the indexing maps. - auto indexingMaps = - llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) { - return getCollapsedOpIndexingMap(map, collapsingInfo); - }); - // TODO: Find a more general way to determine if op requires explicit // indexing_maps and iterator_types if (isa(op)) { + // Get the iterator types for the operand. + SmallVector iteratorTypes = llvm::map_to_vector( + getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo), + [&](utils::IteratorType itTy) { + return cast( + IteratorTypeAttr::get(rewriter.getContext(), itTy)); + }); + + // Get the indexing maps. + auto indexingMaps = + llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) { + return getCollapsedOpIndexingMap(map, collapsingInfo); + }); + collapsedOp->setAttr("indexing_maps", rewriter.getAffineMapArrayAttr(indexingMaps)); collapsedOp->setAttr("iterator_types", @@ -1506,7 +1506,7 @@ LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, } /// Implementation of fusion with reshape operation by collapsing dimensions. -FailureOr> mlir::linalg::collapseOpIterationDims( +FailureOr mlir::linalg::collapseOpIterationDims( LinalgOp op, ArrayRef foldedIterationDims, RewriterBase &rewriter) { // Bail on trivial no-op cases. @@ -1594,7 +1594,7 @@ FailureOr> mlir::linalg::collapseOpIterationDims( results.push_back(collapsedOpResult); } } - return results; + return CollapseResult{.results = results, .collapsedOp = collapsedOp}; } namespace { @@ -1626,14 +1626,14 @@ class FoldWithProducerReshapeOpByCollapsing continue; } - std::optional> replacements = collapseOpIterationDims( + std::optional collapseResult = collapseOpIterationDims( genericOp, collapsableIterationDims, rewriter); - if (!replacements) { + if (!collapseResult) { return rewriter.notifyMatchFailure( genericOp, "failed to do the fusion by collapsing transformation"); } - rewriter.replaceOp(genericOp, *replacements); + rewriter.replaceOp(genericOp, (*collapseResult).results); return success(); } return failure(); @@ -1667,12 +1667,12 @@ class CollapseLinalgDimensions : public OpRewritePattern { op, "specified dimensions cannot be collapsed"); } - std::optional> replacements = + std::optional collapseResult = collapseOpIterationDims(op, collapsableIterationDims, rewriter); - if (!replacements) { + if (!collapseResult) { return rewriter.notifyMatchFailure(op, "failed to collapse dimensions"); } - rewriter.replaceOp(op, *replacements); + rewriter.replaceOp(op, (*collapseResult).results); return success(); } From db62df3da264838cd4d5675a8ade7c929c076123 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 13 Feb 2024 23:30:29 -0600 Subject: [PATCH 05/11] Add EOL --- mlir/test/Dialect/Linalg/flatten-elementwise.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir index e360fc3ff5178..147759e13aa48 100644 --- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir @@ -74,4 +74,4 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> !transform.any_op transform.yield } -} \ No newline at end of file +} From 27fb2083c0411206920a715342dddb39ec01344f Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 13 Feb 2024 23:50:26 -0600 Subject: [PATCH 06/11] Remove restrictions and add tensor test --- .../TransformOps/LinalgTransformOps.cpp | 51 +++---------------- .../Dialect/Linalg/flatten-elementwise.mlir | 22 ++++++++ 2 files changed, 29 insertions(+), 44 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 25e72ab273833..1be7b261995fd 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3254,50 +3254,13 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( rewriter.setInsertionPoint(target); if (target.getNumLoops() <= 1) return DiagnosedSilenceableFailure::success(); - auto flatten = [&](linalg::LinalgOp &op) -> FailureOr { - if (!isElementwise(target)) { - return rewriter.notifyMatchFailure( - target, "only elementwise flattening is supported"); - } - // TODO: Support broadcasting and permutations - if (!llvm::all_of(target.getIndexingMapsArray(), - [](auto map) { return map.isMinorIdentity(); })) { - return rewriter.notifyMatchFailure( - target, "only minor identity indexing maps is supported"); - } - ShapedType nonEmptyShapeType = nullptr; - for (const auto &resultVal : target->getOperands()) { - auto resultType = resultVal.getType(); - if (ShapedType resultShapedType = dyn_cast(resultType)) { - if (resultShapedType.getShape().empty()) - continue; - if (nonEmptyShapeType == nullptr) { - nonEmptyShapeType = resultShapedType; - } else if (resultShapedType != nonEmptyShapeType) { - return rewriter.notifyMatchFailure( - target, "all operands (except rank 0) must have same types"); - } - } - } - if (target.hasPureBufferSemantics()) { - // TODO: Relax restrictions on layout - if (!llvm::all_of(target->getOperands(), [](Value operand) { - if (auto memRefTy = dyn_cast(operand.getType())) - return memRefTy.getLayout().isIdentity(); - return true; - })) { - return rewriter.notifyMatchFailure( - target, "only memrefs with identity layout is supported"); - } - } else { - // TODO: Support tensors - return rewriter.notifyMatchFailure(target, "tensors are not supported"); - } - ReassociationIndices reassociation(target.getNumLoops()); - std::iota(reassociation.begin(), reassociation.end(), 0); - return collapseOpIterationDims(op, reassociation, rewriter); - }; - auto maybeFlattened = flatten(target); + ReassociationIndices reassociation(target.getNumLoops()); + std::iota(reassociation.begin(), reassociation.end(), 0); + auto maybeFlattened = + (isElementwise(target)) + ? collapseOpIterationDims(target, reassociation, rewriter) + : FailureOr(rewriter.notifyMatchFailure( + target, "only elementwise flattening is supported")); if (failed(maybeFlattened)) return emitDefaultSilenceableFailure(target); results.push_back((*maybeFlattened).collapsedOp); diff --git a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir index 147759e13aa48..858c133dd536c 100644 --- a/mlir/test/Dialect/Linalg/flatten-elementwise.mlir +++ b/mlir/test/Dialect/Linalg/flatten-elementwise.mlir @@ -21,6 +21,28 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func.func @fill_tensor( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: tensor<32x7xf32> +// CHECK-NEXT: %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]] +// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>) +// CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]] +func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> { + %0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) -> tensor<32x7xf32> + return %0 : tensor<32x7xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op + %flattened = transform.structured.flatten_elementwise %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + // CHECK-LABEL: func.func @map( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32> From 6ec2d19a7d012094952aad0f4acaa259f66fa780 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 14 Feb 2024 10:43:04 -0600 Subject: [PATCH 07/11] Fix designated initializers warning --- mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 013a31ee1d959..11b786261c619 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1594,7 +1594,7 @@ FailureOr mlir::linalg::collapseOpIterationDims( results.push_back(collapsedOpResult); } } - return CollapseResult{.results = results, .collapsedOp = collapsedOp}; + return CollapseResult{results, collapsedOp}; } namespace { From f6e1eca9b83fce8df27b29bd8515e62c82aaad15 Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 24 Feb 2024 19:40:34 -0600 Subject: [PATCH 08/11] Fix bug --- .../TransformOps/LinalgTransformOps.cpp | 1 + .../Linalg/Transforms/ElementwiseOpFusion.cpp | 22 +++++++++---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 1be7b261995fd..6a18d9742f2a7 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3264,6 +3264,7 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( if (failed(maybeFlattened)) return emitDefaultSilenceableFailure(target); results.push_back((*maybeFlattened).collapsedOp); + rewriter.replaceOp(target, (*maybeFlattened).results); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 11b786261c619..ef8beb979deff 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1479,17 +1479,7 @@ LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, rewriter, op, resultTypes, llvm::to_vector(llvm::concat(inputOperands, outputOperands))); - // TODO: Find a more general way to determine if op requires explicit - // indexing_maps and iterator_types - if (isa(op)) { - // Get the iterator types for the operand. - SmallVector iteratorTypes = llvm::map_to_vector( - getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo), - [&](utils::IteratorType itTy) { - return cast( - IteratorTypeAttr::get(rewriter.getContext(), itTy)); - }); - + if (op->hasAttr("indexing_maps")) { // Get the indexing maps. auto indexingMaps = llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) { @@ -1498,6 +1488,16 @@ LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, collapsedOp->setAttr("indexing_maps", rewriter.getAffineMapArrayAttr(indexingMaps)); + } + + if (op->hasAttr("iterator_types")) { + // Get the iterator types for the operand. + SmallVector iteratorTypes = llvm::map_to_vector( + getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo), + [&](utils::IteratorType itTy) { + return cast( + IteratorTypeAttr::get(rewriter.getContext(), itTy)); + }); collapsedOp->setAttr("iterator_types", rewriter.getArrayAttr(iteratorTypes)); } From a3503068a8a7594114cbcfc820047b77051cddce Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 27 Feb 2024 16:33:31 -0600 Subject: [PATCH 09/11] apply reviewer suggestion --- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 91 ++++++++++++------- 1 file changed, 60 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index ef8beb979deff..c0f3512e400ad 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1449,20 +1449,20 @@ void generateCollapsedIndexingRegion(Location loc, Block *block, } } -LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, - RewriterBase &rewriter) { +void collapseOperandsAndResults(LinalgOp op, + const CollapsingInfo &collapsingInfo, + RewriterBase &rewriter, + SmallVectorImpl &inputOperands, + SmallVectorImpl &outputOperands, + SmallVectorImpl &resultTypes) { Location loc = op->getLoc(); - - // Get the input operands. - SmallVector inputOperands = + inputOperands = llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) { return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo, rewriter); }); // Get the output operands and result types. - SmallVector resultTypes; - SmallVector outputOperands; resultTypes.reserve(op.getNumDpsInits()); outputOperands.reserve(op.getNumDpsInits()); for (OpOperand &output : op.getDpsInitsMutable()) { @@ -1474,35 +1474,64 @@ LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, if (!op.hasPureBufferSemantics()) resultTypes.push_back(newOutput.getType()); } +} - Operation *collapsedOp = clone( - rewriter, op, resultTypes, - llvm::to_vector(llvm::concat(inputOperands, outputOperands))); +/// Clone a `LinalgOp` to a collapsed version of same name +template +OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, + const CollapsingInfo &collapsingInfo) { + return nullptr; +} - if (op->hasAttr("indexing_maps")) { - // Get the indexing maps. - auto indexingMaps = - llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) { - return getCollapsedOpIndexingMap(map, collapsingInfo); - }); +/// Collapse any `LinalgOp` that does not require any specialization such as +/// indexing_maps, iterator_types, etc. +template <> +LinalgOp cloneToCollapsedOp(RewriterBase &rewriter, LinalgOp origOp, + const CollapsingInfo &collapsingInfo) { + SmallVector inputOperands, outputOperands; + SmallVector resultTypes; + collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands, + outputOperands, resultTypes); + return cast(clone( + rewriter, origOp, resultTypes, + llvm::to_vector(llvm::concat(inputOperands, outputOperands)))); +} - collapsedOp->setAttr("indexing_maps", - rewriter.getAffineMapArrayAttr(indexingMaps)); - } +/// Collapse a `GenericOp` +template <> +GenericOp cloneToCollapsedOp(RewriterBase &rewriter, + GenericOp origOp, + const CollapsingInfo &collapsingInfo) { + SmallVector inputOperands, outputOperands; + SmallVector resultTypes; + collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands, + outputOperands, resultTypes); + SmallVector indexingMaps( + llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) { + return getCollapsedOpIndexingMap(map, collapsingInfo); + })); - if (op->hasAttr("iterator_types")) { - // Get the iterator types for the operand. - SmallVector iteratorTypes = llvm::map_to_vector( - getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo), - [&](utils::IteratorType itTy) { - return cast( - IteratorTypeAttr::get(rewriter.getContext(), itTy)); - }); - collapsedOp->setAttr("iterator_types", - rewriter.getArrayAttr(iteratorTypes)); - } + SmallVector iteratorTypes(getCollapsedOpIteratorTypes( + origOp.getIteratorTypesArray(), collapsingInfo)); + + GenericOp collapsedOp = rewriter.create( + origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps, + iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); + Block *origOpBlock = &origOp->getRegion(0).front(); + Block *collapsedOpBlock = &collapsedOp->getRegion(0).front(); + rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, + collapsedOpBlock->getArguments()); + return collapsedOp; +} - return cast(collapsedOp); +LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, + RewriterBase &rewriter) { + if (GenericOp genericOp = dyn_cast(op.getOperation())) { + return cast( + cloneToCollapsedOp(rewriter, genericOp, collapsingInfo).getOperation()); + } else { + return cloneToCollapsedOp(rewriter, op, collapsingInfo); + } } /// Implementation of fusion with reshape operation by collapsing dimensions. From bdd04e234a7182ef885a69a7ef1bc47772887dd8 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 27 Feb 2024 16:40:44 -0600 Subject: [PATCH 10/11] remove unnecessary cast --- mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index c0f3512e400ad..a5b0aa3baeecb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1527,8 +1527,7 @@ GenericOp cloneToCollapsedOp(RewriterBase &rewriter, LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, RewriterBase &rewriter) { if (GenericOp genericOp = dyn_cast(op.getOperation())) { - return cast( - cloneToCollapsedOp(rewriter, genericOp, collapsingInfo).getOperation()); + return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo); } else { return cloneToCollapsedOp(rewriter, op, collapsingInfo); } From 12444bc7f3d2b7e14cb347cbb57adb78e8bf9330 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 28 Feb 2024 01:37:54 -0600 Subject: [PATCH 11/11] Address comment and elaborate summary --- .../mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td | 3 ++- mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 4 ++-- mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index d8d864d14ea69..53ed31877c6f2 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2307,7 +2307,8 @@ def FlattenElementwiseLinalgOp : Op { let description = [{ - Flattens elementwise linalg ops. + Flattens the iteration space and (applicable) operands of elementwise + linalg ops to a single dimension. Returns one handle: - Flattened linalg operation. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 6a18d9742f2a7..51f793475ea1b 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3263,8 +3263,8 @@ DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne( target, "only elementwise flattening is supported")); if (failed(maybeFlattened)) return emitDefaultSilenceableFailure(target); - results.push_back((*maybeFlattened).collapsedOp); - rewriter.replaceOp(target, (*maybeFlattened).results); + results.push_back(maybeFlattened->collapsedOp); + rewriter.replaceOp(target, maybeFlattened->results); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index a5b0aa3baeecb..59c55e504e0e3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1661,7 +1661,7 @@ class FoldWithProducerReshapeOpByCollapsing genericOp, "failed to do the fusion by collapsing transformation"); } - rewriter.replaceOp(genericOp, (*collapseResult).results); + rewriter.replaceOp(genericOp, collapseResult->results); return success(); } return failure(); @@ -1700,7 +1700,7 @@ class CollapseLinalgDimensions : public OpRewritePattern { if (!collapseResult) { return rewriter.notifyMatchFailure(op, "failed to collapse dimensions"); } - rewriter.replaceOp(op, (*collapseResult).results); + rewriter.replaceOp(op, collapseResult->results); return success(); }