Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -649,13 +649,6 @@ def ForallOp : SCF_Op<"forall", [

/// Returns true if the mapping specified for this forall op is linear.
bool usesLinearMapping();

/// RegionBranchOpInterface

OperandRange getEntrySuccessorOperands(RegionSuccessor successor) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: This now falls back to the default implementation of getEntrySuccessorOperands: an empty operand range.

return getInits();
}

}];
}

Expand All @@ -667,6 +660,7 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
Pure,
Terminator,
DeclareOpInterfaceMethods<InParallelOpInterface>,
ReturnLike,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This is a shortcut for RegionBranchTerminatorOpInterface.

HasParent<"ForallOp">,
] # GraphRegionNoTerminator.traits> {
let summary = "terminates a `forall` block";
Expand Down
31 changes: 18 additions & 13 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2013,21 +2013,26 @@ void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
ForallOpReplaceConstantInductionVar>(context);
}

/// Given the region at `index`, or the parent operation if `index` is None,
/// return the successor regions. These are the regions that may be selected
/// during the flow of control. `operands` is a set of optional attributes that
/// correspond to a constant value for each operand, or null if that operand is
/// not a constant.
void ForallOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// In accordance with the semantics of forall, its body is executed in
// parallel by multiple threads. We should not expect to branch back into
// the forall body after the region's execution is complete.
if (point.isParent())
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
else
regions.push_back(
RegionSuccessor(getOperation(), getOperation()->getResults()));
// There are two region branch points:
// 1. "parent": entering the forall op for the first time.
// 2. scf.in_parallel terminator
if (point.isParent()) {
// When first entering the forall op, the control flow typically branches
// into the forall body. (In parallel for multiple threads.)
regions.push_back(RegionSuccessor(&getRegion()));
// However, when there are 0 threads, the control flow may branch back to
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test the 0 thread case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The getSuccessorRegions implementation doesn't actually check the number of iterations of the loop. I.e., doesn't matter if 0 or more iterations. It just (conservatively) populates both region successors. (This could be improved across the entire SCF dialect. E.g., the implementation for scf.for also doesn't look at the number of iterations...)

I updated the test case to use an SSA value for the number of threads. In that case, we must populate both region successors because we do not statically know the branching behavior. Is that good enough? (That test case used to produce incorrect results before.)

// the parent immediately.
regions.emplace_back(getOperation(),
ResultRange{getResults().end(), getResults().end()});
} else {
// In accordance with the semantics of forall, its body is executed in
// parallel by multiple threads. We should not expect to branch back into
// the forall body after the region's execution is complete.
regions.emplace_back(getOperation(),
ResultRange{getResults().end(), getResults().end()});
}
}

//===----------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Analysis/DataFlow/test-dead-code-analysis.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,17 @@ func.func @test_dca_doesnt_crash() -> () {
func.func @test_dca_doesnt_crash_2() -> () attributes {symbol = @notexistant} {
return
}

func.func @test_forall_op_control_flow(%num_threads: index) {
// CHECK: test_forall_op_control_flow:
// CHECK: region #0
// CHECK: ^bb0 = live
// CHECK: region_preds: (all) predecessors:
// CHECK: scf.forall (%{{.*}}) in (%{{.*}}) {...} {tag = "test_forall_op_control_flow"}
// CHECK: op_preds: (all) predecessors:
// CHECK: scf.forall (%{{.*}}) in (%{{.*}}) {...} {tag = "test_forall_op_control_flow"}
// CHECK: scf.forall.in_parallel {...}
scf.forall (%arg0) in (%num_threads) {
} {tag = "test_forall_op_control_flow"}
return
}