Skip to content

Commit f53e340

Browse files
committed
[CIR][ThroughMLIR] Handle ContinueOp directly under a WhileOp
1 parent 209df5a commit f53e340

File tree

2 files changed

+71
-3
lines changed

2 files changed

+71
-3
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,12 @@ class SCFWhileLoop {
6363
SCFWhileLoop(cir::WhileOp op, cir::WhileOp::Adaptor adaptor,
6464
mlir::ConversionPatternRewriter *rewriter)
6565
: whileOp(op), adaptor(adaptor), rewriter(rewriter) {}
66-
void transferToSCFWhileOp();
66+
mlir::scf::WhileOp transferToSCFWhileOp();
6767

6868
private:
6969
cir::WhileOp whileOp;
7070
cir::WhileOp::Adaptor adaptor;
71+
mlir::scf::WhileOp scfWhileOp;
7172
mlir::ConversionPatternRewriter *rewriter;
7273
};
7374

@@ -337,7 +338,7 @@ void SCFLoop::transformToSCFWhileOp() {
337338
scfWhileOp.getAfterBody()->end());
338339
}
339340

340-
void SCFWhileLoop::transferToSCFWhileOp() {
341+
mlir::scf::WhileOp SCFWhileLoop::transferToSCFWhileOp() {
341342
auto scfWhileOp = rewriter->create<mlir::scf::WhileOp>(
342343
whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands());
343344
rewriter->createBlock(&scfWhileOp.getBefore());
@@ -348,6 +349,7 @@ void SCFWhileLoop::transferToSCFWhileOp() {
348349
rewriter->inlineBlockBefore(&whileOp.getBody().front(),
349350
scfWhileOp.getAfterBody(),
350351
scfWhileOp.getAfterBody()->end());
352+
return scfWhileOp;
351353
}
352354

353355
void SCFDoLoop::transferToSCFWhileOp() {
@@ -393,14 +395,53 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
393395
};
394396

395397
class CIRWhileOpLowering : public mlir::OpConversionPattern<cir::WhileOp> {
398+
void rewriteContinue(mlir::scf::WhileOp whileOp,
399+
mlir::ConversionPatternRewriter &rewriter) const {
400+
// Collect all ContinueOp inside this while.
401+
llvm::SmallVector<cir::ContinueOp> continues;
402+
whileOp->walk([&](mlir::Operation *op) {
403+
if (auto continueOp = dyn_cast<ContinueOp>(op))
404+
continues.push_back(continueOp);
405+
});
406+
407+
if (continues.empty())
408+
return;
409+
410+
for (auto continueOp : continues) {
411+
// When the break is under an IfOp, a direct replacement of `scf.yield`
412+
// won't work: the yield would jump out of that IfOp instead. We might
413+
// need to change the whileOp itself to achieve the same effect.
414+
for (mlir::Operation *parent = continueOp->getParentOp();
415+
parent != whileOp; parent = parent->getParentOp()) {
416+
if (isa<mlir::scf::IfOp>(parent) || isa<cir::IfOp>(parent))
417+
llvm_unreachable("NYI");
418+
}
419+
420+
// Operations after this break has to be removed.
421+
for (mlir::Operation *runner = continueOp->getNextNode(); runner;) {
422+
mlir::Operation *next = runner->getNextNode();
423+
runner->erase();
424+
runner = next;
425+
}
426+
427+
// Blocks after this break also has to be removed.
428+
for (mlir::Block *block = continueOp->getBlock()->getNextNode(); block;) {
429+
mlir::Block *next = block->getNextNode();
430+
block->erase();
431+
block = next;
432+
}
433+
}
434+
}
435+
396436
public:
397437
using OpConversionPattern<cir::WhileOp>::OpConversionPattern;
398438

399439
mlir::LogicalResult
400440
matchAndRewrite(cir::WhileOp op, OpAdaptor adaptor,
401441
mlir::ConversionPatternRewriter &rewriter) const override {
402442
SCFWhileLoop loop(op, adaptor, &rewriter);
403-
loop.transferToSCFWhileOp();
443+
auto whileOp = loop.transferToSCFWhileOp();
444+
rewriteContinue(whileOp, rewriter);
404445
rewriter.eraseOp(op);
405446
return mlir::success();
406447
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
void for_with_break() {
5+
int i = 0;
6+
while (i < 100) {
7+
i++;
8+
continue;
9+
i++;
10+
}
11+
// Only the first `i++` will be emitted.
12+
13+
// CHECK: scf.while : () -> () {
14+
// CHECK: %[[TMP0:.+]] = memref.load %alloca[]
15+
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
16+
// CHECK: %[[TMP1:.+]] = arith.cmpi slt, %[[TMP0]], %[[HUNDRED]]
17+
// CHECK: scf.condition(%[[TMP1]])
18+
// CHECK: } do {
19+
// CHECK: memref.alloca_scope {
20+
// CHECK: %[[TMP2:.+]] = memref.load %alloca[]
21+
// CHECK: %[[ONE:.+]] = arith.constant 1
22+
// CHECK: %[[TMP3:.+]] = arith.addi %[[TMP2]], %[[ONE]]
23+
// CHECK: memref.store %[[TMP3]], %alloca[]
24+
// CHECK: }
25+
// CHECK: scf.yield
26+
// CHECK: }
27+
}

0 commit comments

Comments
 (0)