diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 26582df013f571a..7938647563a22c2 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1582,18 +1582,20 @@ static void genTaskwaitClauses(lower::AbstractConverter &converter, loc, llvm::omp::Directive::OMPD_taskwait); } -static void -genTeamsClauses(lower::AbstractConverter &converter, - semantics::SemanticsContext &semaCtx, - lower::StatementContext &stmtCtx, const List &clauses, - mlir::Location loc, bool evalOutsideTarget, - mlir::omp::TeamsOperands &clauseOps, - mlir::omp::NumTeamsClauseOps &numTeamsClauseOps, - mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps) { +static void genTeamsClauses( + lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, const List &clauses, + mlir::Location loc, bool evalOutsideTarget, + mlir::omp::TeamsOperands &clauseOps, + mlir::omp::NumTeamsClauseOps &numTeamsClauseOps, + mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps, + llvm::SmallVectorImpl &reductionTypes, + llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processDefault(); cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); + cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms); // TODO Support delayed privatization. // Evaluate NUM_TEAMS and THREAD_LIMIT on the host device, if currently inside @@ -1609,8 +1611,6 @@ genTeamsClauses(lower::AbstractConverter &converter, cp.processNumTeams(stmtCtx, numTeamsClauseOps); cp.processThreadLimit(stmtCtx, threadLimitClauseOps); } - - // cp.processTODO(loc, llvm::omp::Directive::OMPD_teams); } static void genWsloopClauses( @@ -2384,14 +2384,22 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, mlir::omp::TeamsOperands clauseOps; mlir::omp::NumTeamsClauseOps numTeamsClauseOps; mlir::omp::ThreadLimitClauseOps threadLimitClauseOps; + llvm::SmallVector reductionSyms; + llvm::SmallVector reductionTypes; genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc, evalOutsideTarget, clauseOps, numTeamsClauseOps, - threadLimitClauseOps); + threadLimitClauseOps, reductionTypes, reductionSyms); + + auto reductionCallback = [&](mlir::Operation *op) { + genReductionVars(op, converter, loc, reductionSyms, reductionTypes); + return llvm::SmallVector(reductionSyms); + }; auto teamsOp = genOpWithBody( OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, llvm::omp::Directive::OMPD_teams) - .setClauses(&item->clauses), + .setClauses(&item->clauses) + .setGenRegionEntryCb(reductionCallback), queue, item, clauseOps); if (numTeamsClauseOps.numTeamsUpper) { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 82efe62525ce122..56afdf3e885e7a0 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -472,16 +472,21 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op, //===----------------------------------------------------------------------===// static ParseResult parseClauseWithRegionArgs( - OpAsmParser &parser, Region ®ion, + OpAsmParser &parser, SmallVectorImpl &operands, SmallVectorImpl &types, DenseBoolArrayAttr &byref, ArrayAttr &symbols, - SmallVectorImpl ®ionPrivateArgs) { + SmallVectorImpl ®ionPrivateArgs, + bool parseParens = true) { SmallVector reductionVec; SmallVector isByRefVec; unsigned regionArgOffset = regionPrivateArgs.size(); + OpAsmParser::Delimiter delimiter = parseParens ? OpAsmParser::Delimiter::Paren + : OpAsmParser::Delimiter::None; + + // TODO: Optionally parse parens. if (failed( - parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() { + parser.parseCommaSeparatedList(delimiter, [&]() { ParseResult optionalByref = parser.parseOptionalKeyword("byref"); if (parser.parseAttribute(reductionVec.emplace_back()) || parser.parseOperand(operands.emplace_back()) || @@ -536,17 +541,17 @@ static ParseResult parseParallelRegion( llvm::SmallVector regionPrivateArgs; if (succeeded(parser.parseOptionalKeyword("reduction"))) { - if (failed(parseClauseWithRegionArgs(parser, region, reductionVars, - reductionTypes, reductionByref, - reductionSyms, regionPrivateArgs))) + if (failed(parseClauseWithRegionArgs(parser, reductionVars, reductionTypes, + reductionByref, reductionSyms, + regionPrivateArgs))) return failure(); } if (succeeded(parser.parseOptionalKeyword("private"))) { auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {}); - if (failed(parseClauseWithRegionArgs(parser, region, privateVars, - privateTypes, privateByref, - privateSyms, regionPrivateArgs))) + if (failed(parseClauseWithRegionArgs(parser, privateVars, privateTypes, + privateByref, privateSyms, + regionPrivateArgs))) return failure(); if (llvm::any_of(privateByref.asArrayRef(), [](bool byref) { return byref; })) { @@ -597,45 +602,25 @@ static ParseResult parseReductionVarList( SmallVectorImpl &reductionVars, SmallVectorImpl &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms) { - SmallVector reductionVec; - SmallVector isByRefVec; - if (failed(parser.parseCommaSeparatedList([&]() { - ParseResult optionalByref = parser.parseOptionalKeyword("byref"); - if (parser.parseAttribute(reductionVec.emplace_back()) || - parser.parseArrow() || - parser.parseOperand(reductionVars.emplace_back()) || - parser.parseColonType(reductionTypes.emplace_back())) - return failure(); - isByRefVec.push_back(optionalByref.succeeded()); - return success(); - }))) - return failure(); - reductionByref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec); - SmallVector reductions(reductionVec.begin(), reductionVec.end()); - reductionSyms = ArrayAttr::get(parser.getContext(), reductions); - return success(); + llvm::SmallVector regionPrivateArgs; + return parseClauseWithRegionArgs(parser, reductionVars, reductionTypes, + reductionByref, reductionSyms, + regionPrivateArgs, /*parseParens=*/false); } /// Print Reduction clause -static void -printReductionVarList(OpAsmPrinter &p, Operation *op, - OperandRange reductionVars, TypeRange reductionTypes, - std::optional reductionByref, - std::optional reductionSyms) { - auto getByRef = [&](unsigned i) -> const char * { - if (!reductionByref || !*reductionByref) - return ""; - assert(reductionByref->empty() || i < reductionByref->size()); - if (!reductionByref->empty() && (*reductionByref)[i]) - return "byref "; - return ""; - }; - - for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) { - if (i != 0) - p << ", "; - p << getByRef(i) << (*reductionSyms)[i] << " -> " << reductionVars[i] - << " : " << reductionVars[i].getType(); +static void printReductionVarList(OpAsmPrinter &p, Operation *op, + OperandRange reductionVars, + TypeRange reductionTypes, + DenseBoolArrayAttr reductionByref, + ArrayAttr reductionSyms) { + if (reductionSyms) { + // TODO: Do not print entry block args. + auto *argsBegin = op->getRegion(0).front().getArguments().begin(); + MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size()); + printClauseWithRegionArgs(p, op, argsSubrange, llvm::StringRef(), + reductionVars, reductionTypes, reductionByref, + reductionSyms); } } @@ -1828,7 +1813,7 @@ parseWsloop(OpAsmParser &parser, Region ®ion, // Parse an optional reduction clause llvm::SmallVector privates; if (succeeded(parser.parseOptionalKeyword("reduction"))) { - if (failed(parseClauseWithRegionArgs(parser, region, reductionOperands, + if (failed(parseClauseWithRegionArgs(parser, reductionOperands, reductionTypes, reductionByRef, reductionSymbols, privates))) return failure(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 3aec1a56559756e..506cdd249de24e6 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -864,7 +864,8 @@ static LogicalResult createReductionsAndCleanup( SmallVector &owningReductionGens, SmallVector &owningAtomicReductionGens, SmallVector &reductionInfos, - bool isTeamsReduction = false, bool hasDistribute = false) { + bool IsNowait = false, bool isTeamsReduction = false, + bool hasDistribute = false) { // Process the reductions if required. if (op.getNumReductionVars() == 0) return success(); @@ -884,7 +885,7 @@ static LogicalResult createReductionsAndCleanup( builder.SetInsertPoint(tempTerminator); llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint = ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos, - isByRef, op.getNowait(), isTeamsReduction, + isByRef, IsNowait, isTeamsReduction, hasDistribute); if (!contInsertPoint.getBlock()) return op->emitOpError() << "failed to convert reductions"; @@ -1083,7 +1084,7 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder, return createReductionsAndCleanup( sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls, privateReductionVariables, isByRef, owningReductionGens, - owningAtomicReductionGens, reductionInfos); + owningAtomicReductionGens, reductionInfos, sectionsOp.getNowait()); } /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder. @@ -1127,10 +1128,36 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; LogicalResult bodyGenStatus = success(); - if (!op.getAllocatorVars().empty() || op.getReductionSyms() || - !op.getPrivateVars().empty() || op.getPrivateSyms()) + if (!op.getAllocatorVars().empty() || !op.getPrivateVars().empty() || + op.getPrivateSyms()) return op.emitError("unhandled clauses for translation to LLVM IR"); + llvm::ArrayRef isByRef = getIsByRef(op.getReductionByref()); + assert(isByRef.size() == op.getNumReductionVars()); + + SmallVector reductionDecls; + collectReductionDecls(op, reductionDecls); + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = + findAllocaInsertPoint(builder, moduleTranslation); + + SmallVector privateReductionVariables( + op.getNumReductionVars()); + DenseMap reductionVariableMap; + + MutableArrayRef reductionArgs = op.getRegion().getArguments(); + + if (failed(allocAndInitializeReductionVars( + op, reductionArgs, builder, moduleTranslation, allocaIP, + reductionDecls, privateReductionVariables, reductionVariableMap, + isByRef))) + return failure(); + + // Store the mapping between reduction variables and their private copies on + // ModuleTranslation stack. It can be then recovered when translating + // omp.reduce operations in a separate call. + LLVM::ModuleTranslation::SaveStack mappingGuard( + moduleTranslation, reductionVariableMap); + auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) { LLVM::ModuleTranslation::SaveStack frame( moduleTranslation, allocaIP); @@ -1160,7 +1187,17 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder, builder.restoreIP(ompBuilder->createTeams( ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr)); - return bodyGenStatus; + if (failed(bodyGenStatus)) + return bodyGenStatus; + + // Process the reductions if required. + SmallVector owningReductionGens; + SmallVector owningAtomicReductionGens; + SmallVector reductionInfos; + return createReductionsAndCleanup(op, builder, moduleTranslation, allocaIP, + reductionDecls, privateReductionVariables, + isByRef, owningReductionGens, + owningAtomicReductionGens, reductionInfos); } static void @@ -1430,8 +1467,8 @@ static LogicalResult convertOmpWsloop( return createReductionsAndCleanup( wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls, privateReductionVariables, isByRef, owningReductionGens, - owningAtomicReductionGens, reductionInfos, /*isTeamsReduction=*/false, - distributeCodeGen); + owningAtomicReductionGens, reductionInfos, wsloopOp.getNowait(), + /*isTeamsReduction=*/false, distributeCodeGen); } static LogicalResult