Skip to content

Commit 4e97d8d

Browse files
committed
[CIR][ThroughMLIR] Handle ContinueOp directly under a WhileOp
1 parent 209df5a commit 4e97d8d

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ class SCFWhileLoop {
6363
SCFWhileLoop(cir::WhileOp op, cir::WhileOp::Adaptor adaptor,
6464
mlir::ConversionPatternRewriter *rewriter)
6565
: whileOp(op), adaptor(adaptor), rewriter(rewriter) {}
66-
void transferToSCFWhileOp();
66+
mlir::scf::WhileOp transferToSCFWhileOp();
6767

6868
private:
6969
cir::WhileOp whileOp;
7070
cir::WhileOp::Adaptor adaptor;
71+
mlir::scf::WhileOp scfWhileOp;
7172
mlir::ConversionPatternRewriter *rewriter;
7273
};
7374

@@ -337,7 +338,7 @@ void SCFLoop::transformToSCFWhileOp() {
337338
scfWhileOp.getAfterBody()->end());
338339
}
339340

340-
void SCFWhileLoop::transferToSCFWhileOp() {
341+
mlir::scf::WhileOp SCFWhileLoop::transferToSCFWhileOp() {
341342
auto scfWhileOp = rewriter->create<mlir::scf::WhileOp>(
342343
whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands());
343344
rewriter->createBlock(&scfWhileOp.getBefore());
@@ -348,6 +349,7 @@ void SCFWhileLoop::transferToSCFWhileOp() {
348349
rewriter->inlineBlockBefore(&whileOp.getBody().front(),
349350
scfWhileOp.getAfterBody(),
350351
scfWhileOp.getAfterBody()->end());
352+
return scfWhileOp;
351353
}
352354

353355
void SCFDoLoop::transferToSCFWhileOp() {
@@ -393,14 +395,51 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
393395
};
394396

395397
class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
398+
void rewriteContinue(mlir::scf::WhileOp whileOp, mlir::ConversionPatternRewriter &rewriter) const {
399+
// Collect all ContinueOp inside this while.
400+
llvm::SmallVector<cir::ContinueOp> continues;
401+
whileOp->walk([&](mlir::Operation *op) {
402+
if (auto continueOp = dyn_cast<ContinueOp>(op))
403+
continues.push_back(continueOp);
404+
});
405+
406+
if (continues.empty())
407+
return;
408+
409+
for (auto continueOp : continues) {
410+
// When the break is under an IfOp, a direct replacement of `scf.yield` won't work:
411+
// the yield would jump out of that IfOp instead.
412+
// We might need to change the whileOp itself to achieve the same effect.
413+
for (mlir::Operation *parent = continueOp->getParentOp(); parent != whileOp; parent = parent->getParentOp()) {
414+
if (isa<mlir::scf::IfOp>(parent) || isa<cir::IfOp>(parent))
415+
llvm_unreachable("NYI");
416+
}
417+
418+
// Operations after this break has to be removed.
419+
for (mlir::Operation *runner = continueOp->getNextNode(); runner;) {
420+
mlir::Operation *next = runner->getNextNode();
421+
runner->erase();
422+
runner = next;
423+
}
424+
425+
// Blocks after this break also has to be removed.
426+
for (mlir::Block *block = continueOp->getBlock()->getNextNode(); block;) {
427+
mlir::Block *next = block->getNextNode();
428+
block->erase();
429+
block = next;
430+
}
431+
}
432+
}
433+
396434
public:
397435
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;
398436

399437
mlir::LogicalResult
400438
matchAndRewrite(cir::WhileOp op, OpAdaptor adaptor,
401439
mlir::ConversionPatternRewriter &rewriter) const override {
402440
SCFWhileLoop loop(op, adaptor, &rewriter);
403-
loop.transferToSCFWhileOp();
441+
auto whileOp = loop.transferToSCFWhileOp();
442+
rewriteContinue(whileOp, rewriter);
404443
rewriter.eraseOp(op);
405444
return mlir::success();
406445
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
void for_with_break() {
5+
int i = 0;
6+
while (i < 100) {
7+
i++;
8+
continue;
9+
i++;
10+
}
11+
// Only the first `i++` will be emitted.
12+
13+
// CHECK: scf.while : () -> () {
14+
// CHECK: %[[TMP0:.+]] = memref.load %alloca[]
15+
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
16+
// CHECK: %[[TMP1:.+]] = arith.cmpi slt, %[[TMP0]], %[[HUNDRED]]
17+
// CHECK: scf.condition(%[[TMP1]])
18+
// CHECK: } do {
19+
// CHECK: memref.alloca_scope {
20+
// CHECK: %[[TMP2:.+]] = memref.load %alloca[]
21+
// CHECK: %[[ONE:.+]] = arith.constant 1
22+
// CHECK: %[[TMP3:.+]] = arith.addi %[[TMP2]], %[[ONE]]
23+
// CHECK: memref.store %[[TMP3]], %alloca[]
24+
// CHECK: }
25+
// CHECK: scf.yield
26+
// CHECK: }
27+
}

0 commit comments

Comments
 (0)