-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[flang][OpenMP] do concurrent: support reduce on device
#156610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
78fc5ed to
78e1013
Compare
f748bd2 to
6987182
Compare
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Kareem Ergawy (ergawy) ChangesExtends
Full diff: https://github.com/llvm/llvm-project/pull/156610.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index 66b778fecc208..135382abb0227 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -140,6 +140,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
for (mlir::Value local : loop.getLocalVars())
liveIns.push_back(local);
+
+ for (mlir::Value reduce : loop.getReduceVars())
+ liveIns.push_back(reduce);
}
/// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -272,7 +275,7 @@ class DoConcurrentConversion
targetOp =
genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns,
targetClauseOps, loopNestClauseOps, liveInShapeInfoMap);
- genTeamsOp(doLoop.getLoc(), rewriter);
+ genTeamsOp(rewriter, loop, mapper);
}
mlir::omp::ParallelOp parallelOp =
@@ -488,46 +491,7 @@ class DoConcurrentConversion
if (!mapToDevice)
genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
- if (!loop.getReduceVars().empty()) {
- for (auto [op, byRef, sym, arg] : llvm::zip_equal(
- loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
- loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
- loop.getRegionReduceArgs())) {
- auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
- sym.getLeafReference());
-
- mlir::OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointAfter(firReducer);
- std::string ompReducerName = sym.getLeafReference().str() + ".omp";
-
- auto ompReducer =
- moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
- rewriter.getStringAttr(ompReducerName));
-
- if (!ompReducer) {
- ompReducer = mlir::omp::DeclareReductionOp::create(
- rewriter, firReducer.getLoc(), ompReducerName,
- firReducer.getTypeAttr().getValue());
-
- cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
- ompReducer.getAllocRegion());
- cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
- ompReducer.getInitializerRegion());
- cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
- ompReducer.getReductionRegion());
- cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
- ompReducer.getAtomicReductionRegion());
- cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
- ompReducer.getCleanupRegion());
- moduleSymbolTable.insert(ompReducer);
- }
-
- wsloopClauseOps.reductionVars.push_back(op);
- wsloopClauseOps.reductionByref.push_back(byRef);
- wsloopClauseOps.reductionSyms.push_back(
- mlir::SymbolRefAttr::get(ompReducer));
- }
- }
+ genReductions(rewriter, mapper, loop, wsloopClauseOps);
auto wsloopOp =
mlir::omp::WsloopOp::create(rewriter, loop.getLoc(), wsloopClauseOps);
@@ -549,8 +513,6 @@ class DoConcurrentConversion
rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
mlir::omp::YieldOp::create(rewriter, loop->getLoc());
- loop->getParentOfType<mlir::ModuleOp>().print(
- llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
return {loopNestOp, wsloopOp};
}
@@ -771,15 +733,26 @@ class DoConcurrentConversion
liveInName, shape);
}
- mlir::omp::TeamsOp
- genTeamsOp(mlir::Location loc,
- mlir::ConversionPatternRewriter &rewriter) const {
- auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(
- loc, /*clauses=*/mlir::omp::TeamsOperands{});
+ mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter,
+ fir::DoConcurrentLoopOp loop,
+ mlir::IRMapping &mapper) const {
+ mlir::omp::TeamsOperands teamsOps;
+ genReductions(rewriter, mapper, loop, teamsOps);
+
+ mlir::Location loc = loop.getLoc();
+ auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps);
+ Fortran::common::openmp::EntryBlockArgs teamsArgs;
+ teamsArgs.reduction.vars = teamsOps.reductionVars;
+ Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs,
+ teamsOp.getRegion());
- rewriter.createBlock(&teamsOp.getRegion());
rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
+ for (auto [loopVar, teamsArg] : llvm::zip_equal(
+ loop.getReduceVars(), teamsOp.getRegion().getArguments())) {
+ mapper.map(loopVar, teamsArg);
+ }
+
return teamsOp;
}
@@ -846,6 +819,52 @@ class DoConcurrentConversion
}
}
+ void genReductions(mlir::ConversionPatternRewriter &rewriter,
+ mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
+ mlir::omp::ReductionClauseOps &reductionClauseOps) const {
+ if (!loop.getReduceVars().empty()) {
+ for (auto [var, byRef, sym, arg] : llvm::zip_equal(
+ loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
+ loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
+ loop.getRegionReduceArgs())) {
+ auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
+ sym.getLeafReference());
+
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPointAfter(firReducer);
+ std::string ompReducerName = sym.getLeafReference().str() + ".omp";
+
+ auto ompReducer =
+ moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
+ rewriter.getStringAttr(ompReducerName));
+
+ if (!ompReducer) {
+ ompReducer = mlir::omp::DeclareReductionOp::create(
+ rewriter, firReducer.getLoc(), ompReducerName,
+ firReducer.getTypeAttr().getValue());
+
+ cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
+ ompReducer.getAllocRegion());
+ cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
+ ompReducer.getInitializerRegion());
+ cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
+ ompReducer.getReductionRegion());
+ cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
+ ompReducer.getAtomicReductionRegion());
+ cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
+ ompReducer.getCleanupRegion());
+ moduleSymbolTable.insert(ompReducer);
+ }
+
+ reductionClauseOps.reductionVars.push_back(
+ mapToDevice ? mapper.lookup(var) : var);
+ reductionClauseOps.reductionByref.push_back(byRef);
+ reductionClauseOps.reductionSyms.push_back(
+ mlir::SymbolRefAttr::get(ompReducer));
+ }
+ }
+ }
+
bool mapToDevice;
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
mlir::SymbolTable &moduleSymbolTable;
diff --git a/flang/test/Transforms/DoConcurrent/reduce_device.mlir b/flang/test/Transforms/DoConcurrent/reduce_device.mlir
new file mode 100644
index 0000000000000..3e46692a15dca
--- /dev/null
+++ b/flang/test/Transforms/DoConcurrent/reduce_device.mlir
@@ -0,0 +1,53 @@
+// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s
+
+fir.declare_reduction @add_reduction_f32 : f32 init {
+^bb0(%arg0: f32):
+ %cst = arith.constant 0.000000e+00 : f32
+ fir.yield(%cst : f32)
+} combiner {
+^bb0(%arg0: f32, %arg1: f32):
+ %0 = arith.addf %arg0, %arg1 fastmath<contract> : f32
+ fir.yield(%0 : f32)
+}
+
+func.func @_QPfoo() {
+ %0 = fir.dummy_scope : !fir.dscope
+ %3 = fir.alloca f32 {bindc_name = "s", uniq_name = "_QFfooEs"}
+ %4:2 = hlfir.declare %3 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 1 : index
+ fir.do_concurrent {
+ %7 = fir.alloca i32 {bindc_name = "i"}
+ %8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+ fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) reduce(@add_reduction_f32 #fir.reduce_attr<add> %4#0 -> %arg1 : !fir.ref<f32>) {
+ %9 = fir.convert %arg0 : (index) -> i32
+ fir.store %9 to %8#0 : !fir.ref<i32>
+ %10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+ %11 = fir.load %10#0 : !fir.ref<f32>
+ %cst = arith.constant 1.000000e+00 : f32
+ %12 = arith.addf %11, %cst fastmath<contract> : f32
+ hlfir.assign %12 to %10#0 : f32, !fir.ref<f32>
+ }
+ }
+ return
+}
+
+// CHECK: omp.declare_reduction @[[OMP_RED:.*.omp]] : f32
+
+// CHECK: %[[S_DECL:.*]]:2 = hlfir.declare %6 {uniq_name = "_QFfooEs"}
+// CHECK: %[[S_MAP:.*]] = omp.map.info var_ptr(%[[S_DECL]]#1
+
+// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[S_MAP]] -> %[[S_TARGET_ARG:.*]] : {{.*}}) {
+// CHECK: %[[S_DEV_DECL:.*]]:2 = hlfir.declare %[[S_TARGET_ARG]]
+// CHECK: omp.teams reduction(@[[OMP_RED]] %[[S_DEV_DECL]]#0 -> %[[RED_TEAMS_ARG:.*]] : !fir.ref<f32>) {
+// CHECK: omp.parallel {
+// CHECK: omp.distribute {
+// CHECK: omp.wsloop reduction(@[[OMP_RED]] %[[RED_TEAMS_ARG]] -> %[[RED_WS_ARG:.*]] : {{.*}}) {
+// CHECK: %[[S_WS_DECL:.*]]:2 = hlfir.declare %[[RED_WS_ARG]] {uniq_name = "_QFfooEs"}
+// CHECK: %[[S_VAL:.*]] = fir.load %[[S_WS_DECL]]#0
+// CHECK: %[[RED_RES:.*]] = arith.addf %[[S_VAL]], %{{.*}} fastmath<contract> : f32
+// CHECK: hlfir.assign %[[RED_RES]] to %[[S_WS_DECL]]#0
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
|
78e1013 to
13f4544
Compare
6987182 to
9a7ae05
Compare
…ide values (#155754) Following up on #154483, this PR introduces further refactoring to extract some shared utils between OpenMP lowering and `do concurrent` conversion pass. In particular, this PR extracts 2 utils that handle mapping or cloning values used inside target regions but defined outside. Later `do concurrent` PR(s) will also use these utils. PR stack: - #155754◀️ - #155987 - #155992 - #155993 - #156589 - #156610 - #156837
13f4544 to
c29c8a2
Compare
9a7ae05 to
6d02115
Compare
… clone outside values (#155754) Following up on #154483, this PR introduces further refactoring to extract some shared utils between OpenMP lowering and `do concurrent` conversion pass. In particular, this PR extracts 2 utils that handle mapping or cloning values used inside target regions but defined outside. Later `do concurrent` PR(s) will also use these utils. PR stack: - llvm/llvm-project#155754◀️ - llvm/llvm-project#155987 - llvm/llvm-project#155992 - llvm/llvm-project#155993 - llvm/llvm-project#156589 - llvm/llvm-project#156610 - llvm/llvm-project#156837
c29c8a2 to
e681a9f
Compare
6d02115 to
7fb93a3
Compare
e681a9f to
0c93791
Compare
7fb93a3 to
afd1552
Compare
0c93791 to
cbb2c67
Compare
afd1552 to
31bf2c1
Compare
cbb2c67 to
723193b
Compare
31bf2c1 to
3b73016
Compare
…#155987) Upstreams further parts of `do concurrent` to OpenMP conversion pass from AMD's fork. This PR extends the pass by adding support for mapping to the device. PR stack: - llvm/llvm-project#155754 - llvm/llvm-project#155987◀️ - llvm/llvm-project#155992 - llvm/llvm-project#155993 - llvm/llvm-project#157638 - llvm/llvm-project#156610 - llvm/llvm-project#156837
5099595 to
d67632e
Compare
5b9f176 to
bdd9ab2
Compare
…nMP mapping (#155993) Adds end-to-end tests for `do concurrent` offloading to the device. PR stack: - llvm/llvm-project#155754 - llvm/llvm-project#155987 - llvm/llvm-project#155992 - llvm/llvm-project#155993◀️ - llvm/llvm-project#157638 - llvm/llvm-project#156610 - llvm/llvm-project#156837
d67632e to
c59436f
Compare
bdd9ab2 to
76c4b9a
Compare
Extends support for mapping `do concurrent` on the device by adding support for `local` specifiers. The changes in this PR map the local variable to the `omp.target` op and uses the mapped value as the `private` clause operand in the nested `omp.parallel` op. - #155754 - #155987 - #155992 - #155993 - #157638◀️ - #156610 - #156837
Extends `do concurrent` to OpenMP device mapping by adding support for mapping `reduce` specifiers to omp `reduction` clauses. The changes attach 2 `reduction` clauses to the mapped OpenMP construct: one on the `teams` part of the construct and one on the `wloop` part.
76c4b9a to
f698b21
Compare
… (#157638) Extends support for mapping `do concurrent` on the device by adding support for `local` specifiers. The changes in this PR map the local variable to the `omp.target` op and uses the mapped value as the `private` clause operand in the nested `omp.parallel` op. - llvm/llvm-project#155754 - llvm/llvm-project#155987 - llvm/llvm-project#155992 - llvm/llvm-project#155993 - llvm/llvm-project#157638◀️ - llvm/llvm-project#156610 - llvm/llvm-project#156837
…e (#156610) Extends `do concurrent` to OpenMP device mapping by adding support for mapping `reduce` specifiers to omp `reduction` clauses. The changes attach 2 `reduction` clauses to the mapped OpenMP construct: one on the `teams` part of the construct and one on the `wloop` part. - llvm/llvm-project#155754 - llvm/llvm-project#155987 - llvm/llvm-project#155992 - llvm/llvm-project#155993 - llvm/llvm-project#157638 - llvm/llvm-project#156610◀️ - llvm/llvm-project#156837
…ions on the GPU (#156837) Fixes a bug related to insertion points when inlining multi-block combiner reduction regions. The IP at the end of the inlined region was not used resulting in emitting BBs with multiple terminators. PR stack: - llvm/llvm-project#155754 - llvm/llvm-project#155987 - llvm/llvm-project#155992 - llvm/llvm-project#155993 - llvm/llvm-project#157638 - llvm/llvm-project#156610 - llvm/llvm-project#156837◀️
Extends
do concurrentto OpenMP device mapping by adding support for mappingreducespecifiers to ompreductionclauses. The changes attach 2reductionclauses to the mapped OpenMP construct: one on theteamspart of the construct and one on thewlooppart.do concurrentmapping to device #155987do concurrentto device mapping lit tests #155992do concurrent: supportlocalon device #157638do concurrent: supportreduceon device #156610