Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 142 additions & 94 deletions llvm/lib/Transforms/IPO/GlobalOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2488,20 +2488,21 @@ DeleteDeadIFuncs(Module &M,
// Follows the use-def chain of \p V backwards until it finds a Function,
// in which case it collects in \p Versions. Return true on successful
// use-def chain traversal, false otherwise.
static bool collectVersions(TargetTransformInfo &TTI, Value *V,
SmallVectorImpl<Function *> &Versions) {
static bool
collectVersions(Value *V, SmallVectorImpl<Function *> &Versions,
function_ref<TargetTransformInfo &(Function &)> GetTTI) {
if (auto *F = dyn_cast<Function>(V)) {
if (!TTI.isMultiversionedFunction(*F))
if (!GetTTI(*F).isMultiversionedFunction(*F))
return false;
Versions.push_back(F);
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
if (!collectVersions(TTI, Sel->getTrueValue(), Versions))
if (!collectVersions(Sel->getTrueValue(), Versions, GetTTI))
return false;
if (!collectVersions(TTI, Sel->getFalseValue(), Versions))
if (!collectVersions(Sel->getFalseValue(), Versions, GetTTI))
return false;
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
if (!collectVersions(TTI, Phi->getIncomingValue(I), Versions))
if (!collectVersions(Phi->getIncomingValue(I), Versions, GetTTI))
return false;
} else {
// Unknown instruction type. Bail.
Expand All @@ -2510,31 +2511,37 @@ static bool collectVersions(TargetTransformInfo &TTI, Value *V,
return true;
}

// Bypass the IFunc Resolver of MultiVersioned functions when possible. To
// deduce whether the optimization is legal we need to compare the target
// features between caller and callee versions. The criteria for bypassing
// the resolver are the following:
//
// * If the callee's feature set is a subset of the caller's feature set,
// then the callee is a candidate for direct call.
//
// * Among such candidates the one of highest priority is the best match
// and it shall be picked, unless there is a version of the callee with
// higher priority than the best match which cannot be picked from a
// higher priority caller (directly or through the resolver).
//
// * For every higher priority callee version than the best match, there
// is a higher priority caller version whose feature set availability
// is implied by the callee's feature set.
// Try to statically resolve calls to versioned functions when possible. First
// we identify the function versions which are associated with an IFUNC symbol.
// We do that by examining the resolver function of the IFUNC. Once we have
// collected all the function versions, we sort them in decreasing priority
// order. This is necessary for identifying the highest priority callee version
// for a given caller version. We then collect all the callsites to versioned
// functions. The static resolution is performed by comparing the feature sets
// between callers and callees. Versions of the callee may be skipped if they
// depend on features we already know are unavailable. This information can
// be deduced on each subsequent iteration of the set of caller versions: prior
// iterations correspond to higher priority caller versions which would not have
// been selected in a hypothetical runtime execution.
//
// Presentation in EuroLLVM2025:
// https://www.youtube.com/watch?v=k54MFimPz-A&t=867s
static bool OptimizeNonTrivialIFuncs(
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
bool Changed = false;

// Cache containing the mask constructed from a function's target features.
// Map containing the feature bits for a given function.
DenseMap<Function *, APInt> FeatureMask;
// Map containing all the function versions corresponding to an IFunc symbol.
DenseMap<GlobalIFunc *, SmallVector<Function *>> VersionedFuncs;
// Map containing the IFunc symbol a function is version of.
DenseMap<Function *, GlobalIFunc *> VersionOf;
// List of all the interesting IFuncs found in the module.
SmallVector<GlobalIFunc *> IFuncs;

for (GlobalIFunc &IF : M.ifuncs()) {
LLVM_DEBUG(dbgs() << "Examining IFUNC " << IF.getName() << "\n");

if (IF.isInterposable())
continue;

Expand All @@ -2545,107 +2552,148 @@ static bool OptimizeNonTrivialIFuncs(
if (Resolver->isInterposable())
continue;

TargetTransformInfo &TTI = GetTTI(*Resolver);

// Discover the callee versions.
SmallVector<Function *> Callees;
if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
SmallVector<Function *> Versions;
// Discover the versioned functions.
if (any_of(*Resolver, [&](BasicBlock &BB) {
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
if (!collectVersions(Ret->getReturnValue(), Versions, GetTTI))
return true;
return false;
}))
continue;

if (Callees.empty())
if (Versions.empty())
continue;

LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
<< Resolver->getName() << "\n");

// Cache the feature mask for each callee.
for (Function *Callee : Callees) {
auto [It, Inserted] = FeatureMask.try_emplace(Callee);
for (Function *V : Versions) {
VersionOf.insert({V, &IF});
auto [It, Inserted] = FeatureMask.try_emplace(V);
if (Inserted)
It->second = TTI.getFeatureMask(*Callee);
It->second = GetTTI(*V).getFeatureMask(*V);
}

// Sort the callee versions in decreasing priority order.
sort(Callees, [&](auto *LHS, auto *RHS) {
// Sort function versions in decreasing priority order.
sort(Versions, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
});

// Find the callsites and cache the feature mask for each caller.
SmallVector<Function *> Callers;
IFuncs.push_back(&IF);
VersionedFuncs.try_emplace(&IF, std::move(Versions));
}

for (GlobalIFunc *CalleeIF : IFuncs) {
SmallVector<Function *> NonFMVCallers;
DenseSet<GlobalIFunc *> CallerIFuncs;
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
for (User *U : IF.users()) {

// Find the callsites.
for (User *U : CalleeIF->users()) {
if (auto *CB = dyn_cast<CallBase>(U)) {
if (CB->getCalledOperand() == &IF) {
if (CB->getCalledOperand() == CalleeIF) {
Function *Caller = CB->getFunction();
auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
if (FeatInserted)
FeatIt->second = TTI.getFeatureMask(*Caller);
auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
if (CallInserted)
Callers.push_back(Caller);
CallIt->second.push_back(CB);
GlobalIFunc *CallerIF = nullptr;
TargetTransformInfo &TTI = GetTTI(*Caller);
bool CallerIsFMV = TTI.isMultiversionedFunction(*Caller);
// The caller is a version of a known IFunc.
if (auto It = VersionOf.find(Caller); It != VersionOf.end())
CallerIF = It->second;
else if (!CallerIsFMV && OptimizeNonFMVCallers) {
// The caller is non-FMV.
auto [It, Inserted] = FeatureMask.try_emplace(Caller);
if (Inserted)
It->second = TTI.getFeatureMask(*Caller);
} else
// The caller is none of the above, skip.
continue;
auto [It, Inserted] = CallSites.try_emplace(Caller);
if (Inserted) {
if (CallerIsFMV)
CallerIFuncs.insert(CallerIF);
else
NonFMVCallers.push_back(Caller);
}
It->second.push_back(CB);
}
}
}

// Sort the caller versions in decreasing priority order.
sort(Callers, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS].ugt(FeatureMask[RHS]);
});

auto implies = [](APInt A, APInt B) { return B.isSubsetOf(A); };
if (CallSites.empty())
continue;

// Index to the highest priority candidate.
unsigned I = 0;
// Now try to redirect calls starting from higher priority callers.
for (Function *Caller : Callers) {
assert(I < Callees.size() && "Found callers of equal priority");
LLVM_DEBUG(dbgs() << "Statically resolving calls to function "
<< CalleeIF->getResolverFunction()->getName() << "\n");

// The complexity of this algorithm is linear: O(NumCallers + NumCallees).
// TODO
// A limitation it has is that we are not using information about the
// current caller to deduce why an earlier caller of higher priority was
// skipped. For example let's say the current caller is aes+sve2 and a
// previous caller was mops+sve2. Knowing that sve2 is available we could
// infer that mops is unavailable. This would allow us to skip callee
// versions which depend on mops. I tried implementing this but the
// complexity was cubic :/
auto redirectCalls = [&](ArrayRef<Function *> Callers,
ArrayRef<Function *> Callees) {
// Index to the highest callee candidate.
unsigned I = 0;

for (Function *const &Caller : Callers) {
bool CallerIsFMV = GetTTI(*Caller).isMultiversionedFunction(*Caller);

LLVM_DEBUG(dbgs() << " Examining "
<< (CallerIsFMV ? "FMV" : "regular") << " caller "
<< Caller->getName() << "\n");

if (I == Callees.size())
break;

Function *Callee = Callees[I];
APInt CallerBits = FeatureMask[Caller];
APInt CalleeBits = FeatureMask[Callee];
Function *Callee = Callees[I];
APInt CallerBits = FeatureMask[Caller];
APInt CalleeBits = FeatureMask[Callee];

// In the case of FMV callers, we know that all higher priority callers
// than the current one did not get selected at runtime, which helps
// reason about the callees (if they have versions that mandate presence
// of the features which we already know are unavailable on this target).
if (TTI.isMultiversionedFunction(*Caller)) {
// If the feature set of the caller implies the feature set of the
// highest priority candidate then it shall be picked. In case of
// identical sets advance the candidate index one position.
if (CallerBits == CalleeBits)
++I;
else if (!implies(CallerBits, CalleeBits)) {
// Keep advancing the candidate index as long as the caller's
// features are a subset of the current candidate's.
while (implies(CalleeBits, CallerBits)) {
if (++I == Callees.size())
break;
CalleeBits = FeatureMask[Callees[I]];
// callee then all the callsites can be statically resolved.
if (CalleeBits.isSubsetOf(CallerBits)) {
// Not all caller versions are necessarily users of the callee IFUNC.
if (auto It = CallSites.find(Caller); It != CallSites.end()) {
for (CallBase *CS : It->second) {
LLVM_DEBUG(dbgs() << " Redirecting call " << Caller->getName()
<< " -> " << Callee->getName() << "\n");
CS->setCalledOperand(Callee);
}
Changed = true;
}
continue;
}
} else {
// We can't reason much about non-FMV callers. Just pick the highest
// priority callee if it matches, otherwise bail.
if (!OptimizeNonFMVCallers || I > 0 || !implies(CallerBits, CalleeBits))

// Nothing else to do about non-FMV callers.
if (!CallerIsFMV)
continue;

// Subsequent iterations of the outermost loop (set of callers)
// will consider the caller of the current iteration unavailable.
// Therefore we can skip all those callees which depend on it.
while (CallerBits.isSubsetOf(CalleeBits)) {
if (++I == Callees.size())
break;
CalleeBits = FeatureMask[Callees[I]];
}
}
auto &Calls = CallSites[Caller];
for (CallBase *CS : Calls) {
LLVM_DEBUG(dbgs() << "Redirecting call " << Caller->getName() << " -> "
<< Callee->getName() << "\n");
CS->setCalledOperand(Callee);
}
Changed = true;
};

auto &Callees = VersionedFuncs[CalleeIF];

// Optimize non-FMV calls.
if (OptimizeNonFMVCallers)
redirectCalls(NonFMVCallers, Callees);

// Optimize FMV calls.
for (GlobalIFunc *CallerIF : CallerIFuncs) {
auto &Callers = VersionedFuncs[CallerIF];
redirectCalls(Callers, Callees);
}
if (IF.use_empty() ||
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))

if (CalleeIF->use_empty() ||
all_of(CalleeIF->users(), [](User *U) { return isa<GlobalAlias>(U); }))
NumIFuncsResolved++;
}
return Changed;
Expand Down
Loading
Loading