Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,29 +101,39 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
return true;
}

void getBackwardSliceIncludingUsesFromAbove(
Operation *op, SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options) {
BackwardSliceOptions wrappedOptions(options);
wrappedOptions.filter = [&](Operation *op) {
if (!options.filter(op)) {
return false;
}
BackwardSliceOptions regionOptions(wrappedOptions);
regionOptions.inclusive = true;
mlir::visitUsedValuesDefinedAbove(
op->getRegions(), [&](OpOperand *operand) {
getBackwardSlice(operand->get(), backwardSlice, regionOptions);
});
return true;
};

getBackwardSlice(op, backwardSlice, wrappedOptions);
}

bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
const DominanceInfo &dominanceInfo,
Operation *seedOp) {
assert(dominanceInfo.properlyDominates(seedOp, op) &&
op->getParentRegion() == seedOp->getParentRegion());

BackwardSliceOptions options;
// Limit the slice to the seed to make sure the slice is small.
options.filter = [&](Operation *op) {
return !dominanceInfo.properlyDominates(op, seedOp);
};
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice, options);

// `getBackwardSlice` doesnt track uses from within an ops region, so make
// sure there are no values defined above.
for (Operation *sliceOp : slice) {
bool usesValuesFromAbove = false;
mlir::visitUsedValuesDefinedAbove(
sliceOp->getRegions(), [&](void *) { usesValuesFromAbove = true; });
if (usesValuesFromAbove) {
return false;
}
}
getBackwardSliceIncludingUsesFromAbove(op, &slice, options);

return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
return slice.contains(groupedOp);
Expand Down
15 changes: 9 additions & 6 deletions compiler/src/iree/compiler/DispatchCreation/FusionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
const DominanceInfo &dominanceInfo, Operation *seedOp);

/// Wrapps `filter` so that operations used in a region of an op also get
/// included in the backward slice. This breaks the topological sorting of the
/// original `getBackwardsSlice` but isn't nessasary for the uses here.
/// TODO: Upstream this as a part of `getBackwardsSlice`.
void getBackwardSliceIncludingUsesFromAbove(
Operation *op, SetVector<Operation *> *backwardSlice,
const BackwardSliceOptions &options);

/// Moves the operands and transitive defs for each op in `operations` directly
/// after `insertionPoint`. Note: this does not check if it is legal to move the
/// operands.
Expand All @@ -44,16 +52,11 @@ moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
!ignoreOperationsSet.contains(op);
};
// Set inclusive to true cause the slice is computed from the operand, and
// we want to include the defining op (which is the point here)
options.inclusive = true;

llvm::SetVector<Operation *> slice;
for (auto op : operations) {
assert(insertionPoint->getBlock() == op->getBlock());
for (auto operand : op->getOperands()) {
getBackwardSlice(operand, &slice, options);
}
getBackwardSliceIncludingUsesFromAbove(op, &slice, options);
}

mlir::topologicalSort(slice);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,70 @@ util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<
%8 = arith.addf %arg2, %cst : f32
linalg.yield %8 : f32
} -> tensor<5x5xf32>
// expected-note @below {{prior use here}}
%collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
%5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%arg2: f32, %arg3: f32):
%8 = arith.subf %arg2, %cst_0 : f32
%cst_2 = arith.constant 2.000000e+00 : f32
%8 = arith.subf %arg2, %cst_2 : f32
linalg.yield %8 : f32
} -> tensor<5x5xf32>
util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32>
}
// CHECK-LABEL: util.func public @fuse_by_moving_consumer
// CHECK: linalg.generic
// CHECK-NOT: linalg.generic

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
util.func public @dont_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 2.000000e+00 : f32
%cst_1 = arith.constant 3.000000e+00 : f32
%0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.addf %in, %cst : f32
linalg.yield %2 : f32
} -> tensor<5x5xf32>
%collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%in: f32, %out: f32):
%c2 = arith.constant 2 : index
%extracted = tensor.extract %collapsed[%c2] : tensor<25xf32>
%2 = arith.addf %extracted, %extracted : f32
linalg.yield %2 : f32
} -> tensor<5x5xf32>
util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32>
}

// CHECK-LABEL: util.func public @dont_fuse_use_from_above
// CHECK: linalg.generic
// CHECK: linalg.generic


// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
util.func public @do_fuse_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
%cst = arith.constant 1.000000e+00 : f32
%cst_0 = arith.constant 2.000000e+00 : f32
%cst_1 = arith.constant 3.000000e+00 : f32
%0 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%in: f32, %out: f32):
%2 = arith.addf %in, %cst : f32
linalg.yield %2 : f32
} -> tensor<5x5xf32>
%collapsed = tensor.collapse_shape %0 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
^bb0(%in: f32, %out: f32):
%c2 = arith.constant 2 : index
%extracted = tensor.extract %arg0[%c2, %c2] : tensor<5x5xf32>
%2 = arith.addf %extracted, %extracted : f32
linalg.yield %2 : f32
} -> tensor<5x5xf32>
util.return %1, %collapsed : tensor<5x5xf32>, tensor<25xf32>
}

// CHECK-LABEL: util.func public @do_fuse_use_from_above
// CHECK: linalg.generic
// CHECK-NOT: linalg.generic