Skip to content

Commit

Permalink
Initial support for omp.teams reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
skatrak committed Aug 22, 2024
1 parent cb01ff8 commit b996d13
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 66 deletions.
36 changes: 24 additions & 12 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1698,17 +1698,19 @@ static void genTaskwaitClauses(lower::AbstractConverter &converter,
loc, llvm::omp::Directive::OMPD_taskwait);
}

static void
genTeamsClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, bool evalOutsideTarget,
mlir::omp::TeamsOperands &clauseOps,
mlir::omp::NumTeamsClauseOps &numTeamsClauseOps,
mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps) {
static void genTeamsClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, bool evalOutsideTarget,
mlir::omp::TeamsOperands &clauseOps,
mlir::omp::NumTeamsClauseOps &numTeamsClauseOps,
mlir::omp::ThreadLimitClauseOps &threadLimitClauseOps,
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
cp.processReduction(loc, clauseOps, &reductionTypes, &reductionSyms);

// Evaluate NUM_TEAMS and THREAD_LIMIT on the host device, if currently inside
// of an omp.target operation.
Expand All @@ -1723,8 +1725,6 @@ genTeamsClauses(lower::AbstractConverter &converter,
cp.processNumTeams(stmtCtx, numTeamsClauseOps);
cp.processThreadLimit(stmtCtx, threadLimitClauseOps);
}

// cp.processTODO<clause::Reduction>(loc, llvm::omp::Directive::OMPD_teams);
}

static void genWsloopClauses(
Expand Down Expand Up @@ -2496,14 +2496,22 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::omp::TeamsOperands clauseOps;
mlir::omp::NumTeamsClauseOps numTeamsClauseOps;
mlir::omp::ThreadLimitClauseOps threadLimitClauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
llvm::SmallVector<mlir::Type> reductionTypes;
genTeamsClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
evalOutsideTarget, clauseOps, numTeamsClauseOps,
threadLimitClauseOps);
threadLimitClauseOps, reductionTypes, reductionSyms);

auto reductionCallback = [&](mlir::Operation *op) {
genReductionVars(op, converter, loc, reductionSyms, reductionTypes);
return llvm::SmallVector<const semantics::Symbol *>(reductionSyms);
};

auto teamsOp = genOpWithBody<mlir::omp::TeamsOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_teams)
.setClauses(&item->clauses),
.setClauses(&item->clauses)
.setGenRegionEntryCb(reductionCallback),
queue, item, clauseOps);

if (numTeamsClauseOps.numTeamsUpper) {
Expand Down Expand Up @@ -2721,6 +2729,8 @@ static void genCompositeDistributeParallelDo(
genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc,
wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms);

// Pass the innermost leaf construct's clauses because that's where COLLAPSE
// is placed by construct decomposition.
mlir::omp::LoopNestOperands loopNestClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
genLoopNestClauses(converter, semaCtx, eval, doItem->clauses, loc,
Expand Down Expand Up @@ -2804,6 +2814,8 @@ static void genCompositeDistributeParallelDoSimd(
mlir::omp::SimdOperands simdClauseOps;
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps);

// Pass the innermost leaf construct's clauses because that's where COLLAPSE
// is placed by construct decomposition.
mlir::omp::LoopNestOperands loopNestClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
Expand Down
74 changes: 28 additions & 46 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,16 +472,19 @@ static void printOrderClause(OpAsmPrinter &p, Operation *op,
//===----------------------------------------------------------------------===//

static ParseResult parseClauseWithRegionArgs(
OpAsmParser &parser, Region &region,
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref, ArrayAttr &symbols,
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
bool parseParens = true) {
SmallVector<SymbolRefAttr> reductionVec;
SmallVector<bool> isByRefVec;
unsigned regionArgOffset = regionPrivateArgs.size();

OpAsmParser::Delimiter delimiter = parseParens ? OpAsmParser::Delimiter::Paren
: OpAsmParser::Delimiter::None;
if (failed(
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
parser.parseCommaSeparatedList(delimiter, [&]() {
ParseResult optionalByref = parser.parseOptionalKeyword("byref");
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseOperand(operands.emplace_back()) ||
Expand Down Expand Up @@ -536,17 +539,17 @@ static ParseResult parseParallelRegion(
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;

if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionVars,
reductionTypes, reductionByref,
reductionSyms, regionPrivateArgs)))
if (failed(parseClauseWithRegionArgs(parser, reductionVars, reductionTypes,
reductionByref, reductionSyms,
regionPrivateArgs)))
return failure();
}

if (succeeded(parser.parseOptionalKeyword("private"))) {
auto privateByref = DenseBoolArrayAttr::get(parser.getContext(), {});
if (failed(parseClauseWithRegionArgs(parser, region, privateVars,
privateTypes, privateByref,
privateSyms, regionPrivateArgs)))
if (failed(parseClauseWithRegionArgs(parser, privateVars, privateTypes,
privateByref, privateSyms,
regionPrivateArgs)))
return failure();
if (llvm::any_of(privateByref.asArrayRef(),
[](bool byref) { return byref; })) {
Expand Down Expand Up @@ -597,45 +600,24 @@ static ParseResult parseReductionVarList(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
ArrayAttr &reductionSyms) {
SmallVector<SymbolRefAttr> reductionVec;
SmallVector<bool> isByRefVec;
if (failed(parser.parseCommaSeparatedList([&]() {
ParseResult optionalByref = parser.parseOptionalKeyword("byref");
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseArrow() ||
parser.parseOperand(reductionVars.emplace_back()) ||
parser.parseColonType(reductionTypes.emplace_back()))
return failure();
isByRefVec.push_back(optionalByref.succeeded());
return success();
})))
return failure();
reductionByref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
reductionSyms = ArrayAttr::get(parser.getContext(), reductions);
return success();
llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
return parseClauseWithRegionArgs(parser, reductionVars, reductionTypes,
reductionByref, reductionSyms,
regionPrivateArgs, /*parseParens=*/false);
}

/// Print Reduction clause
static void
printReductionVarList(OpAsmPrinter &p, Operation *op,
OperandRange reductionVars, TypeRange reductionTypes,
std::optional<DenseBoolArrayAttr> reductionByref,
std::optional<ArrayAttr> reductionSyms) {
auto getByRef = [&](unsigned i) -> const char * {
if (!reductionByref || !*reductionByref)
return "";
assert(reductionByref->empty() || i < reductionByref->size());
if (!reductionByref->empty() && (*reductionByref)[i])
return "byref ";
return "";
};

for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
if (i != 0)
p << ", ";
p << getByRef(i) << (*reductionSyms)[i] << " -> " << reductionVars[i]
<< " : " << reductionVars[i].getType();
static void printReductionVarList(OpAsmPrinter &p, Operation *op,
OperandRange reductionVars,
TypeRange reductionTypes,
DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms) {
if (reductionSyms) {
auto *argsBegin = op->getRegion(0).front().getArguments().begin();
MutableArrayRef argsSubrange(argsBegin, argsBegin + reductionTypes.size());
printClauseWithRegionArgs(p, op, argsSubrange, llvm::StringRef(),
reductionVars, reductionTypes, reductionByref,
reductionSyms);
}
}

Expand Down Expand Up @@ -1850,7 +1832,7 @@ parseWsloop(OpAsmParser &parser, Region &region,
// Parse an optional reduction clause
llvm::SmallVector<OpAsmParser::Argument> privates;
if (succeeded(parser.parseOptionalKeyword("reduction"))) {
if (failed(parseClauseWithRegionArgs(parser, region, reductionOperands,
if (failed(parseClauseWithRegionArgs(parser, reductionOperands,
reductionTypes, reductionByRef,
reductionSymbols, privates)))
return failure();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,8 @@ static LogicalResult createReductionsAndCleanup(
SmallVector<OwningReductionGen> &owningReductionGens,
SmallVector<OwningAtomicReductionGen> &owningAtomicReductionGens,
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos,
bool isTeamsReduction = false, bool hasDistribute = false) {
bool IsNowait = false, bool isTeamsReduction = false,
bool hasDistribute = false) {
// Process the reductions if required.
if (op.getNumReductionVars() == 0)
return success();
Expand All @@ -884,7 +885,7 @@ static LogicalResult createReductionsAndCleanup(
builder.SetInsertPoint(tempTerminator);
llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
isByRef, op.getNowait(), isTeamsReduction,
isByRef, IsNowait, isTeamsReduction,
hasDistribute);
if (!contInsertPoint.getBlock())
return op->emitOpError() << "failed to convert reductions";
Expand Down Expand Up @@ -1083,7 +1084,7 @@ convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
return createReductionsAndCleanup(
sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
privateReductionVariables, isByRef, owningReductionGens,
owningAtomicReductionGens, reductionInfos);
owningAtomicReductionGens, reductionInfos, sectionsOp.getNowait());
}

/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
Expand Down Expand Up @@ -1127,10 +1128,36 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
if (!op.getAllocatorVars().empty() || op.getReductionSyms() ||
!op.getPrivateVars().empty() || op.getPrivateSyms())
if (!op.getAllocatorVars().empty() || !op.getPrivateVars().empty() ||
op.getPrivateSyms())
return op.emitError("unhandled clauses for translation to LLVM IR");

llvm::ArrayRef<bool> isByRef = getIsByRef(op.getReductionByref());
assert(isByRef.size() == op.getNumReductionVars());

SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(op, reductionDecls);
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);

SmallVector<llvm::Value *> privateReductionVariables(
op.getNumReductionVars());
DenseMap<Value, llvm::Value *> reductionVariableMap;

MutableArrayRef<BlockArgument> reductionArgs = op.getRegion().getArguments();

if (failed(allocAndInitializeReductionVars(
op, reductionArgs, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables, reductionVariableMap,
isByRef)))
return failure();

// Store the mapping between reduction variables and their private copies on
// ModuleTranslation stack. It can be then recovered when translating
// omp.reduce operations in a separate call.
LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
moduleTranslation, reductionVariableMap);

auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
moduleTranslation, allocaIP);
Expand Down Expand Up @@ -1160,7 +1187,17 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
builder.restoreIP(ompBuilder->createTeams(
ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr));

return bodyGenStatus;
if (failed(bodyGenStatus))
return bodyGenStatus;

// Process the reductions if required.
SmallVector<OwningReductionGen> owningReductionGens;
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
return createReductionsAndCleanup(op, builder, moduleTranslation, allocaIP,
reductionDecls, privateReductionVariables,
isByRef, owningReductionGens,
owningAtomicReductionGens, reductionInfos);
}

static void
Expand Down Expand Up @@ -1430,8 +1467,8 @@ static LogicalResult convertOmpWsloop(
return createReductionsAndCleanup(
wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
privateReductionVariables, isByRef, owningReductionGens,
owningAtomicReductionGens, reductionInfos, /*isTeamsReduction=*/false,
distributeCodeGen);
owningAtomicReductionGens, reductionInfos, wsloopOp.getNowait(),
/*isTeamsReduction=*/false, distributeCodeGen);
}

static LogicalResult
Expand Down

0 comments on commit b996d13

Please sign in to comment.