Skip to content

Commit

Permalink
Roll forward the elimination of common Multinomial ops with the fix:
Browse files Browse the repository at this point in the history
Conservatively use the result types of the original tf.If op, as the types from then/else branches might not match.

PiperOrigin-RevId: 544935570
  • Loading branch information
cky9301 authored and tensorflower-gardener committed Jul 1, 2023
1 parent eded6e3 commit 2a5fbc5
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 29 deletions.
34 changes: 29 additions & 5 deletions tensorflow/compiler/mlir/tfrt/tests/deduplicate_if_results.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: tf-tfrt-opt -tfrt-deduplicate-if-result %s | FileCheck %s -dump-input=fail
// RUN: tf-tfrt-opt -split-input-file -tfrt-deduplicate-if-result %s | FileCheck %s -dump-input=fail

func.func private @then(%x: tensor<i32>, %y: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
func.return %x, %x : tensor<i32>, tensor<i32>
Expand All @@ -14,10 +14,34 @@ func.func private @else(%x: tensor<i32>, %y: tensor<i32>) -> (tensor<i32>, tenso
// CHECK-LABEL: else/tfrt_dedup_results
// CHECK: return {{%.*}} : tensor<i32>

// CHECK-LABEL: @main
func.func @main(%cond: tensor<i1>, %x: tensor<i32>, %y: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
// CHECK: [[r:%.*]] = "tf.If"
// CHECK: return [[r]], [[r]] : tensor<i32>, tensor<i32>
// CHECK-LABEL: @basic
func.func @basic(%cond: tensor<i1>, %x: tensor<i32>, %y: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
// CHECK-NEXT: [[r:%.*]] = "tf.If"
// CHECK-NEXT: return [[r]], [[r]] : tensor<i32>, tensor<i32>
%0, %1 = "tf.If"(%cond, %x, %y) {else_branch = @else, then_branch = @then, is_stateless = true} : (tensor<i1>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>)
return %0, %1 : tensor<i32>, tensor<i32>
}

// -----

func.func private @unmatched_then(%x: tensor<*xi32>, %y: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
func.return %x, %x : tensor<*xi32>, tensor<*xi32>
}

func.func private @unmatched_else(%x: tensor<i32>, %y: tensor<i32>) -> (tensor<i32>, tensor<i32>) {
func.return %y, %y : tensor<i32>, tensor<i32>
}

// CHECK-LABEL: unmatched_then/tfrt_dedup_results
// CHECK: return {{%.*}} : tensor<*xi32>

// CHECK-LABEL: unmatched_else/tfrt_dedup_results
// CHECK: return {{%.*}} : tensor<i32>

// CHECK-LABEL: @unmatched_then_else_type
func.func @unmatched_then_else_type(%cond: tensor<i1>, %x: tensor<*xi32>, %y: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) {
// CHECK-NEXT: [[r:%.*]] = "tf.If"
// CHECK-NEXT: return [[r]], [[r]] : tensor<*xi32>, tensor<*xi32>
%0, %1 = "tf.If"(%cond, %x, %y) {else_branch = @unmatched_else, then_branch = @unmatched_then, is_stateless = true} : (tensor<i1>, tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>)
return %0, %1 : tensor<*xi32>, tensor<*xi32>
}
29 changes: 29 additions & 0 deletions tensorflow/compiler/mlir/tfrt/tests/optimize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,32 @@ func.func @not_fold_device_index() -> tensor<i32> {
%0 = "tf.DeviceIndex"() {device = "", device_names = ["CPU", "GPU"]} : () -> tensor<i32>
func.return %0 : tensor<i32>
}

// -----

// CHECK-LABEL: @eliminate_multinomial
func.func @eliminate_multinomial(%0: tensor<*xf32>, %1: tensor<*xi32>) -> (tensor<*xi64>, tensor<*xi64>) {
// CHECK-NEXT: tf.Multinomial
// CHECK-NEXT: return
%2 = "tf.Multinomial"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0", seed = 0 : i64, seed2 = 0 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xi64>
%3 = "tf.Multinomial"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0", seed = 0 : i64, seed2 = 0 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xi64>
func.return %2, %3 : tensor<*xi64>, tensor<*xi64>
}

// -----

// CHECK-LABEL: @not_eliminate_multinomial
func.func @not_eliminate_multinomial(%0: tensor<*xf32>, %1: tensor<*xi32>) -> (tensor<*xi64>, tensor<*xi64>) {
// CHECK-NEXT: tf.Multinomial
// CHECK-SAME: seed = 0
// CHECK-NEXT: tf.Multinomial
// CHECK-SAME: seed = 1
// CHECK-NEXT: tf.Multinomial
// CHECK-SAME: seed = 0
// CHECK-NEXT: return
%2 = "tf.Multinomial"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0", seed = 0 : i64, seed2 = 0 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xi64>
%3 = "tf.Multinomial"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0", seed = 1 : i64, seed2 = 1 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xi64>
%4 = "tf.Multinomial"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0", seed = 0 : i64, seed2 = 0 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xi64>
%5 = "tf.Multinomial"(%0, %1) {device = "/job:localhost/replica:0/task:0/device:CPU:0", seed = 0 : i64, seed2 = 0 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xi64>
func.return %2, %3 : tensor<*xi64>, tensor<*xi64>
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,17 @@ void DeduplicateIfOps(mlir::ModuleOp module) {
auto new_else_func = get_or_create(else_branch, else_mapping);

mlir::OpBuilder::InsertionGuard guard(builder);

builder.setInsertionPoint(op);

llvm::SmallVector<mlir::Type> new_result_types;
for (int i : then_mapping.new_to_old) {
new_result_types.push_back(op->getResult(i).getType());
}

auto new_if_op = builder.create<mlir::TF::IfOp>(
op.getLoc(), new_then_func.getFunctionType().getResults(),
op.getCond(), op.getInput(), new_then_func.getSymName(),
new_else_func.getSymName(), op.getIsStateless());
op.getLoc(), new_result_types, op.getCond(), op.getInput(),
new_then_func.getSymName(), new_else_func.getSymName(),
op.getIsStateless());

DCHECK_EQ(then_mapping.old_to_new.size(), op.getNumResults());
for (int i = 0; i < then_mapping.old_to_new.size(); ++i) {
Expand Down
44 changes: 26 additions & 18 deletions tensorflow/compiler/mlir/tfrt/transforms/merge_tf_if_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <iterator>

#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
Expand Down Expand Up @@ -79,35 +82,40 @@ class MergeTfIfOpsPass
}

void runOnOperation() override {
constexpr int kMaxIter = 10;
constexpr int kMaxIter = 20;
auto module = getOperation();

bool changed = true;

for (int i = 0; i < kMaxIter && changed; ++i) {
changed = false;
for (auto func_op :
llvm::make_early_inc_range(module.getOps<mlir::func::FuncOp>())) {
changed |= ProcessFunction(func_op, i);
}

if (changed) {
// Run optimization passes to expose more merge opportunities among the
// then-branch functions and the else-branch functions that are now
// respectively merged, for the next iteration.
mlir::OpPassManager pm(module.getOperationName());
pm.addPass(mlir::createInlinerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(CreateDeduplicateIfResultPass());
pm.addPass(mlir::createInlinerPass());
pm.addPass(mlir::createCSEPass());
if (mlir::failed(runPipeline(pm, module))) {
module.emitWarning(
absl::StrCat("could not run inliner pass within the "
"tfrt-merge-tf-if-ops pass iteration ",
i));
break;
}
mlir::OperationFingerPrint fingerprint(module);

// Run optimization passes to expose more merge opportunities among the
// then-branch functions and the else-branch functions that are now
// respectively merged, for the next iteration.
mlir::OpPassManager pm(module.getOperationName());
pm.addPass(mlir::createInlinerPass());
pm.addPass(mlir::createCSEPass());
pm.addPass(CreateDeduplicateIfResultPass());
pm.addPass(mlir::createInlinerPass());
pm.addPass(mlir::createCSEPass());
pm.addNestedPass<mlir::func::FuncOp>(
tfrt_compiler::CreateOptimizeTfForTfrtPass());
if (mlir::failed(runPipeline(pm, module))) {
module.emitWarning(
absl::StrCat("could not run inliner pass within the "
"tfrt-merge-tf-if-ops pass iteration ",
i));
break;
}

changed |= fingerprint != mlir::OperationFingerPrint(module);
}
}

Expand Down
68 changes: 66 additions & 2 deletions tensorflow/compiler/mlir/tfrt/transforms/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class FoldDeviceIndex : public mlir::OpRewritePattern<mlir::TF::DeviceIndexOp> {

mlir::LogicalResult matchAndRewrite(
mlir::TF::DeviceIndexOp op,
mlir::PatternRewriter& rewriter) const override {
mlir::PatternRewriter &rewriter) const override {
auto device = op->getAttrOfType<mlir::StringAttr>("device");
if (!device) return mlir::failure();

Expand All @@ -55,6 +55,67 @@ class FoldDeviceIndex : public mlir::OpRewritePattern<mlir::TF::DeviceIndexOp> {
}
};

// A custom hash and compare function for finding out common ops.
struct SimpleOperationInfo : public llvm::DenseMapInfo<mlir::Operation *> {
static unsigned getHashValue(const mlir::Operation *opC) {
return mlir::OperationEquivalence::computeHash(
const_cast<mlir::Operation *>(opC),
/*hashOperands=*/mlir::OperationEquivalence::directHashValue,
/*hashResults=*/mlir::OperationEquivalence::ignoreHashValue,
mlir::OperationEquivalence::IgnoreLocations);
}
static bool isEqual(const mlir::Operation *lhsC,
const mlir::Operation *rhsC) {
auto *lhs = const_cast<mlir::Operation *>(lhsC);
auto *rhs = const_cast<mlir::Operation *>(rhsC);
if (lhs == rhs) return true;
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
rhs == getTombstoneKey() || rhs == getEmptyKey())
return false;
return mlir::OperationEquivalence::isEquivalentTo(
const_cast<mlir::Operation *>(lhsC),
const_cast<mlir::Operation *>(rhsC),
mlir::OperationEquivalence::IgnoreLocations);
}
};

void EliminateCommonMultinomialOps(mlir::Block &block) {
llvm::SmallDenseMap<mlir::Operation *,
llvm::SmallVector<mlir::TF::MultinomialOp>, 2,
SimpleOperationInfo>
multinomial_to_eliminate;

auto eliminate = [&]() {
auto &list = multinomial_to_eliminate.begin()->second;
auto first = list.front();
for (auto op : llvm::drop_begin(list)) {
op.getOutput().replaceAllUsesWith(first.getOutput());
op->erase();
}
multinomial_to_eliminate.clear();
};

for (auto &op : block) {
auto multinomial_op = llvm::dyn_cast<mlir::TF::MultinomialOp>(&op);
// Conservatively, we only eliminate back-to-back tf.Multinomial ops.
if (multinomial_op) {
if (multinomial_to_eliminate.find(multinomial_op) ==
multinomial_to_eliminate.end() &&
!multinomial_to_eliminate.empty()) {
// If the current op is a tf.Multinomial but it is different from the
// preiously found tf.Multinomial, then we eliminate the prviously found
// tf.Multinomial.
eliminate();
}
multinomial_to_eliminate[multinomial_op].push_back(multinomial_op);
} else if (!multinomial_to_eliminate.empty()) {
// If the current op is not a tf.Multinomial, then we eliminate previously
// found tf.Multinomial
eliminate();
}
}
}

// Optimization pass for TFRT-specific rewrite patterns.
class OptimizeTfForTfrt
: public mlir::PassWrapper<OptimizeTfForTfrt,
Expand All @@ -68,7 +129,7 @@ class OptimizeTfForTfrt
return "optmize TF MLIR for TFRT workflow.";
}

mlir::LogicalResult initialize(mlir::MLIRContext* context) override {
mlir::LogicalResult initialize(mlir::MLIRContext *context) override {
mlir::RewritePatternSet pattern_list(context);
pattern_list.add<FoldDeviceIndex>(context);
patterns_ = std::move(pattern_list);
Expand All @@ -77,6 +138,9 @@ class OptimizeTfForTfrt

void runOnOperation() override {
auto func = getOperation();

EliminateCommonMultinomialOps(func.getBody().front());

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(func, patterns_)))
signalPassFailure();
}
Expand Down

0 comments on commit 2a5fbc5

Please sign in to comment.