Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [
getNumDynamicControlOperands() + getRank());
}

BlockArgument getTiedBlockArgument(OpResult opResult) {
assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult");
return getBody()->getArgument(getRank() + opResult.getResultNumber());
}

::mlir::Value getInductionVar(int64_t idx) {
return getInductionVars()[idx];
}
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
/// tiled in a manner that is consistent for all the passed slices. Note that
/// the method replaces the uses of `candidateSlices` with the tiled and fused
/// consumer value but does not delete the slice operations.
/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion
/// is to take the consumer operation, and find the slices to use for fusion
/// by walking its operands to the `loops` and then into the body to get the
/// slices used for fusion.
struct SCFFuseConsumerOfSliceResult {
// Original untiled consumer operands.
SmallVector<OpOperand *> origConsumerOperands;
Expand All @@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops);

/// Fuse the `consumer` operation into the loop nest provided by `loops`.
/// The transformation looks for operands in the `consumer` that are defined
/// by the outermost loop of the loop nest in `loops`. The nested loop is
/// expected to have the structure of the loops generated through tiling.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
MutableArrayRef<LoopLikeOpInterface> loops);

/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
FailureOr<SmallVector<scf::ForOp>>
Expand Down
216 changes: 167 additions & 49 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
// Again assume that all the outer loops are scf.for operations.
auto outerForLoop = cast<scf::ForOp>(outerLoop);
auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
auto outerLoopYield =
cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
SmallVector<Value> newYields =
Expand Down Expand Up @@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter,
return clonedSlices;
}

/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
mlir::scf::tileAndFuseConsumerOfSlices(
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops) {
if (candidateSlices.empty()) {
return rewriter.notifyMatchFailure(
rewriter.getUnknownLoc(),
"no candidate slices provided for consumer fusion");
}
// Return if `loops` is empty, return an error for now. Caller is expected
// to handle this case.
if (loops.empty()) {
return rewriter.notifyMatchFailure(
candidateSlices.front(),
"cannot call tile and fuse consumer with an empty loop nest");
}
static FailureOr<scf::SCFFuseConsumerOfSliceResult>
tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp,
ArrayRef<OpOperand *> consumerOpOperands,
ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "expected loops to be not empty");

if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
llvm::all_of(candidateSlices,
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
// 1. Check assumption for loop with `reorderOperations` disabled.
if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) {
return rewriter.notifyMatchFailure(
candidateSlices.front(),
"candidates slices need to be all `tensor.extract_slice`s or "
"`tensor.parallel_insert_slice`s");
}

// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
SmallVector<OpOperand *> consumerOpOperands;
Operation *consumerOp;
{
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSlices.front(),
"could not fetch consumer to fuse");
}
std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
consumerOp = consumerOpOperands.front()->getOwner();
loops.front(), "the first user of loop should not dominate any define "
"of consumer operand(s)");
}

LoopLikeOpInterface outerMostLoop = loops.front();
LoopLikeOpInterface innerMostLoop = loops.back();

// Check assumption for loop with `reorderOperations` disabled.
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
return rewriter.notifyMatchFailure(
outerMostLoop, "the first user of loop should not dominate any define "
"of consumer operand(s)");
}

OpBuilder::InsertionGuard g(rewriter);

// 2. Check consumer is not using scf loop's output as init.
auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
if (!dstOp)
Expand Down Expand Up @@ -2428,11 +2391,166 @@ mlir::scf::tileAndFuseConsumerOfSlices(
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
});
auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
return scf::SCFFuseConsumerOfSliceResult{
std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
std::move(tileAndFuseResult->tiledOps)};
}

/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
mlir::scf::tileAndFuseConsumerOfSlices(
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops) {
if (candidateSlices.empty()) {
return rewriter.notifyMatchFailure(
rewriter.getUnknownLoc(),
"no candidate slices provided for consumer fusion");
}
// Return if `loops` is empty, return an error for now. Caller is expected
// to handle this case.
if (loops.empty()) {
return rewriter.notifyMatchFailure(
candidateSlices.front(),
"cannot call tile and fuse consumer with an empty loop nest");
}

if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
llvm::all_of(candidateSlices,
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
return rewriter.notifyMatchFailure(
candidateSlices.front(),
"candidates slices need to be all `tensor.extract_slice`s or "
"`tensor.parallel_insert_slice`s");
}

// Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperands =
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
if (failed(maybeConsumerOpOperands)) {
return rewriter.notifyMatchFailure(candidateSlices.front(),
"could not fetch consumer to fuse");
}
Operation *consumerOp = maybeConsumerOpOperands->front()->getOwner();

return tileAndFuseConsumerOfSlicesImpl(rewriter, consumerOp,
maybeConsumerOpOperands.value(),
candidateSlices, loops);
}

/// For a given `result` of a `forallOp` return the
/// `tensor.parallel_insert_slice` op (or combining op) that is used to
/// construct this result.
static std::optional<Operation *>
getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
if (result.getOwner() != forallOp)
return std::nullopt;
BlockArgument bbArg = forallOp.getTiedBlockArgument(result);
SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
// If the number of combining ops is not 1, then this is unexpected. Return
// nullopt.
if (combiningOps.size() != 1)
return std::nullopt;
return combiningOps[0];
}

/// For a given result of the loop nest that is a tiled loop nest, return the
/// insert slice-like op that is used for consumer fusion
static std::optional<Operation *>
getProducingInsertSliceLikeOp(OpResult result,
ArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "Expected loops to be not empty");
LoopLikeOpInterface outerMostLoop = loops.front();
if (auto forallOp = dyn_cast<scf::ForallOp>(outerMostLoop.getOperation())) {
assert(loops.size() == 1 &&
"expected only a single loop when tiling using scf.forall");
return getProducingParallelInsertSlice(forallOp, result);
}
// Assume that the loop nest is a nested `scf.for` that is created through
// tiling and retrieve the `tensor.insert_slice` operation used to construct
// the result.
while (loops.size() != 1) {
LoopLikeOpInterface loop = loops.front();
if (result.getOwner() != loop)
return std::nullopt;
auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
if (!forOp)
return std::nullopt;
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
auto innerForResult =
dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
if (!innerForResult)
return std::nullopt;
result = innerForResult;
loops = loops.drop_front();
}
LoopLikeOpInterface loop = loops.front();
if (result.getOwner() != loop)
return std::nullopt;
auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
if (!forOp)
return std::nullopt;
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
auto insertSliceOp = yieldOp.getOperand(result.getResultNumber())
.getDefiningOp<tensor::InsertSliceOp>();
if (!insertSliceOp)
return std::nullopt;
return insertSliceOp;
}

FailureOr<scf::SCFFuseConsumerOfSliceResult>
mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
MutableArrayRef<LoopLikeOpInterface> loops) {
if (!isa<TilingInterface>(consumer)) {
return rewriter.notifyMatchFailure(
consumer, "unhandled consumer that does not implement TilingInterface");
}

// Return if `loops` is empty, return an error for now. Caller is expected
// to handle this case.
if (loops.empty()) {
return rewriter.notifyMatchFailure(
consumer, "cannot call tile and fuse consumer with an empty loop nest");
}

LoopLikeOpInterface outermostLoop = loops.front();

// Collect the operands of the consumer that come from the outermost loop of
// the loop nest.
SmallVector<OpOperand *> consumerFusableOperands;
for (OpOperand &opOperand : consumer->getOpOperands()) {
if (opOperand.get().getDefiningOp() == outermostLoop) {
consumerFusableOperands.push_back(&opOperand);
}
}

// Nothing to fuse. Just return an empty set.
if (consumerFusableOperands.empty()) {
return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
SmallVector<OpOperand *>{},
SmallVector<Operation *>{}};
}

// Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
// for fusion.
SmallVector<Operation *> candidateSlices;
candidateSlices.reserve(consumerFusableOperands.size());
for (OpOperand *opOperand : consumerFusableOperands) {
std::optional<Operation *> slice =
getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
if (!slice) {
return rewriter.notifyMatchFailure(
consumer,
"couldnt find producing insert-slice like operation for operand");
}
candidateSlices.push_back(slice.value());
}
return tileAndFuseConsumerOfSlicesImpl(
rewriter, consumer, consumerFusableOperands, candidateSlices, loops);
}

//===----------------------------------------------------------------------===//
// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ module {
// Fuse the consumer operation into the tiled loop.
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
transform.test.fuse_consumer %slice_op in (%forall_op)
transform.test.fuse_consumer_using_slice %slice_op in (%forall_op)
: (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
Expand Down Expand Up @@ -231,7 +231,7 @@ module {
// Fuse the consumer operation into the tiled loop.
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
// Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
// Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice
// is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
// to fuse" error.
transform.yield
Expand Down
Loading