Skip to content

Commit f25825f

Browse files
committed
first version of fix for eraseifsafe + small improvements to canonical forOp lowering (TODO: cleanup + write tests)
1 parent 580fbec commit f25825f

File tree

3 files changed

+63
-31
lines changed

3 files changed

+63
-31
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -303,16 +303,26 @@ void SCFLoop::transferToSCFForOp() {
303303
"Not support lowering loop with break, continue or if yet");
304304
// Replace the IV usage to scf loop induction variable.
305305
if (isIVLoad(op, ivAddr)) {
306-
// Replace CIR IV load with arith.addi scf.IV, 0.
307-
// The replacement makes the SCF IV can be automatically propogated
308-
// by OpAdaptor for individual IV user lowering.
309-
// The redundant arith.addi can be removed by later MLIR passes.
310-
rewriter->setInsertionPoint(op);
311-
auto newIV = plusConstant(scfForOp.getInductionVar(), loc, 0);
312-
rewriter->replaceOp(op, newIV.getDefiningOp());
306+
// Replace CIR IV load with scf.IV
307+
// (i.e. remove the load op and replace the uses of the result of the CIR
308+
// IV load with the scf.IV)
309+
rewriter->replaceOp(op, scfForOp.getInductionVar());
313310
}
314311
return mlir::WalkResult::advance();
315312
});
313+
// If the IV was declared in the for op all uses have been replaced by the
314+
// scf.IV and we can remove the alloca + initial store
315+
316+
// The operations before the loop have been transferred to MLIR.
317+
// So we need to go through getRemappedValue to find the value.
318+
auto remapAddr = rewriter->getRemappedValue(ivAddr);
319+
// If IV has more uses than the use in the initial store op keep it
320+
if (!remapAddr || !remapAddr.hasOneUse())
321+
return;
322+
323+
// otherwise remove the alloca + initial store op
324+
rewriter->eraseOp(remapAddr.getDefiningOp());
325+
rewriter->eraseOp(*remapAddr.user_begin());
316326
}
317327

318328
void SCFLoop::transformToSCFWhileOp() {

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/Dialect/SCF/IR/SCF.h"
2929
#include "mlir/Dialect/Vector/IR/VectorOps.h"
3030
#include "mlir/IR/BuiltinDialect.h"
31+
#include "mlir/IR/BuiltinOps.h"
3132
#include "mlir/IR/BuiltinTypes.h"
3233
#include "mlir/IR/Operation.h"
3334
#include "mlir/IR/Region.h"
@@ -36,6 +37,7 @@
3637
#include "mlir/IR/ValueRange.h"
3738
#include "mlir/Pass/Pass.h"
3839
#include "mlir/Pass/PassManager.h"
40+
#include "mlir/Support/LLVM.h"
3941
#include "mlir/Support/LogicalResult.h"
4042
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
4143
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
@@ -51,6 +53,7 @@
5153
#include "llvm/ADT/STLExtras.h"
5254
#include "llvm/ADT/SmallVector.h"
5355
#include "llvm/ADT/TypeSwitch.h"
56+
#include "llvm/IR/Value.h"
5457
#include "llvm/Support/TimeProfiler.h"
5558

5659
using namespace cir;
@@ -209,16 +212,45 @@ static bool findBaseAndIndices(mlir::Value addr, mlir::Value &base,
209212
static void eraseIfSafe(mlir::Value oldAddr, mlir::Value newAddr,
210213
SmallVector<mlir::Operation *> &eraseList,
211214
mlir::ConversionPatternRewriter &rewriter) {
215+
newAddr.getDefiningOp()->getParentOfType<mlir::ModuleOp>()->dump();
216+
oldAddr.dump();
217+
newAddr.dump();
218+
212219
unsigned oldUsedNum =
213220
std::distance(oldAddr.getUses().begin(), oldAddr.getUses().end());
214221
unsigned newUsedNum = 0;
215222
for (auto *user : newAddr.getUsers()) {
216-
if (isa<mlir::memref::LoadOp>(*user) || isa<mlir::memref::StoreOp>(*user))
217-
++newUsedNum;
223+
user->dump();
224+
if (auto loadOpUser = mlir::dyn_cast_or_null<mlir::memref::LoadOp>(*user)) {
225+
if (auto strideVal = loadOpUser.getIndices()[0]) {
226+
strideVal.dump();
227+
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
228+
.getOffsets()[0]
229+
.dump();
230+
if (strideVal ==
231+
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
232+
.getOffsets()[0])
233+
++newUsedNum;
234+
}
235+
} else if (auto storeOpUser =
236+
mlir::dyn_cast_or_null<mlir::memref::StoreOp>(*user)) {
237+
if (auto strideVal = storeOpUser.getIndices()[0]) {
238+
strideVal.dump();
239+
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
240+
.getOffsets()[0]
241+
.dump();
242+
if (strideVal ==
243+
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
244+
.getOffsets()[0])
245+
++newUsedNum;
246+
}
247+
}
218248
}
219249
if (oldUsedNum == newUsedNum) {
220-
for (auto op : eraseList)
250+
for (auto op : eraseList) {
251+
op->dump();
221252
rewriter.eraseOp(op);
253+
}
222254
}
223255
}
224256

@@ -237,7 +269,7 @@ class CIRLoadOpLowering : public mlir::OpConversionPattern<cir::LoadOp> {
237269
rewriter)) {
238270
newLoad = rewriter.create<mlir::memref::LoadOp>(
239271
op.getLoc(), base, indices, op.getIsNontemporal());
240-
// rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices);
272+
newLoad->dump();
241273
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
242274
} else
243275
newLoad = rewriter.create<mlir::memref::LoadOp>(
@@ -756,6 +788,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern<cir::ScopeOp> {
756788
return mlir::success();
757789
}
758790

791+
// TODO: evaluate if a different mlir core dialect op is better suited for
792+
// this
759793
for (auto &block : scopeOp.getScopeRegion()) {
760794
rewriter.setInsertionPointToEnd(&block);
761795
auto *terminator = block.getTerminator();
@@ -1451,8 +1485,8 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
14511485

14521486
auto result = !mlir::failed(pm.run(theModule));
14531487
if (!result) {
1454-
//just for debugging purposes
1455-
//TODO: remove before creating a PR
1488+
// just for debugging purposes
1489+
// TODO: remove before creating a PR
14561490
theModule->dump();
14571491
report_fatal_error(
14581492
"The pass manager failed to lower CIR to MLIR standard dialects!");

clang/test/CIR/Lowering/ThroughMLIR/for.cpp

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@ void constantLoopBound() {
1414
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C100]] step %[[C1]] : i32 {
1515
// CHECK: %[[C3:.*]] = arith.constant 3 : i32
1616
// CHECK: %[[BASE:.*]] = memref.get_global @a : memref<101xi32>
17-
// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32
18-
// CHECK: %[[IV:.*]] = arith.addi %[[I]], %[[C0_i32]] : i32
19-
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[IV]] : i32 to index
17+
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[I]] : i32 to index
2018
// CHECK: memref.store %[[C3]], %[[BASE]][%[[INDEX]]] : memref<101xi32>
2119
// CHECK: }
2220

@@ -33,9 +31,7 @@ void constantLoopBound_LE() {
3331
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C101]] step %[[C1_STEP]] : i32 {
3432
// CHECK: %[[C3:.*]] = arith.constant 3 : i32
3533
// CHECK: %[[BASE:.*]] = memref.get_global @a : memref<101xi32>
36-
// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32
37-
// CHECK: %[[IV:.*]] = arith.addi %[[I]], %[[C0_i32]] : i32
38-
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[IV]] : i32 to index
34+
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[I]] : i32 to index
3935
// CHECK: memref.store %[[C3]], %[[BASE]][%[[INDEX]]] : memref<101xi32>
4036
// CHECK: }
4137

@@ -52,9 +48,7 @@ void variableLoopBound(int l, int u) {
5248
// CHECK: scf.for %[[I:.*]] = %[[LOWER]] to %[[UPPER]] step %[[C1]] : i32 {
5349
// CHECK: %[[C3:.*]] = arith.constant 3 : i32
5450
// CHECK: %[[BASE:.*]] = memref.get_global @a : memref<101xi32>
55-
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
56-
// CHECK: %[[IV:.*]] = arith.addi %[[I]], %[[C0]] : i32
57-
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[IV]] : i32 to index
51+
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[I]] : i32 to index
5852
// CHECK: memref.store %[[C3]], %[[BASE]][%[[INDEX]]] : memref<101xi32>
5953
// CHECK: }
6054

@@ -73,9 +67,7 @@ void ariableLoopBound_LE(int l, int u) {
7367
// CHECK: scf.for %[[I:.*]] = %[[LOWER]] to %[[UPPER]] step %[[C4]] : i32 {
7468
// CHECK: %[[C3:.*]] = arith.constant 3 : i32
7569
// CHECK: %[[BASE:.*]] = memref.get_global @a : memref<101xi32>
76-
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
77-
// CHECK: %[[IV:.*]] = arith.addi %[[I]], %[[C0]] : i32
78-
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[IV]] : i32 to index
70+
// CHECK: %[[INDEX:.*]] = arith.index_cast %[[I]] : i32 to index
7971
// CHECK: memref.store %[[C3]], %[[BASE]][%[[INDEX]]] : memref<101xi32>
8072
// CHECK: }
8173

@@ -89,14 +81,10 @@ void incArray() {
8981
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
9082
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C100]] step %[[C1]] : i32 {
9183
// CHECK: %[[B:.*]] = memref.get_global @b : memref<101xi32>
92-
// CHECK: %[[C0_2:.*]] = arith.constant 0 : i32
93-
// CHECK: %[[IV2:.*]] = arith.addi %[[I]], %[[C0_2]] : i32
94-
// CHECK: %[[INDEX_2:.*]] = arith.index_cast %[[IV2]] : i32 to index
84+
// CHECK: %[[INDEX_2:.*]] = arith.index_cast %[[I]] : i32 to index
9585
// CHECK: %[[B_VALUE:.*]] = memref.load %[[B]][%[[INDEX_2]]] : memref<101xi32>
9686
// CHECK: %[[A:.*]] = memref.get_global @a : memref<101xi32>
97-
// CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
98-
// CHECK: %[[IV1:.*]] = arith.addi %[[I]], %[[C0_1]] : i32
99-
// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[IV1]] : i32 to index
87+
// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[I]] : i32 to index
10088
// CHECK: %[[A_VALUE:.*]] = memref.load %[[A]][%[[INDEX_1]]] : memref<101xi32>
10189
// CHECK: %[[SUM:.*]] = arith.addi %[[A_VALUE]], %[[B_VALUE]] : i32
10290
// CHECK: memref.store %[[SUM]], %[[A]][%[[INDEX_1]]] : memref<101xi32>

0 commit comments

Comments
 (0)