diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h index 5c0d5643c0198..fdf2570626980 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h @@ -79,6 +79,12 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns, /// loop bounds and loop steps are canonicalized. void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns); +/// Populate patterns to uplift `scf.while` ops to `scf.for`. +/// Uplifitng expects a specific ops pattern: +/// * `before` block consisting of single arith.cmp op +/// * `after` block containing arith.addi +void populateUpliftWhileToForPatterns(RewritePatternSet &patterns); + } // namespace scf } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h index 690cd146c606e..220dcb35571d2 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h @@ -222,6 +222,12 @@ FailureOr wrapWhileLoopInZeroTripCheck(WhileOp whileOp, RewriterBase &rewriter, bool forceCreateCheck = false); +/// Try to uplift `scf.while` op to `scf.for`. +/// Uplifitng expects a specific ops pattern: +/// * `before` block consisting of single arith.cmp op +/// * `after` block containing arith.addi +FailureOr upliftWhileToForLoop(RewriterBase &rewriter, WhileOp loop); + } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt index e5494205e086a..a2925aef17ca7 100644 --- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSCFTransforms StructuralTypeConversions.cpp TileUsingInterface.cpp WrapInZeroTripCheck.cpp + UpliftWhileToFor.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp new file mode 100644 index 0000000000000..fea2f659535bb --- /dev/null +++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp @@ -0,0 +1,222 @@ +//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Transforms SCF.WhileOp's into SCF.ForOp's. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; + +namespace { +struct UpliftWhileOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::WhileOp loop, + PatternRewriter &rewriter) const override { + return upliftWhileToForLoop(rewriter, loop); + } +}; +} // namespace + +FailureOr mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, + scf::WhileOp loop) { + Block *beforeBody = loop.getBeforeBody(); + if (!llvm::hasSingleElement(beforeBody->without_terminator())) + return rewriter.notifyMatchFailure(loop, "Loop body must have single op"); + + auto cmp = dyn_cast(beforeBody->front()); + if (!cmp) + return rewriter.notifyMatchFailure(loop, + "Loop body must have single cmp op"); + + scf::ConditionOp beforeTerm = loop.getConditionOp(); + if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult()) + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { + diag << "Expected single condition use: " << *cmp; + }); + + // All `before` block args must be directly forwarded to ConditionOp. + // They will be converted to `scf.for` `iter_vars` except induction var. + if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs()) + return rewriter.notifyMatchFailure(loop, "Invalid args order"); + + using Pred = arith::CmpIPredicate; + Pred predicate = cmp.getPredicate(); + if (predicate != Pred::slt && predicate != Pred::sgt) + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { + diag << "Expected 'slt' or 'sgt' predicate: " << *cmp; + }); + + BlockArgument inductionVar; + Value ub; + DominanceInfo dom; + + // Check if cmp has a suitable form. One of the arguments must be a `before` + // block arg, other must be defined outside `scf.while` and will be treated + // as upper bound. + for (bool reverse : {false, true}) { + auto expectedPred = reverse ? Pred::sgt : Pred::slt; + if (cmp.getPredicate() != expectedPred) + continue; + + auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs(); + auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs(); + + auto blockArg = dyn_cast(arg1); + if (!blockArg || blockArg.getOwner() != beforeBody) + continue; + + if (!dom.properlyDominates(arg2, loop)) + continue; + + inductionVar = blockArg; + ub = arg2; + break; + } + + if (!inductionVar) + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { + diag << "Unrecognized cmp form: " << *cmp; + }); + + // inductionVar must have 2 uses: one is in `cmp` and other is `condition` + // arg. + if (!llvm::hasNItems(inductionVar.getUses(), 2)) + return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) { + diag << "Unrecognized induction var: " << inductionVar; + }); + + Block *afterBody = loop.getAfterBody(); + scf::YieldOp afterTerm = loop.getYieldOp(); + auto argNumber = inductionVar.getArgNumber(); + auto afterTermIndArg = afterTerm.getResults()[argNumber]; + + auto inductionVarAfter = afterBody->getArgument(argNumber); + + Value step; + + // Find suitable `addi` op inside `after` block, one of the args must be an + // Induction var passed from `before` block and second arg must be defined + // outside of the loop and will be considered step value. + // TODO: Add `subi` support? + for (auto &use : inductionVarAfter.getUses()) { + auto owner = dyn_cast(use.getOwner()); + if (!owner) + continue; + + auto other = + (inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs()); + if (!dom.properlyDominates(other, loop)) + continue; + + if (afterTermIndArg != owner.getResult()) + continue; + + step = other; + break; + } + + if (!step) + return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op"); + + auto lb = loop.getInits()[argNumber]; + + assert(lb.getType().isIntOrIndex()); + assert(lb.getType() == ub.getType()); + assert(lb.getType() == step.getType()); + + llvm::SmallVector newArgs; + + // Populate inits for new `scf.for`, skip induction var. + newArgs.reserve(loop.getInits().size()); + for (auto &&[i, init] : llvm::enumerate(loop.getInits())) { + if (i == argNumber) + continue; + + newArgs.emplace_back(init); + } + + Location loc = loop.getLoc(); + + // With `builder == nullptr`, ForOp::build will try to insert terminator at + // the end of newly created block and we don't want it. Provide empty + // dummy builder instead. + auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {}; + auto newLoop = + rewriter.create(loc, lb, ub, step, newArgs, emptyBuilder); + + Block *newBody = newLoop.getBody(); + + // Populate block args for `scf.for` body, move induction var to the front. + newArgs.clear(); + ValueRange newBodyArgs = newBody->getArguments(); + for (auto i : llvm::seq(0, newBodyArgs.size())) { + if (i < argNumber) { + newArgs.emplace_back(newBodyArgs[i + 1]); + } else if (i == argNumber) { + newArgs.emplace_back(newBodyArgs.front()); + } else { + newArgs.emplace_back(newBodyArgs[i]); + } + } + + rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(), + newArgs); + + auto term = cast(newBody->getTerminator()); + + // Populate new yield args, skipping the induction var. + newArgs.clear(); + for (auto &&[i, arg] : llvm::enumerate(term.getResults())) { + if (i == argNumber) + continue; + + newArgs.emplace_back(arg); + } + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(term); + rewriter.replaceOpWithNewOp(term, newArgs); + + // Compute induction var value after loop execution. + rewriter.setInsertionPointAfter(newLoop); + Value one; + if (isa(step.getType())) { + one = rewriter.create(loc, 1); + } else { + one = rewriter.create(loc, 1, step.getType()); + } + + Value stepDec = rewriter.create(loc, step, one); + Value len = rewriter.create(loc, ub, lb); + len = rewriter.create(loc, len, stepDec); + len = rewriter.create(loc, len, step); + len = rewriter.create(loc, len, one); + Value res = rewriter.create(loc, len, step); + res = rewriter.create(loc, lb, res); + + // Reconstruct `scf.while` results, inserting final induction var value + // into proper place. + newArgs.clear(); + llvm::append_range(newArgs, newLoop.getResults()); + newArgs.insert(newArgs.begin() + argNumber, res); + rewriter.replaceOp(loop, newArgs); + return newLoop; +} + +void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/SCF/uplift-while.mlir b/mlir/test/Dialect/SCF/uplift-while.mlir new file mode 100644 index 0000000000000..25ea6142a332d --- /dev/null +++ b/mlir/test/Dialect/SCF/uplift-while.mlir @@ -0,0 +1,157 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-scf-uplift-while-to-for))' -split-input-file -allow-unregistered-dialect | FileCheck %s + +func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index { + %0 = scf.while (%arg3 = %arg0) : (index) -> (index) { + %1 = arith.cmpi slt, %arg3, %arg1 : index + scf.condition(%1) %arg3 : index + } do { + ^bb0(%arg3: index): + "test.test1"(%arg3) : (index) -> () + %added = arith.addi %arg3, %arg2 : index + "test.test2"(%added) : (index) -> () + scf.yield %added : index + } + return %0 : index +} + +// CHECK-LABEL: func @uplift_while +// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] { +// CHECK: "test.test1"(%[[I]]) : (index) -> () +// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index +// CHECK: "test.test2"(%[[INC]]) : (index) -> () +// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index +// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index +// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index +// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index +// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index +// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index +// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index +// CHECK: return %[[R7]] : index + +// ----- + +func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index { + %0 = scf.while (%arg3 = %arg0) : (index) -> (index) { + %1 = arith.cmpi sgt, %arg1, %arg3 : index + scf.condition(%1) %arg3 : index + } do { + ^bb0(%arg3: index): + "test.test1"(%arg3) : (index) -> () + %added = arith.addi %arg3, %arg2 : index + "test.test2"(%added) : (index) -> () + scf.yield %added : index + } + return %0 : index +} + +// CHECK-LABEL: func @uplift_while +// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] { +// CHECK: "test.test1"(%[[I]]) : (index) -> () +// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : index +// CHECK: "test.test2"(%[[INC]]) : (index) -> () +// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index +// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index +// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index +// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index +// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index +// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index +// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index +// CHECK: return %[[R7]] : index + +// ----- + +func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> index { + %0 = scf.while (%arg3 = %arg0) : (index) -> (index) { + %1 = arith.cmpi slt, %arg3, %arg1 : index + scf.condition(%1) %arg3 : index + } do { + ^bb0(%arg3: index): + "test.test1"(%arg3) : (index) -> () + %added = arith.addi %arg2, %arg3 : index + "test.test2"(%added) : (index) -> () + scf.yield %added : index + } + return %0 : index +} + +// CHECK-LABEL: func @uplift_while +// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] { +// CHECK: "test.test1"(%[[I]]) : (index) -> () +// CHECK: %[[INC:.*]] = arith.addi %[[STEP]], %[[I]] : index +// CHECK: "test.test2"(%[[INC]]) : (index) -> () +// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : index +// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : index +// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : index +// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : index +// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : index +// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : index +// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : index +// CHECK: return %[[R7]] : index + + +// ----- + +func.func @uplift_while(%arg0: index, %arg1: index, %arg2: index) -> (i32, f32) { + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2.0 : f32 + %0:3 = scf.while (%arg4 = %c1, %arg3 = %arg0, %arg5 = %c2) : (i32, index, f32) -> (i32, index, f32) { + %1 = arith.cmpi slt, %arg3, %arg1 : index + scf.condition(%1) %arg4, %arg3, %arg5 : i32, index, f32 + } do { + ^bb0(%arg4: i32, %arg3: index, %arg5: f32): + %1 = "test.test1"(%arg4) : (i32) -> i32 + %added = arith.addi %arg3, %arg2 : index + %2 = "test.test2"(%arg5) : (f32) -> f32 + scf.yield %1, %added, %2 : i32, index, f32 + } + return %0#0, %0#2 : i32, f32 +} + +// CHECK-LABEL: func @uplift_while +// CHECK-SAME: (%[[BEGIN:.*]]: index, %[[END:.*]]: index, %[[STEP:.*]]: index) -> (i32, f32) +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[C2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[RES:.*]]:2 = scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] +// CHECK-SAME: iter_args(%[[ARG1:.*]] = %[[C1]], %[[ARG2:.*]] = %[[C2]]) -> (i32, f32) { +// CHECK: %[[T1:.*]] = "test.test1"(%[[ARG1]]) : (i32) -> i32 +// CHECK: %[[T2:.*]] = "test.test2"(%[[ARG2]]) : (f32) -> f32 +// CHECK: scf.yield %[[T1]], %[[T2]] : i32, f32 +// CHECK: return %[[RES]]#0, %[[RES]]#1 : i32, f32 + +// ----- + +func.func @uplift_while(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 { + %0 = scf.while (%arg3 = %arg0) : (i64) -> (i64) { + %1 = arith.cmpi slt, %arg3, %arg1 : i64 + scf.condition(%1) %arg3 : i64 + } do { + ^bb0(%arg3: i64): + "test.test1"(%arg3) : (i64) -> () + %added = arith.addi %arg3, %arg2 : i64 + "test.test2"(%added) : (i64) -> () + scf.yield %added : i64 + } + return %0 : i64 +} + +// CHECK-LABEL: func @uplift_while +// CHECK-SAME: (%[[BEGIN:.*]]: i64, %[[END:.*]]: i64, %[[STEP:.*]]: i64) -> i64 +// CHECK: %[[C1:.*]] = arith.constant 1 : i64 +// CHECK: scf.for %[[I:.*]] = %[[BEGIN]] to %[[END]] step %[[STEP]] : i64 { +// CHECK: "test.test1"(%[[I]]) : (i64) -> () +// CHECK: %[[INC:.*]] = arith.addi %[[I]], %[[STEP]] : i64 +// CHECK: "test.test2"(%[[INC]]) : (i64) -> () +// CHECK: %[[R1:.*]] = arith.subi %[[STEP]], %[[C1]] : i64 +// CHECK: %[[R2:.*]] = arith.subi %[[END]], %[[BEGIN]] : i64 +// CHECK: %[[R3:.*]] = arith.addi %[[R2]], %[[R1]] : i64 +// CHECK: %[[R4:.*]] = arith.divsi %[[R3]], %[[STEP]] : i64 +// CHECK: %[[R5:.*]] = arith.subi %[[R4]], %[[C1]] : i64 +// CHECK: %[[R6:.*]] = arith.muli %[[R5]], %[[STEP]] : i64 +// CHECK: %[[R7:.*]] = arith.addi %[[BEGIN]], %[[R6]] : i64 +// CHECK: return %[[R7]] : i64 diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt index d93bd55915182..792430cc84b65 100644 --- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_library(MLIRSCFTestPasses TestLoopUnrolling.cpp TestSCFUtils.cpp TestSCFWrapInZeroTripCheck.cpp + TestUpliftWhileToFor.cpp TestWhileOpBuilder.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp new file mode 100644 index 0000000000000..468bc0ca78489 --- /dev/null +++ b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp @@ -0,0 +1,50 @@ +//===- TestUpliftWhileToFor.cpp - while to for loop uplifting test pass ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Pass to test transforms SCF.WhileOp's into SCF.ForOp's. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct TestSCFUpliftWhileToFor + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFUpliftWhileToFor) + + StringRef getArgument() const final { return "test-scf-uplift-while-to-for"; } + + StringRef getDescription() const final { + return "test scf while to for uplifting"; + } + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); + RewritePatternSet patterns(ctx); + scf::populateUpliftWhileToForPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestSCFUpliftWhileToFor() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 6ce9f3041d6f4..237ebeb166dc9 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -130,6 +130,7 @@ void registerTestOneToNTypeConversionPass(); void registerTestOpaqueLoc(); void registerTestPadFusion(); void registerTestRecursiveTypesPass(); +void registerTestSCFUpliftWhileToFor(); void registerTestSCFUtilsPass(); void registerTestSCFWhileOpBuilderPass(); void registerTestSCFWrapInZeroTripCheckPasses(); @@ -258,6 +259,7 @@ void registerTestPasses() { mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestPadFusion(); mlir::test::registerTestRecursiveTypesPass(); + mlir::test::registerTestSCFUpliftWhileToFor(); mlir::test::registerTestSCFUtilsPass(); mlir::test::registerTestSCFWhileOpBuilderPass(); mlir::test::registerTestSCFWrapInZeroTripCheckPasses();