Skip to content

Commit

Permalink
[OpenACC] Implement 'reduction' for combined constructs.
Browse files Browse the repository at this point in the history
Once again, this is a clause on a combined construct that does almost
exactly what the loop/compute construct version does, only with some sl
ightly different evaluation rules/sema rules as it doesn't have to
consider the parent, just the 'combined' construct.  The two sets of
rules for reduction on loop and compute are fine together, so this
ensures they are all enforced for this too.

The 'gangs' 'num_gangs' 'reduction' diagnostic (Dim>1) had to be applied
to num_gangs as well, as it previously wasn't permissible to get in this
situation, but we now can.
  • Loading branch information
erichkeane committed Dec 9, 2024
1 parent 44cd8f0 commit 7d89ebf
Show file tree
Hide file tree
Showing 10 changed files with 445 additions and 97 deletions.
15 changes: 8 additions & 7 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12716,9 +12716,9 @@ def err_acc_clause_cannot_combine
: Error<"OpenACC clause '%0' may not appear on the same construct as a "
"'%1' clause on a '%2' construct">;
def err_acc_reduction_num_gangs_conflict
: Error<
"OpenACC 'reduction' clause may not appear on a 'parallel' construct "
"with a 'num_gangs' clause with more than 1 argument, have %0">;
: Error<"OpenACC '%1' clause %select{|with more than 1 argument }0may not "
"appear on a '%2' construct "
"with a '%3' clause%select{ with more than 1 argument|}0">;
def err_acc_reduction_type
: Error<"OpenACC 'reduction' variable must be of scalar type, sub-array, or a "
"composite of scalar types;%select{| sub-array base}1 type is %0">;
Expand Down Expand Up @@ -12779,13 +12779,14 @@ def err_acc_clause_in_clause_region
def err_acc_gang_reduction_conflict
: Error<"%select{OpenACC 'gang' clause with a 'dim' value greater than "
"1|OpenACC 'reduction' clause}0 cannot "
"appear on the same 'loop' construct as a %select{'reduction' "
"appear on the same '%1' construct as a %select{'reduction' "
"clause|'gang' clause with a 'dim' value greater than 1}0">;
def err_acc_gang_reduction_numgangs_conflict
: Error<"OpenACC '%0' clause cannot appear on the same 'loop' construct "
"as a '%1' clause inside a compute construct with a "
: Error<"OpenACC '%0' clause cannot appear on the same '%2' construct as a "
"'%1' clause %select{inside a compute construct with a|and a}3 "
"'num_gangs' clause with more than one argument">;
def err_reduction_op_mismatch

def err_reduction_op_mismatch
: Error<"OpenACC 'reduction' variable must have the same operator in all "
"nested constructs (%0 vs %1)">;
def err_acc_loop_variable_type
Expand Down
3 changes: 2 additions & 1 deletion clang/include/clang/Sema/SemaOpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,8 @@ class SemaOpenACC : public SemaBase {
// Does the checking for a 'gang' clause that needs to be done in dependent
// and not dependent cases.
OpenACCClause *
CheckGangClause(ArrayRef<const OpenACCClause *> ExistingClauses,
CheckGangClause(OpenACCDirectiveKind DirKind,
ArrayRef<const OpenACCClause *> ExistingClauses,
SourceLocation BeginLoc, SourceLocation LParenLoc,
ArrayRef<OpenACCGangKind> GangKinds,
ArrayRef<Expr *> IntExprs, SourceLocation EndLoc);
Expand Down
175 changes: 102 additions & 73 deletions clang/lib/Sema/SemaOpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,27 +719,52 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitNumGangsClause(
<< /*NoArgs=*/1 << Clause.getDirectiveKind() << MaxArgs
<< Clause.getIntExprs().size();

// OpenACC 3.3 Section 2.9.11: A reduction clause may not appear on a loop
// directive that has a gang clause and is within a compute construct that has
// a num_gangs clause with more than one explicit argument.
if (Clause.getIntExprs().size() > 1 &&
isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind())) {
auto *GangClauseItr =
llvm::find_if(ExistingClauses, llvm::IsaPred<OpenACCGangClause>);
auto *ReductionClauseItr =
llvm::find_if(ExistingClauses, llvm::IsaPred<OpenACCReductionClause>);

if (GangClauseItr != ExistingClauses.end() &&
ReductionClauseItr != ExistingClauses.end()) {
SemaRef.Diag(Clause.getBeginLoc(),
diag::err_acc_gang_reduction_numgangs_conflict)
<< OpenACCClauseKind::Reduction << OpenACCClauseKind::Gang
<< Clause.getDirectiveKind() << /*is on combined directive=*/1;
SemaRef.Diag((*ReductionClauseItr)->getBeginLoc(),
diag::note_acc_previous_clause_here);
SemaRef.Diag((*GangClauseItr)->getBeginLoc(),
diag::note_acc_previous_clause_here);
return nullptr;
}
}

// OpenACC 3.3 Section 2.5.4:
// A reduction clause may not appear on a parallel construct with a
// num_gangs clause that has more than one argument.
// TODO: OpenACC: Reduction on Combined Construct needs to do this too.
if (Clause.getDirectiveKind() == OpenACCDirectiveKind::Parallel &&
if ((Clause.getDirectiveKind() == OpenACCDirectiveKind::Parallel ||
Clause.getDirectiveKind() == OpenACCDirectiveKind::ParallelLoop) &&
Clause.getIntExprs().size() > 1) {
auto *Parallel =
llvm::find_if(ExistingClauses, llvm::IsaPred<OpenACCReductionClause>);

if (Parallel != ExistingClauses.end()) {
SemaRef.Diag(Clause.getBeginLoc(),
diag::err_acc_reduction_num_gangs_conflict)
<< Clause.getIntExprs().size();
<< /*>1 arg in first loc=*/1 << Clause.getClauseKind()
<< Clause.getDirectiveKind() << OpenACCClauseKind::Reduction;
SemaRef.Diag((*Parallel)->getBeginLoc(),
diag::note_acc_previous_clause_here);
return nullptr;
}
}

// OpenACC 3.3 Section 2.9.2:
// An argument with no keyword or with the 'num' wkeyword is allowed only when
// An argument with no keyword or with the 'num' keyword is allowed only when
// the 'num_gangs' does not appear on the 'kernel' construct.
if (Clause.getDirectiveKind() == OpenACCDirectiveKind::KernelsLoop) {
auto GangClauses = llvm::make_filter_range(
Expand Down Expand Up @@ -1457,32 +1482,36 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitGangClause(
// OpenACC 3.3 Section 2.9.11: A reduction clause may not appear on a loop
// directive that has a gang clause and is within a compute construct that has
// a num_gangs clause with more than one explicit argument.
// TODO OpenACC: When we implement reduction on combined constructs, we need
// to do this too.
if (Clause.getDirectiveKind() == OpenACCDirectiveKind::Loop &&
SemaRef.getActiveComputeConstructInfo().Kind !=
OpenACCDirectiveKind::Invalid) {
if ((Clause.getDirectiveKind() == OpenACCDirectiveKind::Loop &&
SemaRef.getActiveComputeConstructInfo().Kind !=
OpenACCDirectiveKind::Invalid) ||
isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind())) {
// num_gangs clause on the active compute construct.
auto *NumGangsClauseItr =
llvm::find_if(SemaRef.getActiveComputeConstructInfo().Clauses,
llvm::IsaPred<OpenACCNumGangsClause>);

auto *ReductionClauseItr =
llvm::find_if(ExistingClauses, llvm::IsaPred<OpenACCReductionClause>);

if (ReductionClauseItr != ExistingClauses.end() &&
NumGangsClauseItr !=
SemaRef.getActiveComputeConstructInfo().Clauses.end() &&
auto ActiveComputeConstructContainer =
isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind())
? ExistingClauses
: SemaRef.getActiveComputeConstructInfo().Clauses;
auto *NumGangsClauseItr = llvm::find_if(
ActiveComputeConstructContainer, llvm::IsaPred<OpenACCNumGangsClause>);

if (NumGangsClauseItr != ActiveComputeConstructContainer.end() &&
cast<OpenACCNumGangsClause>(*NumGangsClauseItr)->getIntExprs().size() >
1) {
SemaRef.Diag(Clause.getBeginLoc(),
diag::err_acc_gang_reduction_numgangs_conflict)
<< OpenACCClauseKind::Gang << OpenACCClauseKind::Reduction;
SemaRef.Diag((*ReductionClauseItr)->getBeginLoc(),
diag::note_acc_previous_clause_here);
SemaRef.Diag((*NumGangsClauseItr)->getBeginLoc(),
diag::note_acc_previous_clause_here);
return nullptr;
auto *ReductionClauseItr =
llvm::find_if(ExistingClauses, llvm::IsaPred<OpenACCReductionClause>);

if (ReductionClauseItr != ExistingClauses.end()) {
SemaRef.Diag(Clause.getBeginLoc(),
diag::err_acc_gang_reduction_numgangs_conflict)
<< OpenACCClauseKind::Gang << OpenACCClauseKind::Reduction
<< Clause.getDirectiveKind()
<< isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind());
SemaRef.Diag((*ReductionClauseItr)->getBeginLoc(),
diag::note_acc_previous_clause_here);
SemaRef.Diag((*NumGangsClauseItr)->getBeginLoc(),
diag::note_acc_previous_clause_here);
return nullptr;
}
}
}

Expand Down Expand Up @@ -1563,9 +1592,9 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitGangClause(
}
}

return SemaRef.CheckGangClause(ExistingClauses, Clause.getBeginLoc(),
Clause.getLParenLoc(), GangKinds, IntExprs,
Clause.getEndLoc());
return SemaRef.CheckGangClause(Clause.getDirectiveKind(), ExistingClauses,
Clause.getBeginLoc(), Clause.getLParenLoc(),
GangKinds, IntExprs, Clause.getEndLoc());
}

OpenACCClause *SemaOpenACCClauseVisitor::VisitSeqClause(
Expand Down Expand Up @@ -1609,41 +1638,39 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitSeqClause(

OpenACCClause *SemaOpenACCClauseVisitor::VisitReductionClause(
SemaOpenACC::OpenACCParsedClause &Clause) {
// Restrictions only properly implemented on 'compute' constructs, and
// 'compute' constructs are the only construct that can do anything with
// this yet, so skip/treat as unimplemented in this case.
// TODO: OpenACC: Remove check once we get combined constructs for this clause.
if (!isOpenACCComputeDirectiveKind(Clause.getDirectiveKind()) &&
Clause.getDirectiveKind() != OpenACCDirectiveKind::Loop)
return isNotImplemented();

// OpenACC 3.3 Section 2.9.11: A reduction clause may not appear on a loop
// directive that has a gang clause and is within a compute construct that has
// a num_gangs clause with more than one explicit argument.
if (Clause.getDirectiveKind() == OpenACCDirectiveKind::Loop &&
SemaRef.getActiveComputeConstructInfo().Kind !=
OpenACCDirectiveKind::Invalid) {
if ((Clause.getDirectiveKind() == OpenACCDirectiveKind::Loop &&
SemaRef.getActiveComputeConstructInfo().Kind !=
OpenACCDirectiveKind::Invalid) ||
isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind())) {
// num_gangs clause on the active compute construct.
auto *NumGangsClauseItr =
llvm::find_if(SemaRef.getActiveComputeConstructInfo().Clauses,
llvm::IsaPred<OpenACCNumGangsClause>);

auto *GangClauseItr =
llvm::find_if(ExistingClauses, llvm::IsaPred<OpenACCGangClause>);

if (GangClauseItr != ExistingClauses.end() &&
NumGangsClauseItr !=
SemaRef.getActiveComputeConstructInfo().Clauses.end() &&
auto ActiveComputeConstructContainer =
isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind())
? ExistingClauses
: SemaRef.getActiveComputeConstructInfo().Clauses;
auto *NumGangsClauseItr = llvm::find_if(
ActiveComputeConstructContainer, llvm::IsaPred<OpenACCNumGangsClause>);

if (NumGangsClauseItr != ActiveComputeConstructContainer.end() &&
cast<OpenACCNumGangsClause>(*NumGangsClauseItr)->getIntExprs().size() >
1) {
SemaRef.Diag(Clause.getBeginLoc(),
diag::err_acc_gang_reduction_numgangs_conflict)
<< OpenACCClauseKind::Reduction << OpenACCClauseKind::Gang;
SemaRef.Diag((*GangClauseItr)->getBeginLoc(),
diag::note_acc_previous_clause_here);
SemaRef.Diag((*NumGangsClauseItr)->getBeginLoc(),
diag::note_acc_previous_clause_here);
return nullptr;
auto *GangClauseItr =
llvm::find_if(ExistingClauses, llvm::IsaPred<OpenACCGangClause>);

if (GangClauseItr != ExistingClauses.end()) {
SemaRef.Diag(Clause.getBeginLoc(),
diag::err_acc_gang_reduction_numgangs_conflict)
<< OpenACCClauseKind::Reduction << OpenACCClauseKind::Gang
<< Clause.getDirectiveKind()
<< isOpenACCCombinedDirectiveKind(Clause.getDirectiveKind());
SemaRef.Diag((*GangClauseItr)->getBeginLoc(),
diag::note_acc_previous_clause_here);
SemaRef.Diag((*NumGangsClauseItr)->getBeginLoc(),
diag::note_acc_previous_clause_here);
return nullptr;
}
}
}

Expand All @@ -1667,7 +1694,8 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitReductionClause(
// OpenACC 3.3 Section 2.5.4:
// A reduction clause may not appear on a parallel construct with a
// num_gangs clause that has more than one argument.
if (Clause.getDirectiveKind() == OpenACCDirectiveKind::Parallel) {
if (Clause.getDirectiveKind() == OpenACCDirectiveKind::Parallel ||
Clause.getDirectiveKind() == OpenACCDirectiveKind::ParallelLoop) {
auto NumGangsClauses = llvm::make_filter_range(
ExistingClauses, llvm::IsaPred<OpenACCNumGangsClause>);

Expand All @@ -1678,7 +1706,8 @@ OpenACCClause *SemaOpenACCClauseVisitor::VisitReductionClause(
if (NumExprs > 1) {
SemaRef.Diag(Clause.getBeginLoc(),
diag::err_acc_reduction_num_gangs_conflict)
<< NumExprs;
<< /*>1 arg in first loc=*/0 << Clause.getClauseKind()
<< Clause.getDirectiveKind() << OpenACCClauseKind::NumGangs;
SemaRef.Diag(NGC->getBeginLoc(), diag::note_acc_previous_clause_here);
return nullptr;
}
Expand Down Expand Up @@ -2624,7 +2653,8 @@ SemaOpenACC::CheckGangExpr(ArrayRef<const OpenACCClause *> ExistingClauses,
}

OpenACCClause *
SemaOpenACC::CheckGangClause(ArrayRef<const OpenACCClause *> ExistingClauses,
SemaOpenACC::CheckGangClause(OpenACCDirectiveKind DirKind,
ArrayRef<const OpenACCClause *> ExistingClauses,
SourceLocation BeginLoc, SourceLocation LParenLoc,
ArrayRef<OpenACCGangKind> GangKinds,
ArrayRef<Expr *> IntExprs, SourceLocation EndLoc) {
Expand All @@ -2649,7 +2679,7 @@ SemaOpenACC::CheckGangClause(ArrayRef<const OpenACCClause *> ExistingClauses,
if (const auto *DimVal = dyn_cast<ConstantExpr>(DimExpr);
DimVal && DimVal->getResultAsAPSInt() > 1) {
Diag(DimVal->getBeginLoc(), diag::err_acc_gang_reduction_conflict)
<< /*gang/reduction=*/0;
<< /*gang/reduction=*/0 << DirKind;
Diag((*ReductionItr)->getBeginLoc(),
diag::note_acc_previous_clause_here);
return nullptr;
Expand All @@ -2666,30 +2696,29 @@ OpenACCClause *SemaOpenACC::CheckReductionClause(
OpenACCDirectiveKind DirectiveKind, SourceLocation BeginLoc,
SourceLocation LParenLoc, OpenACCReductionOperator ReductionOp,
ArrayRef<Expr *> Vars, SourceLocation EndLoc) {
if (DirectiveKind == OpenACCDirectiveKind::Loop) {
if (DirectiveKind == OpenACCDirectiveKind::Loop ||
isOpenACCCombinedDirectiveKind(DirectiveKind)) {
// OpenACC 3.3 2.9.11: A reduction clause may not appear on a loop directive
// that has a gang clause with a dim: argument whose value is greater
// than 1.
const auto *GangItr =
llvm::find_if(ExistingClauses, llvm::IsaPred<OpenACCGangClause>);
const auto GangClauses = llvm::make_filter_range(
ExistingClauses, llvm::IsaPred<OpenACCGangClause>);

while (GangItr != ExistingClauses.end()) {
auto *GangClause = cast<OpenACCGangClause>(*GangItr);
for (auto *GC : GangClauses) {
const auto *GangClause = cast<OpenACCGangClause>(GC);
for (unsigned I = 0; I < GangClause->getNumExprs(); ++I) {
std::pair<OpenACCGangKind, const Expr *> EPair = GangClause->getExpr(I);
// We know there is only 1 on this gang, so move onto the next gang.
if (EPair.first != OpenACCGangKind::Dim)
break;
continue;

if (const auto *DimVal = dyn_cast<ConstantExpr>(EPair.second);
DimVal && DimVal->getResultAsAPSInt() > 1) {
Diag(BeginLoc, diag::err_acc_gang_reduction_conflict)
<< /*reduction/gang=*/1;
Diag((*GangItr)->getBeginLoc(), diag::note_acc_previous_clause_here);
<< /*reduction/gang=*/1 << DirectiveKind;
Diag(GangClause->getBeginLoc(), diag::note_acc_previous_clause_here);
return nullptr;
}
}
++GangItr;
}
}

Expand Down
3 changes: 2 additions & 1 deletion clang/lib/Sema/TreeTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -12037,7 +12037,8 @@ void OpenACCClauseTransform<Derived>::VisitGangClause(
}

NewClause = Self.getSema().OpenACC().CheckGangClause(
ExistingClauses, ParsedClause.getBeginLoc(), ParsedClause.getLParenLoc(),
ParsedClause.getDirectiveKind(), ExistingClauses,
ParsedClause.getBeginLoc(), ParsedClause.getLParenLoc(),
TransformedGangKinds, TransformedIntExprs, ParsedClause.getEndLoc());
}
} // namespace
Expand Down
27 changes: 27 additions & 0 deletions clang/test/AST/ast-print-openacc-combined-construct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,4 +386,31 @@ void foo() {
#pragma acc serial loop vector
for(int i = 0;i<5;++i);

//CHECK: #pragma acc parallel loop reduction(+: iPtr)
#pragma acc parallel loop reduction(+: iPtr)
for(int i = 0;i<5;++i);
//CHECK: #pragma acc serial loop reduction(*: i)
#pragma acc serial loop reduction(*: i)
for(int i = 0;i<5;++i);
//CHECK: #pragma acc kernels loop reduction(max: SomeB)
#pragma acc kernels loop reduction(max: SomeB)
for(int i = 0;i<5;++i);
//CHECK: #pragma acc parallel loop reduction(min: iPtr)
#pragma acc parallel loop reduction(min: iPtr)
for(int i = 0;i<5;++i);
//CHECK: #pragma acc serial loop reduction(&: i)
#pragma acc serial loop reduction(&: i)
for(int i = 0;i<5;++i);
//CHECK: #pragma acc kernels loop reduction(|: SomeB)
#pragma acc kernels loop reduction(|: SomeB)
for(int i = 0;i<5;++i);
//CHECK: #pragma acc parallel loop reduction(^: iPtr)
#pragma acc parallel loop reduction(^: iPtr)
for(int i = 0;i<5;++i);
//CHECK: #pragma acc serial loop reduction(&&: i)
#pragma acc serial loop reduction(&&: i)
for(int i = 0;i<5;++i);
//CHECK: #pragma acc kernels loop reduction(||: SomeB)
#pragma acc kernels loop reduction(||: SomeB)
for(int i = 0;i<5;++i);
}
Loading

0 comments on commit 7d89ebf

Please sign in to comment.