From b2cc13219e4de83f7bf58311a5c5b1ce68b7f733 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Wed, 19 Mar 2025 11:52:44 -0700 Subject: [PATCH 1/2] Pick up https://github.com/llvm/llvm-project/pull/132082 Signed-off-by: MaheshRavishankar --- third_party/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/llvm-project b/third_party/llvm-project index 857a04cd7670..ad8238edb800 160000 --- a/third_party/llvm-project +++ b/third_party/llvm-project @@ -1 +1 @@ -Subproject commit 857a04cd7670b629b560ba7e67c758a0c15e0841 +Subproject commit ad8238edb8008a45e8438929aef8a8d9774784a9 From 08201f554fb322030d8ef9cc91ed0541fba48f56 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Wed, 19 Mar 2025 11:53:10 -0700 Subject: [PATCH 2/2] Fixes for https://github.com/llvm/llvm-project/pull/132082 Signed-off-by: MaheshRavishankar --- .../GPU/GPUFuseAndHoistParallelLoops.cpp | 4 +- .../Common/TileDispatchUsingForall.cpp | 64 +++++++++---------- .../TransformExtensions/CommonExtensions.cpp | 18 ++++-- .../CommonExtensionsOps.td | 7 +- ...LLVMCPUTileRootAndFuseProducerConsumer.cpp | 23 ++++--- 5 files changed, 63 insertions(+), 53 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp index 769866f49807..b02ec1e24c2c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp @@ -282,7 +282,7 @@ struct FuseTilableForallConsumers final } tensor::ParallelInsertSliceOp producerSlice; - scf::ForallOp sliceOwner; + LoopLikeOpInterface sliceOwner; Value fusionOperand; for (auto operand : dpsOp.getDpsInputs()) { auto forallProducer = operand.getDefiningOp(); @@ -320,7 +320,7 @@ struct FuseTilableForallConsumers final } FailureOr fuseConsumerResults = - scf::tileAndFuseConsumerOfSlice(rewriter, producerSlice); + scf::tileAndFuseConsumerOfSlice(rewriter, producerSlice, {sliceOwner}); if (failed(fuseConsumerResults)) { return failure(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp index 159759db519f..e68725e47d01 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp @@ -237,8 +237,8 @@ static bool areAllStaticLoopBounds(scf::ForallOp forallOp) { /// Find dimensions of the loop that are unit-trip count and drop them from the /// distributed dimensions. -static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter, - scf::ForallOp forallOp) { +static FailureOr +dropUnitDistributedDims(RewriterBase &rewriter, scf::ForallOp forallOp) { SmallVector mixedLbs = forallOp.getMixedLowerBound(); SmallVector mixedUbs = forallOp.getMixedUpperBound(); SmallVector mixedSteps = forallOp.getMixedStep(); @@ -261,7 +261,7 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter, } } if (droppedLoops.empty()) { - return success(); + return forallOp; } OpBuilder::InsertionGuard g(rewriter); @@ -303,7 +303,7 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter, rewriter.mergeBlocks(oldLoopBody, newLoopBody, argReplacements); rewriter.replaceOp(forallOp, newForallOp.getResults()); - return success(); + return newForallOp; } //===---------------------------------------------------------------------===// @@ -314,8 +314,9 @@ static LogicalResult dropUnitDistributedDims(RewriterBase &rewriter, // Returns a list of new `tensor.extract_slice` ops with new fusion // opportunities, as well as the new surrounding `scf.forall` (because consumer // fusion replaces the loop). -static std::pair, scf::ForallOp> -fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) { +static std::queue +fuseConsumers(RewriterBase &rewriter, Operation *tiledOp, + MutableArrayRef loops) { auto addCandidateSlices = [](Operation *fusedOp, std::queue &candidates) { @@ -333,7 +334,6 @@ fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) { addCandidateSlices(tiledOp, candidates); std::queue newFusionOpportunities; - scf::ForallOp newLoop = tiledOp->getParentOfType(); while (!candidates.empty()) { // Traverse the slices in BFS fashion. @@ -341,7 +341,8 @@ fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) { candidates.pop(); FailureOr fusedResult = - mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp); + mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp, + loops); if (failed(fusedResult)) { LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: " << candidateSliceOp << "\n"); @@ -369,19 +370,15 @@ fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) { } } } - // Store the new loop for follow up producer fusion. - newLoop = tiledOp->getParentOfType(); } } - return std::make_pair(newFusionOpportunities, newLoop); + return newFusionOpportunities; } static void fuseProducersOfSlices(RewriterBase &rewriter, std::queue &worklist, scf::SCFTileAndFuseOptions &options, - scf::ForallOp forallOp) { - SmallVector loops = { - cast(&*forallOp)}; + MutableArrayRef loops) { while (!worklist.empty()) { auto candidateSlice = cast(worklist.front()); worklist.pop(); @@ -532,7 +529,6 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() { // If the `tilableOp` is a `memref` op, then just tile the operation. SmallVector tilingLoops; - Operation *rootTiledOp = nullptr; if (tilableOp->getNumResults() == 0) { FailureOr tilingResult = scf::tileUsingSCF(rewriter, tilableOp, tilingOptions); @@ -554,7 +550,16 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() { rewriter.replaceAllUsesWith(origValue, replacement); } std::swap(tileAndFuseResult->loops, tilingLoops); - rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front(); + Operation *rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front(); + auto newFusionOpportunities = + fuseConsumers(rewriter, rootTiledOp, tilingLoops); + + // Because we restrict to at most a single tilable consumer for yielding + // a replacement, no new fusion opportunities will yield a replacement, + // meaning there is no need to run consumer fusion again afterwards. + // TODO: run producer and consumer fusion in one worklist. + fuseProducersOfSlices(rewriter, newFusionOpportunities, tileAndFuseOptions, + tilingLoops); } if (!tilingLoops.empty()) { if (tilingLoops.size() != 1 || !isa(tilingLoops[0])) { @@ -563,35 +568,24 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() { return signalPassFailure(); } - auto forallOp = cast(tilingLoops[0]); - if (failed(dropUnitDistributedDims(rewriter, forallOp))) { - forallOp.emitOpError("failed to drop unit dimensions"); + auto forallOp = + dropUnitDistributedDims(rewriter, cast(tilingLoops[0])); + if (failed(forallOp)) { + tilingLoops[0]->emitOpError("failed to drop unit dimensions"); return signalPassFailure(); } - if (rootTiledOp) { - auto [newFusionOpportunities, newLoop] = - fuseConsumers(rewriter, rootTiledOp); - - // Because we restrict to at most a single tilable consumer for yielding - // a replacement, no new fusion opportunities will yield a replacement, - // meaning there is no need to run consumer fusion again afterwards. - // TODO: run producer and consumer fusion in one worklist. - fuseProducersOfSlices(rewriter, newFusionOpportunities, - tileAndFuseOptions, newLoop); - forallOp = newLoop; - } - // Reorder the workgroups if the strategy is set to `transpose`. // This just transposes the first two dimensions of the workgroup i.e., the // #iree.codegen.workgroup_id_x and #iree.codegen.workgroup_id_y. // Only reorders if the loop bounds are static. if (transposeWorkgroup) { - SmallVector mappingAttrs(forallOp.getMappingAttr().getValue()); + SmallVector mappingAttrs( + forallOp->getMappingAttr().getValue()); int64_t mappingSize = mappingAttrs.size(); - if (areAllStaticLoopBounds(forallOp) && mappingSize >= 2) { + if (areAllStaticLoopBounds(*forallOp) && mappingSize >= 2) { std::swap(mappingAttrs[mappingSize - 1], mappingAttrs[mappingSize - 2]); - forallOp.setMappingAttr(ArrayAttr::get(context, mappingAttrs)); + forallOp->setMappingAttr(ArrayAttr::get(context, mappingAttrs)); } } } diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index 86c0a107d847..961c23fb2698 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -1193,6 +1193,7 @@ template static LogicalResult applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, + MutableArrayRef loops, transform::TransformResults &transformResults) { SmallVector originalConsumerOps; SmallVector fusedConsumerOps; @@ -1201,7 +1202,7 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp, rewriter.setInsertionPoint(target); FailureOr fuseConsumerResults = - scf::tileAndFuseConsumerOfSlice(rewriter, target); + scf::tileAndFuseConsumerOfSlice(rewriter, target, loops); if (failed(fuseConsumerResults)) return failure(); @@ -1222,9 +1223,18 @@ DiagnosedSilenceableFailure transform_dialect::FuseConsumerOp::apply( transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { - LogicalResult result = - applyFuseConsumer(rewriter, getOperation(), - state.getPayloadOps(getTarget()), transformResults); + SmallVector loops; + for (auto op : getLoops()) { + auto loopOp = + dyn_cast(*state.getPayloadOps(op).begin()); + if (!loopOp) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + loops.push_back(loopOp); + } + LogicalResult result = applyFuseConsumer(rewriter, getOperation(), + state.getPayloadOps(getTarget()), + loops, transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td index 0c05178043c8..daf2e463d63e 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td @@ -694,13 +694,14 @@ def FuseConsumerOp : Op:$loops); let results = (outs TransformHandleTypeInterface:$consumer, TransformHandleTypeInterface:$fused_consumer); let assemblyFormat = [{ - $target attr-dict `:` functional-type(operands, results) + $target `in` `(` $loops `)` attr-dict `:` functional-type(operands, results) }]; } diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp index fd554c74321d..b30d8b7df856 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUTileRootAndFuseProducerConsumer.cpp @@ -56,7 +56,7 @@ static void collectTiledAndFusedOps(Operation *rootOp, /// Tile the root operation and fuse the producers of the root operation. /// If `onlyFuseProducerInputOperands` is set, only fuse producer input /// operands. Returns the tiled operation to be used for fusing consumers. -FailureOr +static FailureOr tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp, int64_t tilingLevel, bool onlyFuseProducerInputOperands) { @@ -136,10 +136,11 @@ tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp, } } - return tiledResults->tiledAndFusedOps.front(); + return tiledResults; } -static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) { +static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp, + MutableArrayRef loops) { // Typically, the consumers of the tiled operation are slices of the // results of the tiled operation. These are expressed in IR using @@ -169,7 +170,8 @@ static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) { candidates.pop(); FailureOr fusedResult = - mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp); + mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp, + loops); if (failed(fusedResult)) { LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: " << candidateSliceOp << "\n"); @@ -196,14 +198,17 @@ static LogicalResult tileRootAndFuse(IRRewriter &rewriter, int64_t tilingLevel, bool onlyFuseProducerInputOperands) { - FailureOr tiledOp = tileRootAndFuseProducers( - rewriter, rootOp, tilingLevel, onlyFuseProducerInputOperands); + FailureOr tileAndFuseResult = + tileRootAndFuseProducers(rewriter, rootOp, tilingLevel, + onlyFuseProducerInputOperands); - if (failed(tiledOp)) + if (failed(tileAndFuseResult)) return failure(); - if (!onlyFuseProducerInputOperands) - fuseConsumers(rewriter, tiledOp.value()); + if (!onlyFuseProducerInputOperands) { + fuseConsumers(rewriter, tileAndFuseResult->tiledAndFusedOps.front(), + tileAndFuseResult->loops); + } return success(); }