Skip to content

Commit afd1552

Browse files
committed
[flang][OpenMP] do concurrent: support reduce on device
Extends `do concurrent` to OpenMP device mapping by adding support for mapping `reduce` specifiers to omp `reduction` clauses. The changes attach 2 `reduction` clauses to the mapped OpenMP construct: one on the `teams` part of the construct and one on the `wloop` part.
1 parent 0c93791 commit afd1552

File tree

2 files changed

+121
-49
lines changed

2 files changed

+121
-49
lines changed

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 68 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
141141

142142
for (mlir::Value local : loop.getLocalVars())
143143
liveIns.push_back(local);
144+
145+
for (mlir::Value reduce : loop.getReduceVars())
146+
liveIns.push_back(reduce);
144147
}
145148

146149
/// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -273,7 +276,7 @@ class DoConcurrentConversion
273276
targetOp =
274277
genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns,
275278
targetClauseOps, loopNestClauseOps, liveInShapeInfoMap);
276-
genTeamsOp(doLoop.getLoc(), rewriter);
279+
genTeamsOp(rewriter, loop, mapper);
277280
}
278281

279282
mlir::omp::ParallelOp parallelOp =
@@ -491,46 +494,7 @@ class DoConcurrentConversion
491494
if (!mapToDevice)
492495
genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
493496

494-
if (!loop.getReduceVars().empty()) {
495-
for (auto [op, byRef, sym, arg] : llvm::zip_equal(
496-
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
497-
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
498-
loop.getRegionReduceArgs())) {
499-
auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
500-
sym.getLeafReference());
501-
502-
mlir::OpBuilder::InsertionGuard guard(rewriter);
503-
rewriter.setInsertionPointAfter(firReducer);
504-
std::string ompReducerName = sym.getLeafReference().str() + ".omp";
505-
506-
auto ompReducer =
507-
moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
508-
rewriter.getStringAttr(ompReducerName));
509-
510-
if (!ompReducer) {
511-
ompReducer = mlir::omp::DeclareReductionOp::create(
512-
rewriter, firReducer.getLoc(), ompReducerName,
513-
firReducer.getTypeAttr().getValue());
514-
515-
cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
516-
ompReducer.getAllocRegion());
517-
cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
518-
ompReducer.getInitializerRegion());
519-
cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
520-
ompReducer.getReductionRegion());
521-
cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
522-
ompReducer.getAtomicReductionRegion());
523-
cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
524-
ompReducer.getCleanupRegion());
525-
moduleSymbolTable.insert(ompReducer);
526-
}
527-
528-
wsloopClauseOps.reductionVars.push_back(op);
529-
wsloopClauseOps.reductionByref.push_back(byRef);
530-
wsloopClauseOps.reductionSyms.push_back(
531-
mlir::SymbolRefAttr::get(ompReducer));
532-
}
533-
}
497+
genReductions(rewriter, mapper, loop, wsloopClauseOps);
534498

535499
auto wsloopOp =
536500
mlir::omp::WsloopOp::create(rewriter, loop.getLoc(), wsloopClauseOps);
@@ -552,8 +516,6 @@ class DoConcurrentConversion
552516

553517
rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
554518
mlir::omp::YieldOp::create(rewriter, loop->getLoc());
555-
loop->getParentOfType<mlir::ModuleOp>().print(
556-
llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
557519

558520
return {loopNestOp, wsloopOp};
559521
}
@@ -774,15 +736,26 @@ class DoConcurrentConversion
774736
liveInName, shape);
775737
}
776738

777-
mlir::omp::TeamsOp
778-
genTeamsOp(mlir::Location loc,
779-
mlir::ConversionPatternRewriter &rewriter) const {
780-
auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(
781-
loc, /*clauses=*/mlir::omp::TeamsOperands{});
739+
mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter,
740+
fir::DoConcurrentLoopOp loop,
741+
mlir::IRMapping &mapper) const {
742+
mlir::omp::TeamsOperands teamsOps;
743+
genReductions(rewriter, mapper, loop, teamsOps);
744+
745+
mlir::Location loc = loop.getLoc();
746+
auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps);
747+
Fortran::common::openmp::EntryBlockArgs teamsArgs;
748+
teamsArgs.reduction.vars = teamsOps.reductionVars;
749+
Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs,
750+
teamsOp.getRegion());
782751

783-
rewriter.createBlock(&teamsOp.getRegion());
784752
rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
785753

754+
for (auto [loopVar, teamsArg] : llvm::zip_equal(
755+
loop.getReduceVars(), teamsOp.getRegion().getArguments())) {
756+
mapper.map(loopVar, teamsArg);
757+
}
758+
786759
return teamsOp;
787760
}
788761

@@ -849,6 +822,52 @@ class DoConcurrentConversion
849822
}
850823
}
851824

825+
void genReductions(mlir::ConversionPatternRewriter &rewriter,
826+
mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
827+
mlir::omp::ReductionClauseOps &reductionClauseOps) const {
828+
if (!loop.getReduceVars().empty()) {
829+
for (auto [var, byRef, sym, arg] : llvm::zip_equal(
830+
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
831+
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
832+
loop.getRegionReduceArgs())) {
833+
auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
834+
sym.getLeafReference());
835+
836+
mlir::OpBuilder::InsertionGuard guard(rewriter);
837+
rewriter.setInsertionPointAfter(firReducer);
838+
std::string ompReducerName = sym.getLeafReference().str() + ".omp";
839+
840+
auto ompReducer =
841+
moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
842+
rewriter.getStringAttr(ompReducerName));
843+
844+
if (!ompReducer) {
845+
ompReducer = mlir::omp::DeclareReductionOp::create(
846+
rewriter, firReducer.getLoc(), ompReducerName,
847+
firReducer.getTypeAttr().getValue());
848+
849+
cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
850+
ompReducer.getAllocRegion());
851+
cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
852+
ompReducer.getInitializerRegion());
853+
cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
854+
ompReducer.getReductionRegion());
855+
cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
856+
ompReducer.getAtomicReductionRegion());
857+
cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
858+
ompReducer.getCleanupRegion());
859+
moduleSymbolTable.insert(ompReducer);
860+
}
861+
862+
reductionClauseOps.reductionVars.push_back(
863+
mapToDevice ? mapper.lookup(var) : var);
864+
reductionClauseOps.reductionByref.push_back(byRef);
865+
reductionClauseOps.reductionSyms.push_back(
866+
mlir::SymbolRefAttr::get(ompReducer));
867+
}
868+
}
869+
}
870+
852871
bool mapToDevice;
853872
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
854873
mlir::SymbolTable &moduleSymbolTable;
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s
2+
3+
fir.declare_reduction @add_reduction_f32 : f32 init {
4+
^bb0(%arg0: f32):
5+
%cst = arith.constant 0.000000e+00 : f32
6+
fir.yield(%cst : f32)
7+
} combiner {
8+
^bb0(%arg0: f32, %arg1: f32):
9+
%0 = arith.addf %arg0, %arg1 fastmath<contract> : f32
10+
fir.yield(%0 : f32)
11+
}
12+
13+
func.func @_QPfoo() {
14+
%0 = fir.dummy_scope : !fir.dscope
15+
%3 = fir.alloca f32 {bindc_name = "s", uniq_name = "_QFfooEs"}
16+
%4:2 = hlfir.declare %3 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
17+
%c1 = arith.constant 1 : index
18+
%c10 = arith.constant 1 : index
19+
fir.do_concurrent {
20+
%7 = fir.alloca i32 {bindc_name = "i"}
21+
%8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
22+
fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) reduce(@add_reduction_f32 #fir.reduce_attr<add> %4#0 -> %arg1 : !fir.ref<f32>) {
23+
%9 = fir.convert %arg0 : (index) -> i32
24+
fir.store %9 to %8#0 : !fir.ref<i32>
25+
%10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
26+
%11 = fir.load %10#0 : !fir.ref<f32>
27+
%cst = arith.constant 1.000000e+00 : f32
28+
%12 = arith.addf %11, %cst fastmath<contract> : f32
29+
hlfir.assign %12 to %10#0 : f32, !fir.ref<f32>
30+
}
31+
}
32+
return
33+
}
34+
35+
// CHECK: omp.declare_reduction @[[OMP_RED:.*.omp]] : f32
36+
37+
// CHECK: %[[S_DECL:.*]]:2 = hlfir.declare %6 {uniq_name = "_QFfooEs"}
38+
// CHECK: %[[S_MAP:.*]] = omp.map.info var_ptr(%[[S_DECL]]#1
39+
40+
// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[S_MAP]] -> %[[S_TARGET_ARG:.*]] : {{.*}}) {
41+
// CHECK: %[[S_DEV_DECL:.*]]:2 = hlfir.declare %[[S_TARGET_ARG]]
42+
// CHECK: omp.teams reduction(@[[OMP_RED]] %[[S_DEV_DECL]]#0 -> %[[RED_TEAMS_ARG:.*]] : !fir.ref<f32>) {
43+
// CHECK: omp.parallel {
44+
// CHECK: omp.distribute {
45+
// CHECK: omp.wsloop reduction(@[[OMP_RED]] %[[RED_TEAMS_ARG]] -> %[[RED_WS_ARG:.*]] : {{.*}}) {
46+
// CHECK: %[[S_WS_DECL:.*]]:2 = hlfir.declare %[[RED_WS_ARG]] {uniq_name = "_QFfooEs"}
47+
// CHECK: %[[S_VAL:.*]] = fir.load %[[S_WS_DECL]]#0
48+
// CHECK: %[[RED_RES:.*]] = arith.addf %[[S_VAL]], %{{.*}} fastmath<contract> : f32
49+
// CHECK: hlfir.assign %[[RED_RES]] to %[[S_WS_DECL]]#0
50+
// CHECK: }
51+
// CHECK: }
52+
// CHECK: }
53+
// CHECK: }

0 commit comments

Comments
 (0)