2828#include " mlir/Dialect/SCF/IR/SCF.h"
2929#include " mlir/Dialect/Vector/IR/VectorOps.h"
3030#include " mlir/IR/BuiltinDialect.h"
31+ #include " mlir/IR/BuiltinOps.h"
3132#include " mlir/IR/BuiltinTypes.h"
3233#include " mlir/IR/Operation.h"
3334#include " mlir/IR/Region.h"
3637#include " mlir/IR/ValueRange.h"
3738#include " mlir/Pass/Pass.h"
3839#include " mlir/Pass/PassManager.h"
40+ #include " mlir/Support/LLVM.h"
3941#include " mlir/Support/LogicalResult.h"
4042#include " mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
4143#include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
5153#include " llvm/ADT/STLExtras.h"
5254#include " llvm/ADT/SmallVector.h"
5355#include " llvm/ADT/TypeSwitch.h"
56+ #include " llvm/IR/Value.h"
5457#include " llvm/Support/TimeProfiler.h"
5558
5659using namespace cir ;
@@ -209,16 +212,45 @@ static bool findBaseAndIndices(mlir::Value addr, mlir::Value &base,
209212static void eraseIfSafe (mlir::Value oldAddr, mlir::Value newAddr,
210213 SmallVector<mlir::Operation *> &eraseList,
211214 mlir::ConversionPatternRewriter &rewriter) {
215+ newAddr.getDefiningOp ()->getParentOfType <mlir::ModuleOp>()->dump ();
216+ oldAddr.dump ();
217+ newAddr.dump ();
218+
212219 unsigned oldUsedNum =
213220 std::distance (oldAddr.getUses ().begin (), oldAddr.getUses ().end ());
214221 unsigned newUsedNum = 0 ;
215222 for (auto *user : newAddr.getUsers ()) {
216- if (isa<mlir::memref::LoadOp>(*user) || isa<mlir::memref::StoreOp>(*user))
217- ++newUsedNum;
223+ user->dump ();
224+ if (auto loadOpUser = mlir::dyn_cast_or_null<mlir::memref::LoadOp>(*user)) {
225+ if (auto strideVal = loadOpUser.getIndices ()[0 ]) {
226+ strideVal.dump ();
227+ mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back ())
228+ .getOffsets ()[0 ]
229+ .dump ();
230+ if (strideVal ==
231+ mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back ())
232+ .getOffsets ()[0 ])
233+ ++newUsedNum;
234+ }
235+ } else if (auto storeOpUser =
236+ mlir::dyn_cast_or_null<mlir::memref::StoreOp>(*user)) {
237+ if (auto strideVal = storeOpUser.getIndices ()[0 ]) {
238+ strideVal.dump ();
239+ mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back ())
240+ .getOffsets ()[0 ]
241+ .dump ();
242+ if (strideVal ==
243+ mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back ())
244+ .getOffsets ()[0 ])
245+ ++newUsedNum;
246+ }
247+ }
218248 }
219249 if (oldUsedNum == newUsedNum) {
220- for (auto op : eraseList)
250+ for (auto op : eraseList) {
251+ op->dump ();
221252 rewriter.eraseOp (op);
253+ }
222254 }
223255}
224256
@@ -237,7 +269,7 @@ class CIRLoadOpLowering : public mlir::OpConversionPattern<cir::LoadOp> {
237269 rewriter)) {
238270 newLoad = rewriter.create <mlir::memref::LoadOp>(
239271 op.getLoc (), base, indices, op.getIsNontemporal ());
240- // rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices );
272+ newLoad-> dump ( );
241273 eraseIfSafe (op.getAddr (), base, eraseList, rewriter);
242274 } else
243275 newLoad = rewriter.create <mlir::memref::LoadOp>(
@@ -756,6 +788,8 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern<cir::ScopeOp> {
756788 return mlir::success ();
757789 }
758790
791+ // TODO: evaluate if a different mlir core dialect op is better suited for
792+ // this
759793 for (auto &block : scopeOp.getScopeRegion ()) {
760794 rewriter.setInsertionPointToEnd (&block);
761795 auto *terminator = block.getTerminator ();
@@ -1451,8 +1485,8 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
14511485
14521486 auto result = !mlir::failed (pm.run (theModule));
14531487 if (!result) {
1454- // just for debugging purposes
1455- // TODO: remove before creating a PR
1488+ // just for debugging purposes
1489+ // TODO: remove before creating a PR
14561490 theModule->dump ();
14571491 report_fatal_error (
14581492 " The pass manager failed to lower CIR to MLIR standard dialects!" );
0 commit comments