Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2295,6 +2295,48 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
}];
}

//===----------------------------------------------------------------------===//
// FlattenElementwiseLinalgOp
//===----------------------------------------------------------------------===//

def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
"structured.flatten_elementwise",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformOpInterface,
TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Flattens elementwise linalg ops.

Returns one handle:
- Flattened linalg operation.

#### Return modes:

Returns a definite failure if target is not isolated from above.
Returns a silenceable failure if the pattern application failed.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);

let assemblyFormat =
"$target attr-dict `:` functional-type($target, results)";

let builders = [
OpBuilder<(ins "Value":$target)>
];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::linalg::LinalgOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

//===----------------------------------------------------------------------===//
// Transpose Conv2D
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1074,16 +1074,20 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
ArrayRef<ReassociationIndices> dimSequences);

struct CollapseResult {
SmallVector<Value> results;
LinalgOp collapsedOp;
};

/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
/// to calling this method is that for each list in `foldedIterationDim`, the
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
/// When valid, the method also collapses the operands of the op. Returns
/// replacement values of the results of the original `linalgOp` by inserting
/// reshapes to get back values of compatible types.
template <typename LinalgType>
FailureOr<SmallVector<Value>>
collapseOpIterationDims(LinalgType op,
FailureOr<CollapseResult>
collapseOpIterationDims(LinalgOp op,
ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);

Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3243,6 +3243,31 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// FlattenElementwiseLinalgOp.
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
if (target.getNumLoops() <= 1)
return DiagnosedSilenceableFailure::success();
ReassociationIndices reassociation(target.getNumLoops());
std::iota(reassociation.begin(), reassociation.end(), 0);
auto maybeFlattened =
(isElementwise(target))
? collapseOpIterationDims(target, reassociation, rewriter)
: FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
target, "only elementwise flattening is supported"));
if (failed(maybeFlattened))
return emitDefaultSilenceableFailure(target);
results.push_back((*maybeFlattened).collapsedOp);
rewriter.replaceOp(target, (*maybeFlattened).results);
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// TransposeConv2DOp
//===----------------------------------------------------------------------===//
Expand Down
86 changes: 39 additions & 47 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1449,12 +1449,8 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
}
}

template <typename LinalgType>
Operation *createCollapsedOp(LinalgType op,
const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter) {
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
"unsupported linalg op type to create");
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter) {
Location loc = op->getLoc();

// Get the input operands.
Expand All @@ -1479,40 +1475,40 @@ Operation *createCollapsedOp(LinalgType op,
resultTypes.push_back(newOutput.getType());
}

if (isa<linalg::CopyOp>(op)) {
return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
outputOperands[0]);
}
Operation *collapsedOp = clone(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the operation being cloned?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow up.. I see why you are cloning it.... This is an interesting way of doing it. Does it work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does.

rewriter, op, resultTypes,
llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));

// Get the iterator types for the operand.
SmallVector<utils::IteratorType> iteratorTypes =
getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
if (op->hasAttr("indexing_maps")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems iffy to me. THis is looking into internal implementation details of the LinalgOp representation in C++ and is hard to maintain. Let me give you an example. I have no idea what the indexing maps or iterator types are stored as. I just use the utility methods on ops to get this information. Side-stepping that can introduce silent bugs if something changes in the implementation of the operation.

I understand you are trying to generalize these things.. One way to do that would be to define a method

template <typename OpTy>
Operation * cloneWithIndexingMapsIteratorTypesAndOperands(RewriterBase &rewriter, OpTy origOp, TypeRange resultTypes, ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorType, ValueRange inputOperands, ValueRange outputOperands) {
  return nullptr;
}

template <>
Operation *cloneWithIndexingMapsIteratorTypesAndOperands<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp, TypeRange resultTypes, ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorType, ValueRange inputOperands, ValueRange outputOperands) {
 return clone(rewriter, origOp, resultTypes, inputOperands, outputOperands);
}

template <>
Operation *cloneWithIndexingMapsIteratorTypesAndOperands<GenericOp>(RewriterBase &rewriter, GenericOp origOp, TypeRange resultTypes, ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorType, ValueRange inputOperands, ValueRange outputOperands) {
SmallVector<utils::IteratorType> iteratorTypes =
      getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);

  // Get the indexing maps.
  auto indexingMaps =
      llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
        return getCollapsedOpIndexingMap(map, collapsingInfo);
      });

  Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
      loc, resultTypes, inputOperands, outputOperands, indexingMaps,
      iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
  Block *origOpBlock = &op->getRegion(0).front();
  Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
  rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
                       collapsedOpBlock->getArguments());
}

Then you can

if (auto genericOp = dyn_cast<GenericOp>(op)) {
  cloneWithIndexingMapsIteratorTypesAndOperands(rewriter, genericOp, ....)
} else {
  cloneWithIndexingMapsIteratorTypesAndOperands(rewriter, cast<LinalgOp>(op), ...)
}

I think that gives you what you want and doesnt leak internal implementation details of the op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yah i totally agree this recent change is garbage. shouldn't have even pushed it :).

i'm going to go with your suggestion. however, just so i understand correctly, this is functionally equivalent to what i had before (checking if generic), but you are suggesting this for cleanliness and maintainability?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. We dont need to look at whether an attribute is present or not.. So a bit more cleaner?

Copy link
Contributor Author

@srcarroll srcarroll Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. do you think these cloneWithIndexingMapsIteratorTypesAndOperands should be static functions specific to this file? or as part of some lib?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i pushed a modified version of your suggestion. First, i templated the input and output op to enforce that the cloning should be same named op. Second, i didn't want to have both collapsed operands and collapsingInfo as arguments to a function because that leaves room for inconsistency. And the collapsingInfo is needed for the generic's indexing map. So I did more refactoring with that in mind.

// Get the indexing maps.
auto indexingMaps =
llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
});

// Get the indexing maps.
auto indexingMaps =
llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
});
collapsedOp->setAttr("indexing_maps",
rewriter.getAffineMapArrayAttr(indexingMaps));
}

Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
loc, resultTypes, inputOperands, outputOperands, indexingMaps,
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
Block *origOpBlock = &op->getRegion(0).front();
Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
collapsedOpBlock->getArguments());
if (op->hasAttr("iterator_types")) {
// Get the iterator types for the operand.
SmallVector<Attribute> iteratorTypes = llvm::map_to_vector(
getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo),
[&](utils::IteratorType itTy) {
return cast<Attribute>(
IteratorTypeAttr::get(rewriter.getContext(), itTy));
});
collapsedOp->setAttr("iterator_types",
rewriter.getArrayAttr(iteratorTypes));
}

return collapsedOp;
return cast<LinalgOp>(collapsedOp);
}

/// Implementation of fusion with reshape operation by collapsing dimensions.
template <typename LinalgType>
FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter) {
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
"unsupported linalg op type to collapse");

// Bail on trivial no-op cases.
if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
Expand Down Expand Up @@ -1541,8 +1537,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
}

// Bail on non-canonical ranges.
SmallVector<Range> loopRanges =
cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
return cast<IntegerAttr>(attr).getInt() == value;
Expand All @@ -1558,8 +1553,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
op, "expected all loop ranges to have zero start and unit stride");
}

LinalgType collapsedOp = cast<LinalgType>(
createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);

Location loc = op->getLoc();
if (collapsedOp.hasIndexSemantics()) {
Expand Down Expand Up @@ -1600,7 +1594,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
results.push_back(collapsedOpResult);
}
}
return results;
return CollapseResult{results, collapsedOp};
}

namespace {
Expand Down Expand Up @@ -1632,15 +1626,14 @@ class FoldWithProducerReshapeOpByCollapsing
continue;
}

std::optional<SmallVector<Value>> replacements =
collapseOpIterationDims<linalg::GenericOp>(
genericOp, collapsableIterationDims, rewriter);
if (!replacements) {
std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
genericOp, collapsableIterationDims, rewriter);
if (!collapseResult) {
return rewriter.notifyMatchFailure(
genericOp, "failed to do the fusion by collapsing transformation");
}

rewriter.replaceOp(genericOp, *replacements);
rewriter.replaceOp(genericOp, (*collapseResult).results);
return success();
}
return failure();
Expand Down Expand Up @@ -1674,13 +1667,12 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
op, "specified dimensions cannot be collapsed");
}

std::optional<SmallVector<Value>> replacements =
collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
rewriter);
if (!replacements) {
std::optional<CollapseResult> collapseResult =
collapseOpIterationDims(op, collapsableIterationDims, rewriter);
if (!collapseResult) {
return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
}
rewriter.replaceOp(op, *replacements);
rewriter.replaceOp(op, (*collapseResult).results);
return success();
}

Expand Down
99 changes: 99 additions & 0 deletions mlir/test/Dialect/Linalg/flatten-elementwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @fill(
// CHECK-SAME: %[[ARG0:.*]]: f32,
// CHECK-SAME: %[[ARG1:.*]]: memref<32x7xf32>
// CHECK-NEXT: %[[FLATTENED:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : memref<224xf32>)
func.func @fill(%cst: f32, %arg: memref<32x7xf32>) {
linalg.fill ins(%cst: f32) outs(%arg: memref<32x7xf32>)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @fill_tensor(
// CHECK-SAME: %[[ARG0:.*]]: f32,
// CHECK-SAME: %[[ARG1:.*]]: tensor<32x7xf32>
// CHECK-NEXT: %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>)
// CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]]
func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> {
%0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) -> tensor<32x7xf32>
return %0 : tensor<32x7xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @map(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
// CHECK-NEXT: linalg.map { arith.addf } ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
func.func @map(%arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @generic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
// CHECK-NEXT: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
// CHECK-NEXT: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32)
// CHECK-NEXT: %[[SUM:.*]] = arith.addf %[[A]], %[[B]]
// CHECK-NEXT: linalg.yield %[[SUM]]
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @generic( %arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) {
^bb0(%a: f32, %b: f32, %c: f32):
%0 = arith.addf %a, %b : f32
linalg.yield %0 : f32
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}