Skip to content

Commit

Permalink
[FoldIf] look into regions
Browse files Browse the repository at this point in the history
  • Loading branch information
kumasento committed Oct 25, 2021
1 parent 5a813e6 commit 66525ba
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 3 deletions.
41 changes: 39 additions & 2 deletions lib/mlir/Transforms/FoldIf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,46 @@ static LogicalResult process(mlir::AffineStoreOp storeOp, Value cond,

Value orig = b.create<mlir::AffineLoadOp>(loc, memref, affMap, mapOperands);
Value toStore = b.create<SelectOp>(
loc, cond, vmap.lookup(storeOp.getValueToStore()), orig);
loc, cond, vmap.lookupOrDefault(storeOp.getValueToStore()), orig);

b.create<mlir::AffineStoreOp>(loc, toStore, memref, affMap, mapOperands);

return success();
}

/// Work within the regions of the provided op. Find the AffineStoreOp, and
/// replace it with the select-based version.
/// TODO: can we have a rather unified implementation?
static LogicalResult replaceWithinRegion(Operation *parentOp, Value cond,
BlockAndValueMapping &vmap,
OpBuilder &b) {
for (Region &region : parentOp->getRegions()) {
for (Block &block : region.getBlocks()) {
/// TODO: is there a better way to cache the operations?
SmallVector<Operation *> ops;
for (Operation &op : block.getOperations())
ops.push_back(&op);

for (Operation *op : ops) {
if (auto storeOp = dyn_cast<mlir::AffineStoreOp>(op)) {
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(storeOp);

if (failed(process(storeOp, cond, vmap, b)))
return failure();

op->erase();
} else if (op->getNumRegions() >= 1) {
if (failed(replaceWithinRegion(op, cond, vmap, b)))
return failure();
}
}
}
}

return success();
}

/// TODO: filter invalid operations.
/// TODO: affine.load might load from invalid address.
static LogicalResult process(mlir::AffineIfOp ifOp, OpBuilder &b) {
Expand All @@ -125,7 +158,11 @@ static LogicalResult process(mlir::AffineIfOp ifOp, OpBuilder &b) {
if (failed(process(storeOp, cond, vmap, b)))
return failure();
} else {
b.clone(op, vmap);
Operation *cloned = b.clone(op, vmap);
if (cloned->getNumRegions() >= 1) {
if (failed(replaceWithinRegion(cloned, cond, vmap, b)))
return failure();
}
}
}

Expand Down
23 changes: 23 additions & 0 deletions test/mlir/Transforms/FoldIfPass/fold-if-with-blocks.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: phism-opt %s -fold-if | FileCheck %s

#set = affine_set<(d0) : (d0 - 5 == 0)>

func @foo(%A: memref<?xf32>, %i: index, %a: f32) {
affine.if #set(%i) {
affine.store %a, %A[%i] : memref<?xf32>
affine.for %j = 9 to 10 {
affine.store %a, %A[%j] : memref<?xf32>
}
}
return
}


// CHECK: func @foo
// CHECK: affine.load
// CHECK-NEXT: %[[v0:.*]] = select
// CHECK-NEXT: affine.store %[[v0]]
// CHECK: affine.for
// CHECK: affine.load
// CHECK-NEXT: %[[v0:.*]] = select
// CHECK-NEXT: affine.store %[[v0]]

0 comments on commit 66525ba

Please sign in to comment.