From 520c184356f736ea54884c3edc2007c648639c6f Mon Sep 17 00:00:00 2001 From: Ruizhe Zhao Date: Mon, 25 Oct 2021 17:42:53 +0100 Subject: [PATCH] [FoldIf] look into regions --- lib/mlir/Transforms/FoldIf.cc | 41 ++++++++++++++++++- polymer | 2 +- .../FoldIfPass/fold-if-with-blocks.mlir | 23 +++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) create mode 100644 test/mlir/Transforms/FoldIfPass/fold-if-with-blocks.mlir diff --git a/lib/mlir/Transforms/FoldIf.cc b/lib/mlir/Transforms/FoldIf.cc index d23d1c2eb57..60084d6150d 100644 --- a/lib/mlir/Transforms/FoldIf.cc +++ b/lib/mlir/Transforms/FoldIf.cc @@ -97,13 +97,46 @@ static LogicalResult process(mlir::AffineStoreOp storeOp, Value cond, Value orig = b.create(loc, memref, affMap, mapOperands); Value toStore = b.create( - loc, cond, vmap.lookup(storeOp.getValueToStore()), orig); + loc, cond, vmap.lookupOrDefault(storeOp.getValueToStore()), orig); b.create(loc, toStore, memref, affMap, mapOperands); return success(); } +/// Work within the regions of the provided op. Find the AffineStoreOp, and +/// replace it with the select-based version. +/// TODO: can we have a rather unified implementation? +static LogicalResult replaceWithinRegion(Operation *parentOp, Value cond, + BlockAndValueMapping &vmap, + OpBuilder &b) { + for (Region ®ion : parentOp->getRegions()) { + for (Block &block : region.getBlocks()) { + /// TODO: is there a better way to cache the operations? + SmallVector ops; + for (Operation &op : block.getOperations()) + ops.push_back(&op); + + for (Operation *op : ops) { + if (auto storeOp = dyn_cast(op)) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(storeOp); + + if (failed(process(storeOp, cond, vmap, b))) + return failure(); + + op->erase(); + } else if (op->getNumRegions() >= 1) { + if (failed(replaceWithinRegion(op, cond, vmap, b))) + return failure(); + } + } + } + } + + return success(); +} + /// TODO: filter invalid operations. /// TODO: affine.load might load from invalid address. static LogicalResult process(mlir::AffineIfOp ifOp, OpBuilder &b) { @@ -125,7 +158,11 @@ static LogicalResult process(mlir::AffineIfOp ifOp, OpBuilder &b) { if (failed(process(storeOp, cond, vmap, b))) return failure(); } else { - b.clone(op, vmap); + Operation *cloned = b.clone(op, vmap); + if (cloned->getNumRegions() >= 1) { + if (failed(replaceWithinRegion(cloned, cond, vmap, b))) + return failure(); + } } } diff --git a/polymer b/polymer index a8f8ccdf66e..b6dd56b79b2 160000 --- a/polymer +++ b/polymer @@ -1 +1 @@ -Subproject commit a8f8ccdf66ea54b719559a49c9ef2f91fa1874ff +Subproject commit b6dd56b79b2cdb3350033275c141663ae789dc82 diff --git a/test/mlir/Transforms/FoldIfPass/fold-if-with-blocks.mlir b/test/mlir/Transforms/FoldIfPass/fold-if-with-blocks.mlir new file mode 100644 index 00000000000..1003c47c878 --- /dev/null +++ b/test/mlir/Transforms/FoldIfPass/fold-if-with-blocks.mlir @@ -0,0 +1,23 @@ +// RUN: phism-opt %s -fold-if | FileCheck %s + +#set = affine_set<(d0) : (d0 - 5 == 0)> + +func @foo(%A: memref, %i: index, %a: f32) { + affine.if #set(%i) { + affine.store %a, %A[%i] : memref + affine.for %j = 9 to 10 { + affine.store %a, %A[%j] : memref + } + } + return +} + + +// CHECK: func @foo +// CHECK: affine.load +// CHECK-NEXT: %[[v0:.*]] = select +// CHECK-NEXT: affine.store %[[v0]] +// CHECK: affine.for +// CHECK: affine.load +// CHECK-NEXT: %[[v0:.*]] = select +// CHECK-NEXT: affine.store %[[v0]]