@@ -141,6 +141,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
141141
142142 for (mlir::Value local : loop.getLocalVars ())
143143 liveIns.push_back (local);
144+
145+ for (mlir::Value reduce : loop.getReduceVars ())
146+ liveIns.push_back (reduce);
144147}
145148
146149// / Collects values that are local to a loop: "loop-local values". A loop-local
@@ -273,7 +276,7 @@ class DoConcurrentConversion
273276 targetOp =
274277 genTargetOp (doLoop.getLoc (), rewriter, mapper, loopNestLiveIns,
275278 targetClauseOps, loopNestClauseOps, liveInShapeInfoMap);
276- genTeamsOp (doLoop. getLoc (), rewriter );
279+ genTeamsOp (rewriter, loop, mapper );
277280 }
278281
279282 mlir::omp::ParallelOp parallelOp =
@@ -491,46 +494,7 @@ class DoConcurrentConversion
491494 if (!mapToDevice)
492495 genPrivatizers (rewriter, mapper, loop, wsloopClauseOps);
493496
494- if (!loop.getReduceVars ().empty ()) {
495- for (auto [op, byRef, sym, arg] : llvm::zip_equal (
496- loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
497- loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
498- loop.getRegionReduceArgs ())) {
499- auto firReducer = moduleSymbolTable.lookup <fir::DeclareReductionOp>(
500- sym.getLeafReference ());
501-
502- mlir::OpBuilder::InsertionGuard guard (rewriter);
503- rewriter.setInsertionPointAfter (firReducer);
504- std::string ompReducerName = sym.getLeafReference ().str () + " .omp" ;
505-
506- auto ompReducer =
507- moduleSymbolTable.lookup <mlir::omp::DeclareReductionOp>(
508- rewriter.getStringAttr (ompReducerName));
509-
510- if (!ompReducer) {
511- ompReducer = mlir::omp::DeclareReductionOp::create (
512- rewriter, firReducer.getLoc (), ompReducerName,
513- firReducer.getTypeAttr ().getValue ());
514-
515- cloneFIRRegionToOMP (rewriter, firReducer.getAllocRegion (),
516- ompReducer.getAllocRegion ());
517- cloneFIRRegionToOMP (rewriter, firReducer.getInitializerRegion (),
518- ompReducer.getInitializerRegion ());
519- cloneFIRRegionToOMP (rewriter, firReducer.getReductionRegion (),
520- ompReducer.getReductionRegion ());
521- cloneFIRRegionToOMP (rewriter, firReducer.getAtomicReductionRegion (),
522- ompReducer.getAtomicReductionRegion ());
523- cloneFIRRegionToOMP (rewriter, firReducer.getCleanupRegion (),
524- ompReducer.getCleanupRegion ());
525- moduleSymbolTable.insert (ompReducer);
526- }
527-
528- wsloopClauseOps.reductionVars .push_back (op);
529- wsloopClauseOps.reductionByref .push_back (byRef);
530- wsloopClauseOps.reductionSyms .push_back (
531- mlir::SymbolRefAttr::get (ompReducer));
532- }
533- }
497+ genReductions (rewriter, mapper, loop, wsloopClauseOps);
534498
535499 auto wsloopOp =
536500 mlir::omp::WsloopOp::create (rewriter, loop.getLoc (), wsloopClauseOps);
@@ -552,8 +516,6 @@ class DoConcurrentConversion
552516
553517 rewriter.setInsertionPointToEnd (&loopNestOp.getRegion ().back ());
554518 mlir::omp::YieldOp::create (rewriter, loop->getLoc ());
555- loop->getParentOfType <mlir::ModuleOp>().print (
556- llvm::errs (), mlir::OpPrintingFlags ().assumeVerified ());
557519
558520 return {loopNestOp, wsloopOp};
559521 }
@@ -774,15 +736,26 @@ class DoConcurrentConversion
774736 liveInName, shape);
775737 }
776738
777- mlir::omp::TeamsOp
778- genTeamsOp (mlir::Location loc,
779- mlir::ConversionPatternRewriter &rewriter) const {
780- auto teamsOp = rewriter.create <mlir::omp::TeamsOp>(
781- loc, /* clauses=*/ mlir::omp::TeamsOperands{});
739+ mlir::omp::TeamsOp genTeamsOp (mlir::ConversionPatternRewriter &rewriter,
740+ fir::DoConcurrentLoopOp loop,
741+ mlir::IRMapping &mapper) const {
742+ mlir::omp::TeamsOperands teamsOps;
743+ genReductions (rewriter, mapper, loop, teamsOps);
744+
745+ mlir::Location loc = loop.getLoc ();
746+ auto teamsOp = rewriter.create <mlir::omp::TeamsOp>(loc, teamsOps);
747+ Fortran::common::openmp::EntryBlockArgs teamsArgs;
748+ teamsArgs.reduction .vars = teamsOps.reductionVars ;
749+ Fortran::common::openmp::genEntryBlock (rewriter, teamsArgs,
750+ teamsOp.getRegion ());
782751
783- rewriter.createBlock (&teamsOp.getRegion ());
784752 rewriter.setInsertionPoint (rewriter.create <mlir::omp::TerminatorOp>(loc));
785753
754+ for (auto [loopVar, teamsArg] : llvm::zip_equal (
755+ loop.getReduceVars (), teamsOp.getRegion ().getArguments ())) {
756+ mapper.map (loopVar, teamsArg);
757+ }
758+
786759 return teamsOp;
787760 }
788761
@@ -849,6 +822,52 @@ class DoConcurrentConversion
849822 }
850823 }
851824
825+ void genReductions (mlir::ConversionPatternRewriter &rewriter,
826+ mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
827+ mlir::omp::ReductionClauseOps &reductionClauseOps) const {
828+ if (!loop.getReduceVars ().empty ()) {
829+ for (auto [var, byRef, sym, arg] : llvm::zip_equal (
830+ loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
831+ loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
832+ loop.getRegionReduceArgs ())) {
833+ auto firReducer = moduleSymbolTable.lookup <fir::DeclareReductionOp>(
834+ sym.getLeafReference ());
835+
836+ mlir::OpBuilder::InsertionGuard guard (rewriter);
837+ rewriter.setInsertionPointAfter (firReducer);
838+ std::string ompReducerName = sym.getLeafReference ().str () + " .omp" ;
839+
840+ auto ompReducer =
841+ moduleSymbolTable.lookup <mlir::omp::DeclareReductionOp>(
842+ rewriter.getStringAttr (ompReducerName));
843+
844+ if (!ompReducer) {
845+ ompReducer = mlir::omp::DeclareReductionOp::create (
846+ rewriter, firReducer.getLoc (), ompReducerName,
847+ firReducer.getTypeAttr ().getValue ());
848+
849+ cloneFIRRegionToOMP (rewriter, firReducer.getAllocRegion (),
850+ ompReducer.getAllocRegion ());
851+ cloneFIRRegionToOMP (rewriter, firReducer.getInitializerRegion (),
852+ ompReducer.getInitializerRegion ());
853+ cloneFIRRegionToOMP (rewriter, firReducer.getReductionRegion (),
854+ ompReducer.getReductionRegion ());
855+ cloneFIRRegionToOMP (rewriter, firReducer.getAtomicReductionRegion (),
856+ ompReducer.getAtomicReductionRegion ());
857+ cloneFIRRegionToOMP (rewriter, firReducer.getCleanupRegion (),
858+ ompReducer.getCleanupRegion ());
859+ moduleSymbolTable.insert (ompReducer);
860+ }
861+
862+ reductionClauseOps.reductionVars .push_back (
863+ mapToDevice ? mapper.lookup (var) : var);
864+ reductionClauseOps.reductionByref .push_back (byRef);
865+ reductionClauseOps.reductionSyms .push_back (
866+ mlir::SymbolRefAttr::get (ompReducer));
867+ }
868+ }
869+ }
870+
852871 bool mapToDevice;
853872 llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
854873 mlir::SymbolTable &moduleSymbolTable;
0 commit comments