Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 68 additions & 49 deletions flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,

for (mlir::Value local : loop.getLocalVars())
liveIns.push_back(local);

for (mlir::Value reduce : loop.getReduceVars())
liveIns.push_back(reduce);
}

/// Collects values that are local to a loop: "loop-local values". A loop-local
Expand Down Expand Up @@ -319,7 +322,7 @@ class DoConcurrentConversion
targetOp =
genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns,
targetClauseOps, loopNestClauseOps, liveInShapeInfoMap);
genTeamsOp(doLoop.getLoc(), rewriter);
genTeamsOp(rewriter, loop, mapper);
}

mlir::omp::ParallelOp parallelOp =
Expand Down Expand Up @@ -492,46 +495,7 @@ class DoConcurrentConversion
if (!mapToDevice)
genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);

if (!loop.getReduceVars().empty()) {
for (auto [op, byRef, sym, arg] : llvm::zip_equal(
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
loop.getRegionReduceArgs())) {
auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
sym.getLeafReference());

mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(firReducer);
std::string ompReducerName = sym.getLeafReference().str() + ".omp";

auto ompReducer =
moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
rewriter.getStringAttr(ompReducerName));

if (!ompReducer) {
ompReducer = mlir::omp::DeclareReductionOp::create(
rewriter, firReducer.getLoc(), ompReducerName,
firReducer.getTypeAttr().getValue());

cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
ompReducer.getAllocRegion());
cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
ompReducer.getInitializerRegion());
cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
ompReducer.getReductionRegion());
cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
ompReducer.getAtomicReductionRegion());
cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
ompReducer.getCleanupRegion());
moduleSymbolTable.insert(ompReducer);
}

wsloopClauseOps.reductionVars.push_back(op);
wsloopClauseOps.reductionByref.push_back(byRef);
wsloopClauseOps.reductionSyms.push_back(
mlir::SymbolRefAttr::get(ompReducer));
}
}
genReductions(rewriter, mapper, loop, wsloopClauseOps);

auto wsloopOp =
mlir::omp::WsloopOp::create(rewriter, loop.getLoc(), wsloopClauseOps);
Expand All @@ -553,8 +517,6 @@ class DoConcurrentConversion

rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
mlir::omp::YieldOp::create(rewriter, loop->getLoc());
loop->getParentOfType<mlir::ModuleOp>().print(
llvm::errs(), mlir::OpPrintingFlags().assumeVerified());

return {loopNestOp, wsloopOp};
}
Expand Down Expand Up @@ -778,15 +740,26 @@ class DoConcurrentConversion
liveInName, shape);
}

mlir::omp::TeamsOp
genTeamsOp(mlir::Location loc,
mlir::ConversionPatternRewriter &rewriter) const {
auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(
loc, /*clauses=*/mlir::omp::TeamsOperands{});
mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter,
fir::DoConcurrentLoopOp loop,
mlir::IRMapping &mapper) const {
mlir::omp::TeamsOperands teamsOps;
genReductions(rewriter, mapper, loop, teamsOps);

mlir::Location loc = loop.getLoc();
auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps);
Fortran::common::openmp::EntryBlockArgs teamsArgs;
teamsArgs.reduction.vars = teamsOps.reductionVars;
Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs,
teamsOp.getRegion());

rewriter.createBlock(&teamsOp.getRegion());
rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));

for (auto [loopVar, teamsArg] : llvm::zip_equal(
loop.getReduceVars(), teamsOp.getRegion().getArguments())) {
mapper.map(loopVar, teamsArg);
}

return teamsOp;
}

Expand Down Expand Up @@ -861,6 +834,52 @@ class DoConcurrentConversion
}
}

void genReductions(mlir::ConversionPatternRewriter &rewriter,
mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
mlir::omp::ReductionClauseOps &reductionClauseOps) const {
if (!loop.getReduceVars().empty()) {
for (auto [var, byRef, sym, arg] : llvm::zip_equal(
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
loop.getRegionReduceArgs())) {
auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
sym.getLeafReference());

mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(firReducer);
std::string ompReducerName = sym.getLeafReference().str() + ".omp";

auto ompReducer =
moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
rewriter.getStringAttr(ompReducerName));

if (!ompReducer) {
ompReducer = mlir::omp::DeclareReductionOp::create(
rewriter, firReducer.getLoc(), ompReducerName,
firReducer.getTypeAttr().getValue());

cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
ompReducer.getAllocRegion());
cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
ompReducer.getInitializerRegion());
cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
ompReducer.getReductionRegion());
cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
ompReducer.getAtomicReductionRegion());
cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
ompReducer.getCleanupRegion());
moduleSymbolTable.insert(ompReducer);
}

reductionClauseOps.reductionVars.push_back(
mapToDevice ? mapper.lookup(var) : var);
reductionClauseOps.reductionByref.push_back(byRef);
reductionClauseOps.reductionSyms.push_back(
mlir::SymbolRefAttr::get(ompReducer));
}
}
}

bool mapToDevice;
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
mlir::SymbolTable &moduleSymbolTable;
Expand Down
53 changes: 53 additions & 0 deletions flang/test/Transforms/DoConcurrent/reduce_device.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s

fir.declare_reduction @add_reduction_f32 : f32 init {
^bb0(%arg0: f32):
%cst = arith.constant 0.000000e+00 : f32
fir.yield(%cst : f32)
} combiner {
^bb0(%arg0: f32, %arg1: f32):
%0 = arith.addf %arg0, %arg1 fastmath<contract> : f32
fir.yield(%0 : f32)
}

func.func @_QPfoo() {
%0 = fir.dummy_scope : !fir.dscope
%3 = fir.alloca f32 {bindc_name = "s", uniq_name = "_QFfooEs"}
%4:2 = hlfir.declare %3 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
%c1 = arith.constant 1 : index
%c10 = arith.constant 1 : index
fir.do_concurrent {
%7 = fir.alloca i32 {bindc_name = "i"}
%8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) reduce(@add_reduction_f32 #fir.reduce_attr<add> %4#0 -> %arg1 : !fir.ref<f32>) {
%9 = fir.convert %arg0 : (index) -> i32
fir.store %9 to %8#0 : !fir.ref<i32>
%10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
%11 = fir.load %10#0 : !fir.ref<f32>
%cst = arith.constant 1.000000e+00 : f32
%12 = arith.addf %11, %cst fastmath<contract> : f32
hlfir.assign %12 to %10#0 : f32, !fir.ref<f32>
}
}
return
}

// CHECK: omp.declare_reduction @[[OMP_RED:.*.omp]] : f32

// CHECK: %[[S_DECL:.*]]:2 = hlfir.declare %6 {uniq_name = "_QFfooEs"}
// CHECK: %[[S_MAP:.*]] = omp.map.info var_ptr(%[[S_DECL]]#1

// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[S_MAP]] -> %[[S_TARGET_ARG:.*]] : {{.*}}) {
// CHECK: %[[S_DEV_DECL:.*]]:2 = hlfir.declare %[[S_TARGET_ARG]]
// CHECK: omp.teams reduction(@[[OMP_RED]] %[[S_DEV_DECL]]#0 -> %[[RED_TEAMS_ARG:.*]] : !fir.ref<f32>) {
// CHECK: omp.parallel {
// CHECK: omp.distribute {
// CHECK: omp.wsloop reduction(@[[OMP_RED]] %[[RED_TEAMS_ARG]] -> %[[RED_WS_ARG:.*]] : {{.*}}) {
// CHECK: %[[S_WS_DECL:.*]]:2 = hlfir.declare %[[RED_WS_ARG]] {uniq_name = "_QFfooEs"}
// CHECK: %[[S_VAL:.*]] = fir.load %[[S_WS_DECL]]#0
// CHECK: %[[RED_RES:.*]] = arith.addf %[[S_VAL]], %{{.*}} fastmath<contract> : f32
// CHECK: hlfir.assign %[[RED_RES]] to %[[S_WS_DECL]]#0
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: }
Loading