Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,14 @@ scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
FailureOr<scf::ForallOp> normalizeForallOp(RewriterBase &rewriter,
scf::ForallOp forallOp);

/// Check if the provided loops are perfectly nested for-loops. Perfect nesting
/// means:
/// 1. All loops are scf.for operations
/// 2. Each outer loop's region iter args match the inner loop's init args
/// 3. Each outer loop's yields match the inner loop's results
/// 4. Each region iter arg and result has exactly one use
bool isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops);

} // namespace mlir

#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
59 changes: 1 addition & 58 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1916,63 +1916,6 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
return failure();
}

/// Check that the loop is perfectly nested.
/// The loops are expected to be ordered from outer most to inner most.
/// For example:
/// ```
/// %0 = scf.for()
/// %1 = scf.for()
/// %2 = scf.for()
/// %3 = ...
/// yield %3
/// yield %2
/// yield %1
/// ```
/// Here loops should be [%0, %1].
static bool
isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected empty loop nest");
if (loops.size() == 1) {
return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
}
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
if (!outerFor || !innerFor) {
return false;
}
auto outerBBArgs = outerFor.getRegionIterArgs();
auto innerIterArgs = innerFor.getInitArgs();
if (outerBBArgs.size() != innerIterArgs.size()) {
return false;
}

for (auto [outerBBArg, innerIterArg] :
llvm::zip_equal(outerBBArgs, innerIterArgs)) {
if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
innerIterArg != outerBBArg) {
return false;
}
}

ValueRange outerYields =
cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
ValueRange innerResults = innerFor.getResults();
if (outerYields.size() != innerResults.size()) {
return false;
}
for (auto [outerYield, innerResult] :
llvm::zip_equal(outerYields, innerResults)) {
if (!llvm::hasSingleElement(innerResult.getUses()) ||
outerYield != innerResult) {
return false;
}
}
}
return true;
}

/// Fetch the untiled consumer of the outermost scf.for's result which is
/// yielded by a tensor.insert_slice from the innermost scf.for. This function
/// makes the following assumptions :
Expand All @@ -1993,7 +1936,7 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter,
}

// 2. Check that the loop is perfectly nested.
if (!isPerfectlyNestedForLoops(loops)) {
if (!mlir::isPerfectlyNestedForLoops(loops)) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "expected passed loops to be perfectly nested.");
}
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Dialect/SCF/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1512,3 +1512,41 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
rewriter.replaceOp(forallOp, normalizedForallOp);
return normalizedForallOp;
}

bool mlir::isPerfectlyNestedForLoops(
MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected empty loop nest");
if (loops.size() == 1)
return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
if (!outerFor || !innerFor)
return false;
auto outerBBArgs = outerFor.getRegionIterArgs();
auto innerIterArgs = innerFor.getInitArgs();
if (outerBBArgs.size() != innerIterArgs.size())
return false;

for (auto [outerBBArg, innerIterArg] :
llvm::zip_equal(outerBBArgs, innerIterArgs)) {
if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
innerIterArg != outerBBArg)
return false;
}

ValueRange outerYields =
cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
ValueRange innerResults = innerFor.getResults();
if (outerYields.size() != innerResults.size())
return false;
for (auto [outerYield, innerResult] :
llvm::zip_equal(outerYields, innerResults)) {
if (!llvm::hasSingleElement(innerResult.getUses()) ||
outerYield != innerResult)
return false;
}
}
return true;
}
Loading