Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2295,6 +2295,48 @@ def ConvertConv2DToImg2ColOp : Op<Transform_Dialect,
}];
}

//===----------------------------------------------------------------------===//
// FlattenElementwiseLinalgOp
//===----------------------------------------------------------------------===//

def FlattenElementwiseLinalgOp : Op<Transform_Dialect,
"structured.flatten_elementwise",
[FunctionalStyleTransformOpTrait,
MemoryEffectsOpInterface,
TransformOpInterface,
TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Flattens elementwise linalg ops.

Returns one handle:
- Flattened linalg operation.

#### Return modes:

Returns a definite failure if target is not isolated from above.
Returns a silenceable failure if the pattern application failed.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);

let assemblyFormat =
"$target attr-dict `:` functional-type($target, results)";

let builders = [
OpBuilder<(ins "Value":$target)>
];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::linalg::LinalgOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

//===----------------------------------------------------------------------===//
// Transpose Conv2D
//===----------------------------------------------------------------------===//
Expand Down
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1074,16 +1074,20 @@ bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence);
bool areDimSequencesPreserved(ArrayRef<AffineMap> maps,
ArrayRef<ReassociationIndices> dimSequences);

struct CollapseResult {
SmallVector<Value> results;
LinalgOp collapsedOp;
};

/// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition
/// to calling this method is that for each list in `foldedIterationDim`, the
/// sequence of dimensions is contiguous in domains of all `indexing_maps` of
/// the `linalgOp`. This can be checked using `areDimSequencePreserved` method.
/// When valid, the method also collapses the operands of the op. Returns
/// replacement values of the results of the original `linalgOp` by inserting
/// reshapes to get back values of compatible types.
template <typename LinalgType>
FailureOr<SmallVector<Value>>
collapseOpIterationDims(LinalgType op,
FailureOr<CollapseResult>
collapseOpIterationDims(LinalgOp op,
ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter);

Expand Down
24 changes: 24 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3243,6 +3243,30 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// FlattenElementwiseLinalgOp.
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::FlattenElementwiseLinalgOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
if (target.getNumLoops() <= 1)
return DiagnosedSilenceableFailure::success();
ReassociationIndices reassociation(target.getNumLoops());
std::iota(reassociation.begin(), reassociation.end(), 0);
auto maybeFlattened =
(isElementwise(target))
? collapseOpIterationDims(target, reassociation, rewriter)
: FailureOr<CollapseResult>(rewriter.notifyMatchFailure(
target, "only elementwise flattening is supported"));
if (failed(maybeFlattened))
return emitDefaultSilenceableFailure(target);
results.push_back((*maybeFlattened).collapsedOp);
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// TransposeConv2DOp
//===----------------------------------------------------------------------===//
Expand Down
88 changes: 40 additions & 48 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1449,12 +1449,8 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
}
}

template <typename LinalgType>
Operation *createCollapsedOp(LinalgType op,
const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter) {
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
"unsupported linalg op type to create");
LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
RewriterBase &rewriter) {
Location loc = op->getLoc();

// Get the input operands.
Expand All @@ -1479,40 +1475,40 @@ Operation *createCollapsedOp(LinalgType op,
resultTypes.push_back(newOutput.getType());
}

if (isa<linalg::CopyOp>(op)) {
return rewriter.create<linalg::CopyOp>(loc, inputOperands[0],
outputOperands[0]);
}

// Get the iterator types for the operand.
SmallVector<utils::IteratorType> iteratorTypes =
getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo);
Operation *collapsedOp = clone(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the operation being cloned?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow up.. I see why you are cloning it.... This is an interesting way of doing it. Does it work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does.

rewriter, op, resultTypes,
llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands)));

// TODO: Find a more general way to determine if op requires explicit
// indexing_maps and iterator_types
if (isa<linalg::GenericOp>(op)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think you need to do this. All LinalgOps have getIteratorTypesArray (or should have). Also if you dont clone the op you dont need to set the indexing maps etc. explicitly....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah i know all LinalgOps have these attrs and the corresponding getters. But most of the structured ops don't have an explicit attr for indexing maps and iterator types. so if i unconditionally try to set them, i will end up with named ops with additional attrs that aren't at all associated with the implicit ones. I'm sure there's a way to do this correctly. I just dont know.

I'm not sure how any of this would work without cloning the op (other than having a switch statement checking every single linalg op that exists). Of course the alternative is to just always convert to a linalg.generic op. But I dont like that. I much prefer keeping the named op and that's why I chose to clone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm. there is this kMemoizedIndexingMapsAttrName that's for named ops. so i could check if the op has this and then replace it. otherwise replace the expicit indexing maps. but i dont see a similar attr name for iterator types

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it woudl be nice if there was a general setIndexingMaps and setIteratoryTypes for arbitrary LinalgOps

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One alternative is that you could have a interface method to LinalgOp that allows you to clone with a given indexing_maps and iterator_types and just uses the region from the original operation. Each specific op could just implement its own version (including linalg.generic).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a bad idea. Id have to see what it looks like to understand the implications fully but I sense that it might be too demanding to require all ops to implement that. Would this work for the ops generated from core_named_ops.py (can't remember the exact file name right now)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To start with you can just have the method return a LogicalResult and the default implementation return failure()and only add the implementation for ops that you want to support. Those can be filled in over time.

On second thought maybe instead of going full interface method, just add a templated method in this file for cloning? The default can be for handling named ops, and the generic op can get its specialization

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback. I wont be able to get back to this until this weekend or next, but I'll take a look at your suggestion and keep you updated.

// Get the iterator types for the operand.
SmallVector<Attribute> iteratorTypes = llvm::map_to_vector(
getCollapsedOpIteratorTypes(op.getIteratorTypesArray(), collapsingInfo),
[&](utils::IteratorType itTy) {
return cast<Attribute>(
IteratorTypeAttr::get(rewriter.getContext(), itTy));
});

// Get the indexing maps.
auto indexingMaps =
llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
});
// Get the indexing maps.
auto indexingMaps =
llvm::map_to_vector(op.getIndexingMapsArray(), [&](AffineMap map) {
return getCollapsedOpIndexingMap(map, collapsingInfo);
});

Operation *collapsedOp = rewriter.create<linalg::GenericOp>(
loc, resultTypes, inputOperands, outputOperands, indexingMaps,
iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {});
Block *origOpBlock = &op->getRegion(0).front();
Block *collapsedOpBlock = &collapsedOp->getRegion(0).front();
rewriter.mergeBlocks(origOpBlock, collapsedOpBlock,
collapsedOpBlock->getArguments());
collapsedOp->setAttr("indexing_maps",
rewriter.getAffineMapArrayAttr(indexingMaps));
collapsedOp->setAttr("iterator_types",
rewriter.getArrayAttr(iteratorTypes));
}

return collapsedOp;
return cast<LinalgOp>(collapsedOp);
}

/// Implementation of fusion with reshape operation by collapsing dimensions.
template <typename LinalgType>
FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
LinalgType op, ArrayRef<ReassociationIndices> foldedIterationDims,
FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims,
RewriterBase &rewriter) {
static_assert(llvm::is_one_of<LinalgType, GenericOp, CopyOp>::value,
"unsupported linalg op type to collapse");

// Bail on trivial no-op cases.
if (op.getNumLoops() <= 1 || foldedIterationDims.empty() ||
llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) {
Expand Down Expand Up @@ -1541,8 +1537,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
}

// Bail on non-canonical ranges.
SmallVector<Range> loopRanges =
cast<LinalgOp>(op.getOperation()).createLoopRanges(rewriter, op.getLoc());
SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc());
auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) {
if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr))
return cast<IntegerAttr>(attr).getInt() == value;
Expand All @@ -1558,8 +1553,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
op, "expected all loop ranges to have zero start and unit stride");
}

LinalgType collapsedOp = cast<LinalgType>(
createCollapsedOp<LinalgType>(op, collapsingInfo, rewriter));
LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter);

Location loc = op->getLoc();
if (collapsedOp.hasIndexSemantics()) {
Expand Down Expand Up @@ -1600,7 +1594,7 @@ FailureOr<SmallVector<Value>> mlir::linalg::collapseOpIterationDims(
results.push_back(collapsedOpResult);
}
}
return results;
return CollapseResult{results, collapsedOp};
}

namespace {
Expand Down Expand Up @@ -1632,15 +1626,14 @@ class FoldWithProducerReshapeOpByCollapsing
continue;
}

std::optional<SmallVector<Value>> replacements =
collapseOpIterationDims<linalg::GenericOp>(
genericOp, collapsableIterationDims, rewriter);
if (!replacements) {
std::optional<CollapseResult> collapseResult = collapseOpIterationDims(
genericOp, collapsableIterationDims, rewriter);
if (!collapseResult) {
return rewriter.notifyMatchFailure(
genericOp, "failed to do the fusion by collapsing transformation");
}

rewriter.replaceOp(genericOp, *replacements);
rewriter.replaceOp(genericOp, (*collapseResult).results);
return success();
}
return failure();
Expand Down Expand Up @@ -1674,13 +1667,12 @@ class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
op, "specified dimensions cannot be collapsed");
}

std::optional<SmallVector<Value>> replacements =
collapseOpIterationDims<LinalgType>(op, collapsableIterationDims,
rewriter);
if (!replacements) {
std::optional<CollapseResult> collapseResult =
collapseOpIterationDims(op, collapsableIterationDims, rewriter);
if (!collapseResult) {
return rewriter.notifyMatchFailure(op, "failed to collapse dimensions");
}
rewriter.replaceOp(op, *replacements);
rewriter.replaceOp(op, (*collapseResult).results);
return success();
}

Expand Down
99 changes: 99 additions & 0 deletions mlir/test/Dialect/Linalg/flatten-elementwise.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s

// CHECK-LABEL: func.func @fill(
// CHECK-SAME: %[[ARG0:.*]]: f32,
// CHECK-SAME: %[[ARG1:.*]]: memref<32x7xf32>
// CHECK-NEXT: %[[FLATTENED:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : memref<224xf32>)
func.func @fill(%cst: f32, %arg: memref<32x7xf32>) {
linalg.fill ins(%cst: f32) outs(%arg: memref<32x7xf32>)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @fill_tensor(
// CHECK-SAME: %[[ARG0:.*]]: f32,
// CHECK-SAME: %[[ARG1:.*]]: tensor<32x7xf32>
// CHECK-NEXT: %[[FLATTENED:.*]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_RESULT:.*]] = linalg.fill ins(%[[ARG0]] : f32) outs(%[[FLATTENED]] : tensor<224xf32>)
// CHECK-NEXT: %[[RESULT:.*]] = tensor.expand_shape %[[FLATTENED_RESULT]] {{\[}}[0, 1]]
func.func @fill_tensor(%cst: f32, %arg: tensor<32x7xf32>) -> tensor<32x7xf32> {
%0 = linalg.fill ins(%cst: f32) outs(%arg: tensor<32x7xf32>) -> tensor<32x7xf32>
return %0 : tensor<32x7xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @map(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
// CHECK-NEXT: linalg.map { arith.addf } ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
func.func @map(%arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
linalg.map {arith.addf} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @generic
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<32x7xf32>
// CHECK-NEXT: %[[FLATTENED_0:.*]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_1:.*]] = memref.collapse_shape %[[ARG1]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[FLATTENED_2:.*]] = memref.collapse_shape %[[ARG2]] {{\[}}[0, 1]]
// CHECK-NEXT: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[FLATTENED_0]], %[[FLATTENED_1]] : memref<224xf32>, memref<224xf32>) outs(%[[FLATTENED_2]] : memref<224xf32>)
// CHECK-NEXT: ^bb0(%[[A:.*]]: f32, %[[B:.*]]: f32, %[[C:.*]]: f32)
// CHECK-NEXT: %[[SUM:.*]] = arith.addf %[[A]], %[[B]]
// CHECK-NEXT: linalg.yield %[[SUM]]
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @generic( %arg0: memref<32x7xf32>, %arg1: memref<32x7xf32>, %arg2: memref<32x7xf32>) {
linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1: memref<32x7xf32>, memref<32x7xf32>) outs(%arg2: memref<32x7xf32>) {
^bb0(%a: f32, %b: f32, %c: f32):
%0 = arith.addf %a, %b : f32
linalg.yield %0 : f32
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%flattened = transform.structured.flatten_elementwise %0
: (!transform.any_op) -> !transform.any_op
transform.yield
}
}