Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ struct FuseTilableForallConsumers final
}

tensor::ParallelInsertSliceOp producerSlice;
scf::ForallOp sliceOwner;
LoopLikeOpInterface sliceOwner;
Value fusionOperand;
for (auto operand : dpsOp.getDpsInputs()) {
auto forallProducer = operand.getDefiningOp<scf::ForallOp>();
Expand Down Expand Up @@ -320,7 +320,7 @@ struct FuseTilableForallConsumers final
}

FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
scf::tileAndFuseConsumerOfSlice(rewriter, producerSlice);
scf::tileAndFuseConsumerOfSlice(rewriter, producerSlice, {sliceOwner});
if (failed(fuseConsumerResults)) {
return failure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ static bool areAllStaticLoopBounds(scf::ForallOp forallOp) {

/// Find dimensions of the loop that are unit-trip count and drop them from the
/// distributed dimensions.
static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
scf::ForallOp forallOp) {
static FailureOr<scf::ForallOp>
dropUnitDistributedDims(RewriterBase &rewriter, scf::ForallOp forallOp) {
SmallVector<OpFoldResult> mixedLbs = forallOp.getMixedLowerBound();
SmallVector<OpFoldResult> mixedUbs = forallOp.getMixedUpperBound();
SmallVector<OpFoldResult> mixedSteps = forallOp.getMixedStep();
Expand All @@ -261,7 +261,7 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
}
}
if (droppedLoops.empty()) {
return success();
return forallOp;
}

OpBuilder::InsertionGuard g(rewriter);
Expand Down Expand Up @@ -303,7 +303,7 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
rewriter.mergeBlocks(oldLoopBody, newLoopBody, argReplacements);

rewriter.replaceOp(forallOp, newForallOp.getResults());
return success();
return newForallOp;
}

//===---------------------------------------------------------------------===//
Expand All @@ -314,8 +314,9 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter,
// Returns a list of new `tensor.extract_slice` ops with new fusion
// opportunities, as well as the new surrounding `scf.forall` (because consumer
// fusion replaces the loop).
static std::pair<std::queue<Operation *>, scf::ForallOp>
fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
static std::queue<Operation *>
fuseConsumers(RewriterBase &rewriter, Operation *tiledOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
auto addCandidateSlices =
[](Operation *fusedOp,
std::queue<tensor::ParallelInsertSliceOp> &candidates) {
Expand All @@ -333,15 +334,15 @@ fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
addCandidateSlices(tiledOp, candidates);

std::queue<Operation *> newFusionOpportunities;
scf::ForallOp newLoop = tiledOp->getParentOfType<scf::ForallOp>();
while (!candidates.empty()) {

// Traverse the slices in BFS fashion.
tensor::ParallelInsertSliceOp candidateSliceOp = candidates.front();
candidates.pop();

FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp);
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp,
loops);
if (failed(fusedResult)) {
LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: "
<< candidateSliceOp << "\n");
Expand Down Expand Up @@ -369,19 +370,15 @@ fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
}
}
}
// Store the new loop for follow up producer fusion.
newLoop = tiledOp->getParentOfType<scf::ForallOp>();
}
}
return std::make_pair(newFusionOpportunities, newLoop);
return newFusionOpportunities;
}

static void fuseProducersOfSlices(RewriterBase &rewriter,
std::queue<Operation *> &worklist,
scf::SCFTileAndFuseOptions &options,
scf::ForallOp forallOp) {
SmallVector<LoopLikeOpInterface> loops = {
cast<LoopLikeOpInterface>(&*forallOp)};
MutableArrayRef<LoopLikeOpInterface> loops) {
while (!worklist.empty()) {
auto candidateSlice = cast<tensor::ExtractSliceOp>(worklist.front());
worklist.pop();
Expand Down Expand Up @@ -532,7 +529,6 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {

// If the `tilableOp` is a `memref` op, then just tile the operation.
SmallVector<LoopLikeOpInterface> tilingLoops;
Operation *rootTiledOp = nullptr;
if (tilableOp->getNumResults() == 0) {
FailureOr<scf::SCFTilingResult> tilingResult =
scf::tileUsingSCF(rewriter, tilableOp, tilingOptions);
Expand All @@ -554,7 +550,16 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
rewriter.replaceAllUsesWith(origValue, replacement);
}
std::swap(tileAndFuseResult->loops, tilingLoops);
rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
Operation *rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
auto newFusionOpportunities =
fuseConsumers(rewriter, rootTiledOp, tilingLoops);

// Because we restrict to at most a single tilable consumer for yielding
// a replacement, no new fusion opportunities will yield a replacement,
// meaning there is no need to run consumer fusion again afterwards.
// TODO: run producer and consumer fusion in one worklist.
fuseProducersOfSlices(rewriter, newFusionOpportunities, tileAndFuseOptions,
tilingLoops);
}
if (!tilingLoops.empty()) {
if (tilingLoops.size() != 1 || !isa<scf::ForallOp>(tilingLoops[0])) {
Expand All @@ -563,35 +568,24 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
return signalPassFailure();
}

auto forallOp = cast<scf::ForallOp>(tilingLoops[0]);
if (failed(dropUnitDistributedDims(rewriter, forallOp))) {
forallOp.emitOpError("failed to drop unit dimensions");
auto forallOp =
dropUnitDistributedDims(rewriter, cast<scf::ForallOp>(tilingLoops[0]));
if (failed(forallOp)) {
tilingLoops[0]->emitOpError("failed to drop unit dimensions");
return signalPassFailure();
}

if (rootTiledOp) {
auto [newFusionOpportunities, newLoop] =
fuseConsumers(rewriter, rootTiledOp);

// Because we restrict to at most a single tilable consumer for yielding
// a replacement, no new fusion opportunities will yield a replacement,
// meaning there is no need to run consumer fusion again afterwards.
// TODO: run producer and consumer fusion in one worklist.
fuseProducersOfSlices(rewriter, newFusionOpportunities,
tileAndFuseOptions, newLoop);
forallOp = newLoop;
}

// Reorder the workgroups if the strategy is set to `transpose`.
// This just transposes the first two dimensions of the workgroup i.e., the
// #iree.codegen.workgroup_id_x and #iree.codegen.workgroup_id_y.
// Only reorders if the loop bounds are static.
if (transposeWorkgroup) {
SmallVector<Attribute> mappingAttrs(forallOp.getMappingAttr().getValue());
SmallVector<Attribute> mappingAttrs(
forallOp->getMappingAttr().getValue());
int64_t mappingSize = mappingAttrs.size();
if (areAllStaticLoopBounds(forallOp) && mappingSize >= 2) {
if (areAllStaticLoopBounds(*forallOp) && mappingSize >= 2) {
std::swap(mappingAttrs[mappingSize - 1], mappingAttrs[mappingSize - 2]);
forallOp.setMappingAttr(ArrayAttr::get(context, mappingAttrs));
forallOp->setMappingAttr(ArrayAttr::get(context, mappingAttrs));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,7 @@ template <typename Range>
static LogicalResult
applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
Range &&payloadOps,
MutableArrayRef<LoopLikeOpInterface> loops,
transform::TransformResults &transformResults) {
SmallVector<Operation *> originalConsumerOps;
SmallVector<Operation *> fusedConsumerOps;
Expand All @@ -1201,7 +1202,7 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
rewriter.setInsertionPoint(target);

FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
scf::tileAndFuseConsumerOfSlice(rewriter, target);
scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);

if (failed(fuseConsumerResults))
return failure();
Expand All @@ -1222,9 +1223,18 @@ DiagnosedSilenceableFailure transform_dialect::FuseConsumerOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &transformResults,
transform::TransformState &state) {
LogicalResult result =
applyFuseConsumer(rewriter, getOperation(),
state.getPayloadOps(getTarget()), transformResults);
SmallVector<LoopLikeOpInterface> loops;
for (auto op : getLoops()) {
auto loopOp =
dyn_cast<LoopLikeOpInterface>(*state.getPayloadOps(op).begin());
if (!loopOp) {
return DiagnosedSilenceableFailure::definiteFailure();
}
loops.push_back(loopOp);
}
LogicalResult result = applyFuseConsumer(rewriter, getOperation(),
state.getPayloadOps(getTarget()),
loops, transformResults);
return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
: DiagnosedSilenceableFailure::success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -694,13 +694,14 @@ def FuseConsumerOp : Op<Transform_Dialect, "iree.fuse_consumer",
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";

let arguments =
(ins TransformHandleTypeInterface:$target);
let arguments =(ins
TransformHandleTypeInterface:$target,
Variadic<TransformHandleTypeInterface>:$loops);
let results = (outs TransformHandleTypeInterface:$consumer,
TransformHandleTypeInterface:$fused_consumer);

let assemblyFormat = [{
$target attr-dict `:` functional-type(operands, results)
$target `in` `(` $loops `)` attr-dict `:` functional-type(operands, results)
}];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ static void collectTiledAndFusedOps(Operation *rootOp,
/// Tile the root operation and fuse the producers of the root operation.
/// If `onlyFuseProducerInputOperands` is set, only fuse producer input
/// operands. Returns the tiled operation to be used for fusing consumers.
FailureOr<Operation *>
static FailureOr<scf::SCFTileAndFuseResult>
tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp,
int64_t tilingLevel,
bool onlyFuseProducerInputOperands) {
Expand Down Expand Up @@ -136,10 +136,11 @@ tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp,
}
}

return tiledResults->tiledAndFusedOps.front();
return tiledResults;
}

static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp,
MutableArrayRef<LoopLikeOpInterface> loops) {

// Typically, the consumers of the tiled operation are slices of the
// results of the tiled operation. These are expressed in IR using
Expand Down Expand Up @@ -169,7 +170,8 @@ static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
candidates.pop();

FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp);
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp,
loops);
if (failed(fusedResult)) {
LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: "
<< candidateSliceOp << "\n");
Expand All @@ -196,14 +198,17 @@ static LogicalResult tileRootAndFuse(IRRewriter &rewriter,
int64_t tilingLevel,
bool onlyFuseProducerInputOperands) {

FailureOr<Operation *> tiledOp = tileRootAndFuseProducers(
rewriter, rootOp, tilingLevel, onlyFuseProducerInputOperands);
FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
tileRootAndFuseProducers(rewriter, rootOp, tilingLevel,
onlyFuseProducerInputOperands);

if (failed(tiledOp))
if (failed(tileAndFuseResult))
return failure();

if (!onlyFuseProducerInputOperands)
fuseConsumers(rewriter, tiledOp.value());
if (!onlyFuseProducerInputOperands) {
fuseConsumers(rewriter, tileAndFuseResult->tiledAndFusedOps.front(),
tileAndFuseResult->loops);
}

return success();
}
Expand Down
Loading