[mlir] Extend SCF loopUnrollByFactor to return the result loops#114573
Merged
[mlir] Extend SCF loopUnrollByFactor to return the result loops#114573
Conversation
Member
|
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Hongtao Yu (htyu) ChangesThere is a need of accessing the resulted epilog loop from the SC loop unroller. It'd clean and convenient to get that directly from the loop unroller instead of rescanning the whole function, as discussed in triton-lang/triton#5027 . I'm changing the result type of Full diff: https://github.com/llvm/llvm-project/pull/114573.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 4001ba3fc84c9d..eda64ea69f81d1 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -111,11 +111,13 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);
-/// Unrolls this for operation by the specified unroll factor. Returns failure
-/// if the loop cannot be unrolled either due to restrictions or due to invalid
-/// unroll factors. Requires positive loop bounds and step. If specified,
-/// annotates the Ops in each unrolled iteration by applying `annotateFn`.
-LogicalResult loopUnrollByFactor(
+/// Unrolls this for operation by the specified unroll factor. Returns the
+/// unrolled main loop and the eplilog loop in sequence, if the loop is
+/// unrolled. Otherwise returns an empty vector if the loop cannot be unrolled
+/// either due to restrictions or due to invalid unroll factors. Requires
+/// positive loop bounds and step. If specified, annotates the Ops in each
+/// unrolled iteration by applying `annotateFn`.
+SmallVector<scf::ForOp> loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index 551411bb147653..c84cb13f8b6bb2 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -353,8 +353,10 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
LogicalResult result(failure());
- if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
- result = loopUnrollByFactor(scfFor, getFactor());
+ if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
+ auto resultLoops = loopUnrollByFactor(scfFor, getFactor());
+ result = resultLoops.empty() ? failure() : success();
+ }
else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
result = loopUnrollByFactor(affineFor, getFactor());
else
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 43fcc595af0f7e..8394ac47888100 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -372,15 +372,17 @@ static void generateUnrolledLoop(
loopBodyBlock->getTerminator()->setOperands(lastYielded);
}
-/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
-LogicalResult mlir::loopUnrollByFactor(
+/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
+/// eplilog loop in sequence, if the loop is unrolled. Otherwise return an empty
+/// vector.
+SmallVector<scf::ForOp> mlir::loopUnrollByFactor(
scf::ForOp forOp, uint64_t unrollFactor,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
assert(unrollFactor > 0 && "expected positive unroll factor");
// Return if the loop body is empty.
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
- return success();
+ return {forOp};
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -401,8 +403,8 @@ LogicalResult mlir::loopUnrollByFactor(
if (unrollFactor == 1) {
if (*constTripCount == 1 &&
failed(forOp.promoteIfSingleIteration(rewriter)))
- return failure();
- return success();
+ return {};
+ return {forOp};
}
int64_t tripCountEvenMultiple =
@@ -450,6 +452,9 @@ LogicalResult mlir::loopUnrollByFactor(
boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
}
+ SmallVector<scf::ForOp, 2> resultLoops;
+ resultLoops.push_back(forOp);
+
// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
if (generateEpilogueLoop) {
OpBuilder epilogueBuilder(forOp->getContext());
@@ -468,6 +473,7 @@ LogicalResult mlir::loopUnrollByFactor(
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
epilogueForOp.getInitArgs().size(), results);
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
+ resultLoops.push_back(epilogueForOp);
}
// Create unrolled loop.
@@ -490,7 +496,7 @@ LogicalResult mlir::loopUnrollByFactor(
annotateFn, iterArgs, yieldedValues);
// Promote the loop body up if this has turned into a single iteration loop.
(void)forOp.promoteIfSingleIteration(rewriter);
- return success();
+ return resultLoops;
}
/// Check if bounds of all inner loops are defined outside of `forOp`
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
ThomasRaoux
reviewed
Nov 1, 2024
ThomasRaoux
reviewed
Nov 3, 2024
ThomasRaoux
approved these changes
Nov 3, 2024
Contributor
ThomasRaoux
left a comment
There was a problem hiding this comment.
LGTM, although I don't think you have addressed Mahesh's comment. Please make sure to address it before merging
htyu
added a commit
to triton-lang/triton
that referenced
this pull request
Nov 5, 2024
#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work.
PhilippRados
pushed a commit
to PhilippRados/llvm-project
that referenced
this pull request
Nov 6, 2024
…#114573) There is a need of accessing the resulted epilog loop from the SC loop unroller. It'd clean and convenient to get that directly from the loop unroller instead of rescanning the whole function, as discussed in triton-lang/triton#5027 . I'm changing the result type of `loopUnrollByFactor` for that.
Luosuu
pushed a commit
to Luosuu/triton
that referenced
this pull request
Nov 13, 2024
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work.
guacamoleo
pushed a commit
to guacamoleo/triton
that referenced
this pull request
Nov 14, 2024
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work.
jataylo
pushed a commit
to jataylo/triton
that referenced
this pull request
Dec 4, 2024
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work. (cherry picked from commit 3c296ab)
jataylo
pushed a commit
to jataylo/triton
that referenced
this pull request
Dec 5, 2024
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work. (cherry picked from commit 3c296ab)
jataylo
pushed a commit
to jataylo/triton
that referenced
this pull request
Dec 6, 2024
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work. (cherry picked from commit 3c296ab)
jataylo
pushed a commit
to jataylo/triton
that referenced
this pull request
Dec 11, 2024
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work. (cherry picked from commit 3c296ab)
jataylo
pushed a commit
to jataylo/triton
that referenced
this pull request
Dec 12, 2024
triton-lang#5064) Bumping llvm to include a loop unroller fix: llvm/llvm-project#114573. This is needed for subsequent loop unroller upstreaming work. (cherry picked from commit 3c296ab)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
There is a need of accessing the resulted epilog loop from the SC loop unroller. It'd clean and convenient to get that directly from the loop unroller instead of rescanning the whole function, as discussed in triton-lang/triton#5027 . I'm changing the result type of
loopUnrollByFactorfor that.