Skip to content

Commit

Permalink
PR60985: Fix merging of lambda closure types across modules.
Browse files Browse the repository at this point in the history
Previously, distinct lambdas would get merged, and multiple definitions
of the same lambda would not get merged, because we attempted to
identify lambdas by their ordinal position within their lexical
DeclContext. This failed for lambdas within namespace-scope variables
and within variable templates, where the lexical position in the context
containing the variable didn't uniquely identify the lambda.

In this patch, we instead identify lambda closure types by index within
their context declaration, which does uniquely identify them in a way
that's consistent across modules.

This change causes a deserialization cycle between the type of a
variable with deduced type and a lambda appearing as the initializer of
the variable -- reading the variable's type requires reading and merging
the lambda, and reading the lambda requires reading and merging the
variable. This is addressed by deferring loading the deduced type of a
variable until after we finish recursive deserialization.

This also exposes a pre-existing subtle issue where loading a
variable declaration would trigger immediate loading of its initializer,
which could recursively refer back to properties of the variable. This
particularly causes problems if the initializer contains a
lambda-expression, but can be problematic in general. That is addressed
by switching to lazily loading the initializers of variables rather than
always loading them with the variable declaration. As well as fixing a
deserialization cycle, that should improve laziness of deserialization
in general.

LambdaDefinitionData had 63 spare bits in it, presumably caused by an
off-by-one-error in some previous change. This change claims 32 of those bits
as a counter for the lambda within its context. We could probably move the
numbering to separate storage, like we do for the device-side mangling number,
to optimize the likely-common case where all three numbers (host-side mangling
number, device-side mangling number, and index within the context declaration)
are zero, but that's not done in this change.

Fixes #60985.

Reviewed By: #clang-language-wg, aaron.ballman

Differential Revision: https://reviews.llvm.org/D145737
  • Loading branch information
zygoloid committed Mar 30, 2023
1 parent 72c662a commit bc73ef0
Show file tree
Hide file tree
Showing 27 changed files with 593 additions and 248 deletions.
2 changes: 2 additions & 0 deletions clang/docs/ReleaseNotes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ Bug Fixes in This Version
- Fix handling of comments in function like macros so they are ignored in -CC
mode.
(`#60887 <https://github.com/llvm/llvm-project/issues/60887>`_)
- Fix incorrect merging of lambdas across modules.
(`#60985 <https://github.com/llvm/llvm-project/issues/60985>`_)


Bug Fixes to Compiler Builtins
Expand Down
2 changes: 1 addition & 1 deletion clang/include/clang/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,7 @@ struct EvaluatedStmt {
bool HasICEInit : 1;
bool CheckedForICEInit : 1;

Stmt *Value;
LazyDeclStmtPtr Value;
APValue Evaluated;

EvaluatedStmt()
Expand Down
39 changes: 28 additions & 11 deletions clang/include/clang/AST/DeclCXX.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ class CXXRecordDecl : public RecordDecl {
unsigned NumCaptures : 15;

/// The number of explicit captures in this lambda.
unsigned NumExplicitCaptures : 13;
unsigned NumExplicitCaptures : 12;

/// Has known `internal` linkage.
unsigned HasKnownInternalLinkage : 1;
Expand All @@ -404,6 +404,10 @@ class CXXRecordDecl : public RecordDecl {
/// mangling in the Itanium C++ ABI.
unsigned ManglingNumber : 31;

/// The index of this lambda within its context declaration. This is not in
/// general the same as the mangling number.
unsigned IndexInContext;

/// The declaration that provides context for this lambda, if the
/// actual DeclContext does not suffice. This is used for lambdas that
/// occur within default arguments of function parameters within the class
Expand All @@ -424,7 +428,7 @@ class CXXRecordDecl : public RecordDecl {
: DefinitionData(D), DependencyKind(DK), IsGenericLambda(IsGeneric),
CaptureDefault(CaptureDefault), NumCaptures(0),
NumExplicitCaptures(0), HasKnownInternalLinkage(0), ManglingNumber(0),
MethodTyInfo(Info) {
IndexInContext(0), MethodTyInfo(Info) {
IsLambda = true;

// C++1z [expr.prim.lambda]p4:
Expand Down Expand Up @@ -1772,18 +1776,31 @@ class CXXRecordDecl : public RecordDecl {
/// the declaration context suffices.
Decl *getLambdaContextDecl() const;

/// Set the mangling number and context declaration for a lambda
/// class.
void setLambdaMangling(unsigned ManglingNumber, Decl *ContextDecl,
bool HasKnownInternalLinkage = false) {
/// Retrieve the index of this lambda within the context declaration returned
/// by getLambdaContextDecl().
unsigned getLambdaIndexInContext() const {
assert(isLambda() && "Not a lambda closure type!");
getLambdaData().ManglingNumber = ManglingNumber;
getLambdaData().ContextDecl = ContextDecl;
getLambdaData().HasKnownInternalLinkage = HasKnownInternalLinkage;
return getLambdaData().IndexInContext;
}

/// Set the device side mangling number.
void setDeviceLambdaManglingNumber(unsigned Num) const;
/// Information about how a lambda is numbered within its context.
struct LambdaNumbering {
Decl *ContextDecl = nullptr;
unsigned IndexInContext = 0;
unsigned ManglingNumber = 0;
unsigned DeviceManglingNumber = 0;
bool HasKnownInternalLinkage = false;
};

/// Set the mangling numbers and context declaration for a lambda class.
void setLambdaNumbering(LambdaNumbering Numbering);

// Get the mangling numbers and context declaration for a lambda class.
LambdaNumbering getLambdaNumbering() const {
return {getLambdaContextDecl(), getLambdaIndexInContext(),
getLambdaManglingNumber(), getDeviceLambdaManglingNumber(),
hasKnownLambdaInternalLinkage()};
}

/// Retrieve the device side mangling number.
unsigned getDeviceLambdaManglingNumber() const;
Expand Down
10 changes: 9 additions & 1 deletion clang/include/clang/AST/ExternalASTSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,14 +371,22 @@ struct LazyOffsetPtr {
/// \param Source the external AST source.
///
/// \returns a pointer to the AST node.
T* get(ExternalASTSource *Source) const {
T *get(ExternalASTSource *Source) const {
if (isOffset()) {
assert(Source &&
"Cannot deserialize a lazy pointer without an AST source");
Ptr = reinterpret_cast<uint64_t>((Source->*Get)(Ptr >> 1));
}
return reinterpret_cast<T*>(Ptr);
}

/// Retrieve the address of the AST node pointer. Deserializes the pointee if
/// necessary.
T **getAddressOfPointer(ExternalASTSource *Source) const {
// Ensure the integer is in pointer form.
(void)get(Source);
return reinterpret_cast<T**>(&Ptr);
}
};

/// A lazy value (of type T) that is within an AST node of type Owner,
Expand Down
8 changes: 8 additions & 0 deletions clang/include/clang/AST/MangleNumberingContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ class VarDecl;
/// Keeps track of the mangled names of lambda expressions and block
/// literals within a particular context.
class MangleNumberingContext {
// The index of the next lambda we encounter in this context.
unsigned LambdaIndex = 0;

public:
virtual ~MangleNumberingContext() {}

Expand Down Expand Up @@ -55,6 +58,11 @@ class MangleNumberingContext {
/// given call operator within the device context. No device number is
/// assigned if there's no device numbering context is associated.
virtual unsigned getDeviceManglingNumber(const CXXMethodDecl *) { return 0; }

// Retrieve the index of the next lambda appearing in this context, which is
// used for deduplicating lambdas across modules. Note that this is a simple
// sequence number and is not ABI-dependent.
unsigned getNextLambdaIndex() { return LambdaIndex++; }
};

} // end namespace clang
Expand Down
25 changes: 25 additions & 0 deletions clang/include/clang/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -2092,6 +2092,11 @@ class IfStmt final
: nullptr;
}

void setConditionVariableDeclStmt(DeclStmt *CondVar) {
assert(hasVarStorage());
getTrailingObjects<Stmt *>()[varOffset()] = CondVar;
}

Stmt *getInit() {
return hasInitStorage() ? getTrailingObjects<Stmt *>()[initOffset()]
: nullptr;
Expand Down Expand Up @@ -2324,6 +2329,11 @@ class SwitchStmt final : public Stmt,
: nullptr;
}

void setConditionVariableDeclStmt(DeclStmt *CondVar) {
assert(hasVarStorage());
getTrailingObjects<Stmt *>()[varOffset()] = CondVar;
}

SwitchCase *getSwitchCaseList() { return FirstCase; }
const SwitchCase *getSwitchCaseList() const { return FirstCase; }
void setSwitchCaseList(SwitchCase *SC) { FirstCase = SC; }
Expand Down Expand Up @@ -2487,6 +2497,11 @@ class WhileStmt final : public Stmt,
: nullptr;
}

void setConditionVariableDeclStmt(DeclStmt *CondVar) {
assert(hasVarStorage());
getTrailingObjects<Stmt *>()[varOffset()] = CondVar;
}

SourceLocation getWhileLoc() const { return WhileStmtBits.WhileLoc; }
void setWhileLoc(SourceLocation L) { WhileStmtBits.WhileLoc = L; }

Expand Down Expand Up @@ -2576,6 +2591,8 @@ class DoStmt : public Stmt {
/// the init/cond/inc parts of the ForStmt will be null if they were not
/// specified in the source.
class ForStmt : public Stmt {
friend class ASTStmtReader;

enum { INIT, CONDVAR, COND, INC, BODY, END_EXPR };
Stmt* SubExprs[END_EXPR]; // SubExprs[INIT] is an expression or declstmt.
SourceLocation LParenLoc, RParenLoc;
Expand Down Expand Up @@ -2603,10 +2620,18 @@ class ForStmt : public Stmt {

/// If this ForStmt has a condition variable, return the faux DeclStmt
/// associated with the creation of that condition variable.
DeclStmt *getConditionVariableDeclStmt() {
return reinterpret_cast<DeclStmt*>(SubExprs[CONDVAR]);
}

const DeclStmt *getConditionVariableDeclStmt() const {
return reinterpret_cast<DeclStmt*>(SubExprs[CONDVAR]);
}

void setConditionVariableDeclStmt(DeclStmt *CondVar) {
SubExprs[CONDVAR] = CondVar;
}

Expr *getCond() { return reinterpret_cast<Expr*>(SubExprs[COND]); }
Expr *getInc() { return reinterpret_cast<Expr*>(SubExprs[INC]); }
Stmt *getBody() { return SubExprs[BODY]; }
Expand Down
5 changes: 5 additions & 0 deletions clang/include/clang/Sema/ExternalSemaSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ class ExternalSemaSource : public ExternalASTSource {
return false;
}

/// Notify the external source that a lambda was assigned a mangling number.
/// This enables the external source to track the correspondence between
/// lambdas and mangling numbers if necessary.
virtual void AssignedLambdaNumbering(const CXXRecordDecl *Lambda) {}

/// LLVM-style RTTI.
/// \{
bool isA(const void *ClassID) const override {
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Sema/MultiplexExternalSemaSource.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ class MultiplexExternalSemaSource : public ExternalSemaSource {
bool MaybeDiagnoseMissingCompleteType(SourceLocation Loc,
QualType T) override;

// Inform all attached sources that a mangling number was assigned.
void AssignedLambdaNumbering(const CXXRecordDecl *Lambda) override;

/// LLVM-style RTTI.
/// \{
bool isA(const void *ClassID) const override {
Expand Down
7 changes: 3 additions & 4 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -7108,10 +7108,9 @@ class Sema final {
Expr *TrailingRequiresClause);

/// Number lambda for linkage purposes if necessary.
void handleLambdaNumbering(
CXXRecordDecl *Class, CXXMethodDecl *Method,
std::optional<std::tuple<bool, unsigned, unsigned, Decl *>> Mangling =
std::nullopt);
void handleLambdaNumbering(CXXRecordDecl *Class, CXXMethodDecl *Method,
std::optional<CXXRecordDecl::LambdaNumbering>
NumberingOverride = std::nullopt);

/// Endow the lambda scope info with the relevant properties.
void buildLambdaScope(sema::LambdaScopeInfo *LSI, CXXMethodDecl *CallOperator,
Expand Down
19 changes: 18 additions & 1 deletion clang/include/clang/Serialization/ASTReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,10 @@ class ASTReader
llvm::DenseMap<Decl*, llvm::SmallVector<NamedDecl*, 2>>
AnonymousDeclarationsForMerging;

/// Map from numbering information for lambdas to the corresponding lambdas.
llvm::DenseMap<std::pair<const Decl *, unsigned>, NamedDecl *>
LambdaDeclarationsForMerging;

/// Key used to identify LifetimeExtendedTemporaryDecl for merging,
/// containing the lifetime-extending declaration and the mangling number.
using LETemporaryKey = std::pair<Decl *, unsigned>;
Expand Down Expand Up @@ -1101,7 +1105,13 @@ class ASTReader
/// they might contain a deduced return type that refers to a local type
/// declared within the function.
SmallVector<std::pair<FunctionDecl *, serialization::TypeID>, 16>
PendingFunctionTypes;
PendingDeducedFunctionTypes;

/// The list of deduced variable types that we have not yet read, because
/// they might contain a deduced type that refers to a local type declared
/// within the variable.
SmallVector<std::pair<VarDecl *, serialization::TypeID>, 16>
PendingDeducedVarTypes;

/// The list of redeclaration chains that still need to be
/// reconstructed, and the local offset to the corresponding list
Expand Down Expand Up @@ -1139,6 +1149,11 @@ class ASTReader
2>
PendingObjCExtensionIvarRedeclarations;

/// Members that have been added to classes, for which the class has not yet
/// been notified. CXXRecordDecl::addedMember will be called for each of
/// these once recursive deserialization is complete.
SmallVector<std::pair<CXXRecordDecl*, Decl*>, 4> PendingAddedClassMembers;

/// The set of NamedDecls that have been loaded, but are members of a
/// context that has been merged into another context where the corresponding
/// declaration is either missing or has not yet been loaded.
Expand Down Expand Up @@ -2082,6 +2097,8 @@ class ASTReader
llvm::MapVector<const FunctionDecl *, std::unique_ptr<LateParsedTemplate>>
&LPTMap) override;

void AssignedLambdaNumbering(const CXXRecordDecl *Lambda) override;

/// Load a selector from disk, registering its ID if it exists.
void LoadSelector(Selector Sel);

Expand Down
9 changes: 9 additions & 0 deletions clang/lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6622,6 +6622,10 @@ bool ASTContext::FriendsDifferByConstraints(const FunctionDecl *X,
}

bool ASTContext::isSameEntity(const NamedDecl *X, const NamedDecl *Y) const {
// Caution: this function is called by the AST reader during deserialization,
// so it cannot rely on AST invariants being met. Non-trivial accessors
// should be avoided, along with any traversal of redeclaration chains.

if (X == Y)
return true;

Expand Down Expand Up @@ -6757,6 +6761,11 @@ bool ASTContext::isSameEntity(const NamedDecl *X, const NamedDecl *Y) const {
if (const auto *VarX = dyn_cast<VarDecl>(X)) {
const auto *VarY = cast<VarDecl>(Y);
if (VarX->getLinkageInternal() == VarY->getLinkageInternal()) {
// During deserialization, we might compare variables before we load
// their types. Assume the types will end up being the same.
if (VarX->getType().isNull() || VarY->getType().isNull())
return true;

if (hasSameType(VarX->getType(), VarY->getType()))
return true;

Expand Down
9 changes: 4 additions & 5 deletions clang/lib/AST/ASTImporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2919,13 +2919,12 @@ ExpectedDecl ASTNodeImporter::VisitRecordDecl(RecordDecl *D) {
DC, *TInfoOrErr, Loc, DCXX->getLambdaDependencyKind(),
DCXX->isGenericLambda(), DCXX->getLambdaCaptureDefault()))
return D2CXX;
ExpectedDecl CDeclOrErr = import(DCXX->getLambdaContextDecl());
CXXRecordDecl::LambdaNumbering Numbering = DCXX->getLambdaNumbering();
ExpectedDecl CDeclOrErr = import(Numbering.ContextDecl);
if (!CDeclOrErr)
return CDeclOrErr.takeError();
D2CXX->setLambdaMangling(DCXX->getLambdaManglingNumber(), *CDeclOrErr,
DCXX->hasKnownLambdaInternalLinkage());
D2CXX->setDeviceLambdaManglingNumber(
DCXX->getDeviceLambdaManglingNumber());
Numbering.ContextDecl = *CDeclOrErr;
D2CXX->setLambdaNumbering(Numbering);
} else if (DCXX->isInjectedClassName()) {
// We have to be careful to do a similar dance to the one in
// Sema::ActOnStartCXXMemberDeclarations
Expand Down
11 changes: 7 additions & 4 deletions clang/lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2356,12 +2356,15 @@ Expr *VarDecl::getInit() {
if (auto *S = Init.dyn_cast<Stmt *>())
return cast<Expr>(S);

return cast_or_null<Expr>(Init.get<EvaluatedStmt *>()->Value);
auto *Eval = getEvaluatedStmt();
return cast<Expr>(Eval->Value.isOffset()
? Eval->Value.get(getASTContext().getExternalSource())
: Eval->Value.get(nullptr));
}

Stmt **VarDecl::getInitAddress() {
if (auto *ES = Init.dyn_cast<EvaluatedStmt *>())
return &ES->Value;
return ES->Value.getAddressOfPointer(getASTContext().getExternalSource());

return Init.getAddrOfPtr1();
}
Expand Down Expand Up @@ -2498,7 +2501,7 @@ APValue *VarDecl::evaluateValueImpl(SmallVectorImpl<PartialDiagnosticAt> &Notes,
bool IsConstantInitialization) const {
EvaluatedStmt *Eval = ensureEvaluatedStmt();

const auto *Init = cast<Expr>(Eval->Value);
const auto *Init = getInit();
assert(!Init->isValueDependent());

// We only produce notes indicating why an initializer is non-constant the
Expand Down Expand Up @@ -2582,7 +2585,7 @@ bool VarDecl::checkForConstantInitialization(
"already evaluated var value before checking for constant init");
assert(getASTContext().getLangOpts().CPlusPlus && "only meaningful in C++");

assert(!cast<Expr>(Eval->Value)->isValueDependent());
assert(!getInit()->isValueDependent());

// Evaluate the initializer to check whether it's a constant expression.
Eval->HasConstantInitialization =
Expand Down
11 changes: 8 additions & 3 deletions clang/lib/AST/DeclCXX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1646,10 +1646,15 @@ Decl *CXXRecordDecl::getLambdaContextDecl() const {
return getLambdaData().ContextDecl.get(Source);
}

void CXXRecordDecl::setDeviceLambdaManglingNumber(unsigned Num) const {
void CXXRecordDecl::setLambdaNumbering(LambdaNumbering Numbering) {
assert(isLambda() && "Not a lambda closure type!");
if (Num)
getASTContext().DeviceLambdaManglingNumbers[this] = Num;
getLambdaData().ManglingNumber = Numbering.ManglingNumber;
if (Numbering.DeviceManglingNumber)
getASTContext().DeviceLambdaManglingNumbers[this] =
Numbering.DeviceManglingNumber;
getLambdaData().IndexInContext = Numbering.IndexInContext;
getLambdaData().ContextDecl = Numbering.ContextDecl;
getLambdaData().HasKnownInternalLinkage = Numbering.HasKnownInternalLinkage;
}

unsigned CXXRecordDecl::getDeviceLambdaManglingNumber() const {
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/ODRDiagsEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1742,6 +1742,7 @@ bool ODRDiagsEmitter::diagnoseMismatch(
return true;
}

// Note, these calls can trigger deserialization.
const Expr *FirstInit = FirstParam->getInit();
const Expr *SecondInit = SecondParam->getInit();
if ((FirstInit == nullptr) != (SecondInit == nullptr)) {
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/Sema/MultiplexExternalSemaSource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,9 @@ bool MultiplexExternalSemaSource::MaybeDiagnoseMissingCompleteType(
}
return false;
}

void MultiplexExternalSemaSource::AssignedLambdaNumbering(
const CXXRecordDecl *Lambda) {
for (auto *Source : Sources)
Source->AssignedLambdaNumbering(Lambda);
}
Loading

0 comments on commit bc73ef0

Please sign in to comment.