Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FMV][GlobalOpt] Statically resolve calls to versioned functions. #87939

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,9 @@ class TargetTransformInfo {
/// false, but it shouldn't matter what it returns anyway.
bool hasArmWideBranch(bool Thumb) const;

/// Returns a bitmask constructed from the target features of a function.
uint64_t getFeatureMask(Function &F) const;

/// \return The maximum number of function arguments the target supports.
unsigned getMaxNumArgs() const;

Expand Down Expand Up @@ -2253,6 +2256,7 @@ class TargetTransformInfo::Concept {
virtual VPLegalization
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
virtual bool hasArmWideBranch(bool Thumb) const = 0;
virtual uint64_t getFeatureMask(Function &F) const = 0;
virtual unsigned getMaxNumArgs() const = 0;
virtual unsigned getNumBytesToPadGlobalArray(unsigned Size,
Type *ArrayType) const = 0;
Expand Down Expand Up @@ -3061,6 +3065,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.hasArmWideBranch(Thumb);
}

uint64_t getFeatureMask(Function &F) const override {
return Impl.getFeatureMask(F);
}

unsigned getMaxNumArgs() const override {
return Impl.getMaxNumArgs();
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,8 @@ class TargetTransformInfoImplBase {

bool hasArmWideBranch(bool) const { return false; }

uint64_t getFeatureMask(Function &F) const { return 0; }

unsigned getMaxNumArgs() const { return UINT_MAX; }

unsigned getNumBytesToPadGlobalArray(unsigned Size, Type *ArrayType) const {
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,10 @@ bool TargetTransformInfo::hasArmWideBranch(bool Thumb) const {
return TTIImpl->hasArmWideBranch(Thumb);
}

uint64_t TargetTransformInfo::getFeatureMask(Function &F) const {
return TTIImpl->getFeatureMask(F);
}

unsigned TargetTransformInfo::getMaxNumArgs() const {
return TTIImpl->getMaxNumArgs();
}
Expand Down
10 changes: 10 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/TargetParser/AArch64TargetParser.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
#include <algorithm>
Expand Down Expand Up @@ -245,6 +246,15 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
return false;
}

uint64_t AArch64TTIImpl::getFeatureMask(Function &F) const {
StringRef FeatureStr = F.getFnAttribute("target-features").getValueAsString();
SmallVector<StringRef, 8> Features;
FeatureStr.split(Features, ",");
if (none_of(Features, [](StringRef Feat) { return Feat == "+fmv"; }))
return 0;
return AArch64::getCpuSupportsMask(Features);
}

bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
const Function *Callee) const {
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
unsigned DefaultCallPenalty) const;

uint64_t getFeatureMask(Function &F) const;

/// \name Scalar TTI Implementations
/// @{

Expand Down
15 changes: 13 additions & 2 deletions llvm/lib/TargetParser/AArch64TargetParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,22 @@ std::optional<AArch64::ArchInfo> AArch64::ArchInfo::findBySubArch(StringRef SubA
return {};
}

std::optional<AArch64::FMVInfo>
lookupFMVByID(llvm::AArch64::ArchExtKind ExtID) {
for (const auto &I : llvm::AArch64::getFMVInfo())
if (I.ID && *I.ID == ExtID)
return I;
return {};
}

uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> FeatureStrs) {
uint64_t FeaturesMask = 0;
for (const StringRef &FeatureStr : FeatureStrs) {
if (auto Ext = parseFMVExtension(FeatureStr))
FeaturesMask |= (1ULL << Ext->Bit);
if (auto FMVExt = parseFMVExtension(FeatureStr))
FeaturesMask |= (1ULL << FMVExt->Bit);
else if (auto ArchExt = targetFeatureToExtension(FeatureStr))
if (auto FMVExt = lookupFMVByID(ArchExt->ID))
FeaturesMask |= (1ULL << FMVExt->Bit);
}
return FeaturesMask;
}
Expand Down
147 changes: 147 additions & 0 deletions llvm/lib/Transforms/IPO/GlobalOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2641,6 +2641,150 @@ DeleteDeadIFuncs(Module &M,
return Changed;
}

// Follows the use-def chain of \p V backwards until it finds a Function,
// in which case it collects in \p Versions.
static void collectVersions(Value *V, SmallVectorImpl<Function *> &Versions) {
if (auto *F = dyn_cast<Function>(V)) {
Versions.push_back(F);
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
collectVersions(Sel->getTrueValue(), Versions);
collectVersions(Sel->getFalseValue(), Versions);
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
collectVersions(Phi->getIncomingValue(I), Versions);
}
labrinea marked this conversation as resolved.
Show resolved Hide resolved
}

// Bypass the IFunc Resolver of MultiVersioned functions when possible. To
// deduce whether the optimization is legal we need to compare the target
// features between caller and callee versions. The criteria for bypassing
// the resolver are the following:
//
// * If the callee's feature set is a subset of the caller's feature set,
// then the callee is a candidate for direct call.
//
// * Among such candidates the one of highest priority is the best match
// and it shall be picked, unless there is a version of the callee with
// higher priority than the best match which cannot be picked from a
// higher priority caller (directly or through the resolver).
//
// * For every higher priority callee version than the best match, there
// is a higher priority caller version whose feature set availability
// is implied by the callee's feature set.
//
static bool OptimizeNonTrivialIFuncs(
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
bool Changed = false;

// Cache containing the mask constructed from a function's target features.
DenseMap<Function *, uint64_t> FeatureMask;

for (GlobalIFunc &IF : M.ifuncs()) {
if (IF.isInterposable())
continue;

Function *Resolver = IF.getResolverFunction();
if (!Resolver)
continue;

if (Resolver->isInterposable())
continue;

// Discover the callee versions.
SmallVector<Function *> Callees;
for (BasicBlock &BB : *Resolver)
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
collectVersions(Ret->getReturnValue(), Callees);

if (Callees.empty())
continue;

TargetTransformInfo &TTI = GetTTI(*Resolver);

// Cache the feature mask for each callee.
bool IsFMV = true;
for (Function *Callee : Callees) {
auto [It, Inserted] = FeatureMask.try_emplace(Callee);
if (Inserted) {
It->second = TTI.getFeatureMask(*Callee);
// Empty mask means this isn't an FMV callee.
if (It->second == 0) {
IsFMV = false;
break;
}
}
}

// This IFunc is not FMV.
if (!IsFMV)
continue;

// Sort the callee versions in decreasing priority order.
sort(Callees, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS] > FeatureMask[RHS];
labrinea marked this conversation as resolved.
Show resolved Hide resolved
});

// Find the callsites and cache the feature mask for each caller.
SmallVector<Function *> Callers;
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
for (User *U : IF.users()) {
if (auto *CB = dyn_cast<CallBase>(U)) {
if (CB->getCalledOperand() == &IF) {
Function *Caller = CB->getFunction();
auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
if (FeatInserted)
FeatIt->second = TTI.getFeatureMask(*Caller);
auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
if (CallInserted)
Callers.push_back(Caller);
CallIt->second.push_back(CB);
}
}
}

// Sort the caller versions in decreasing priority order.
sort(Callers, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS] > FeatureMask[RHS];
labrinea marked this conversation as resolved.
Show resolved Hide resolved
});

auto implies = [](uint64_t A, uint64_t B) { return (A & B) == B; };

// Index to the highest priority candidate.
unsigned I = 0;
// Now try to redirect calls starting from higher priority callers.
for (Function *Caller : Callers) {
assert(I < Callees.size() && "Found callers of equal priority");

Function *Callee = Callees[I];
uint64_t CallerBits = FeatureMask[Caller];
uint64_t CalleeBits = FeatureMask[Callee];
// If the feature set of the caller implies the feature set of the
// highest priority candidate then it shall be picked. In case of
// identical sets advance the candidate index one position.
if (CallerBits == CalleeBits)
++I;
else if (!implies(CallerBits, CalleeBits)) {
// Keep advancing the candidate index as long as the caller's
// features are a subset of the current candidate's.
while (implies(CalleeBits, CallerBits)) {
if (++I == Callees.size())
break;
CalleeBits = FeatureMask[Callees[I]];
}
continue;
}
auto &Calls = CallSites[Caller];
for (CallBase *CS : Calls)
CS->setCalledOperand(Callee);
Changed = true;
}
if (IF.use_empty() ||
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is probably a leftover from the time we had ifunc aliases. Subject to removal.

NumIFuncsResolved++;
}
return Changed;
}

static bool
optimizeGlobalsInModule(Module &M, const DataLayout &DL,
function_ref<TargetLibraryInfo &(Function &)> GetTLI,
Expand Down Expand Up @@ -2707,6 +2851,9 @@ optimizeGlobalsInModule(Module &M, const DataLayout &DL,
// Optimize IFuncs whose callee's are statically known.
LocalChange |= OptimizeStaticIFuncs(M);

// Optimize IFuncs based on the target features of the caller.
LocalChange |= OptimizeNonTrivialIFuncs(M, GetTTI);

// Remove any IFuncs that are now dead.
LocalChange |= DeleteDeadIFuncs(M, NotDiscardableComdats);

Expand Down
Loading
Loading