Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FMV][GlobalOpt] Statically resolve calls to versioned functions. #87939

Merged
merged 16 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1870,6 +1870,13 @@ class TargetTransformInfo {
/// false, but it shouldn't matter what it returns anyway.
bool hasArmWideBranch(bool Thumb) const;

/// Returns a bitmask constructed from the target-features or fmv-features
/// metadata of a function.
uint64_t getFeatureMask(const Function &F) const;

/// Returns true if this is an instance of a function with multiple versions.
bool isMultiversionedFunction(const Function &F) const;

/// \return The maximum number of function arguments the target supports.
unsigned getMaxNumArgs() const;

Expand Down Expand Up @@ -2312,6 +2319,8 @@ class TargetTransformInfo::Concept {
virtual VPLegalization
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
virtual bool hasArmWideBranch(bool Thumb) const = 0;
virtual uint64_t getFeatureMask(const Function &F) const = 0;
virtual bool isMultiversionedFunction(const Function &F) const = 0;
virtual unsigned getMaxNumArgs() const = 0;
virtual unsigned getNumBytesToPadGlobalArray(unsigned Size,
Type *ArrayType) const = 0;
Expand Down Expand Up @@ -3144,6 +3153,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.hasArmWideBranch(Thumb);
}

uint64_t getFeatureMask(const Function &F) const override {
return Impl.getFeatureMask(F);
}

bool isMultiversionedFunction(const Function &F) const override {
return Impl.isMultiversionedFunction(F);
}

unsigned getMaxNumArgs() const override {
return Impl.getMaxNumArgs();
}
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,10 @@ class TargetTransformInfoImplBase {

bool hasArmWideBranch(bool) const { return false; }

uint64_t getFeatureMask(const Function &F) const { return 0; }

bool isMultiversionedFunction(const Function &F) const { return false; }

unsigned getMaxNumArgs() const { return UINT_MAX; }

unsigned getNumBytesToPadGlobalArray(unsigned Size, Type *ArrayType) const {
Expand Down
13 changes: 8 additions & 5 deletions llvm/include/llvm/TargetParser/AArch64TargetParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,16 @@ void fillValidCPUArchList(SmallVectorImpl<StringRef> &Values);

bool isX18ReservedByDefault(const Triple &TT);

// Return the priority for a given set of FMV features.
// For a given set of feature names, which can be either target-features, or
// fmv-features metadata, expand their dependencies and then return a bitmask
// corresponding to the entries of AArch64::FeatPriorities.
uint64_t getFMVPriority(ArrayRef<StringRef> Features);

// For given feature names, return a bitmask corresponding to the entries of
// AArch64::CPUFeatures. The values in CPUFeatures are not bitmasks themselves,
// they are sequential (0, 1, 2, 3, ...). The resulting bitmask is used at
// runtime to test whether a certain FMV feature is available on the host.
// For a given set of FMV feature names, expand their dependencies and then
// return a bitmask corresponding to the entries of AArch64::CPUFeatures.
// The values in CPUFeatures are not bitmasks themselves, they are sequential
// (0, 1, 2, 3, ...). The resulting bitmask is used at runtime to test whether
// a certain FMV feature is available on the host.
uint64_t getCpuSupportsMask(ArrayRef<StringRef> Features);

void PrintSupportedExtensions();
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,14 @@ bool TargetTransformInfo::hasArmWideBranch(bool Thumb) const {
return TTIImpl->hasArmWideBranch(Thumb);
}

uint64_t TargetTransformInfo::getFeatureMask(const Function &F) const {
return TTIImpl->getFeatureMask(F);
}

bool TargetTransformInfo::isMultiversionedFunction(const Function &F) const {
return TTIImpl->isMultiversionedFunction(F);
}

unsigned TargetTransformInfo::getMaxNumArgs() const {
return TTIImpl->getMaxNumArgs();
}
Expand Down
14 changes: 14 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/TargetParser/AArch64TargetParser.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
#include <algorithm>
Expand Down Expand Up @@ -248,6 +249,19 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
return false;
}

uint64_t AArch64TTIImpl::getFeatureMask(const Function &F) const {
StringRef AttributeStr =
isMultiversionedFunction(F) ? "fmv-features" : "target-features";
StringRef FeatureStr = F.getFnAttribute(AttributeStr).getValueAsString();
SmallVector<StringRef, 8> Features;
FeatureStr.split(Features, ",");
return AArch64::getFMVPriority(Features);
}

bool AArch64TTIImpl::isMultiversionedFunction(const Function &F) const {
return F.hasFnAttribute("fmv-features");
}

bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
const Function *Callee) const {
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
unsigned DefaultCallPenalty) const;

uint64_t getFeatureMask(const Function &F) const;

bool isMultiversionedFunction(const Function &F) const;

/// \name Scalar TTI Implementations
/// @{

Expand Down
31 changes: 26 additions & 5 deletions llvm/lib/TargetParser/AArch64TargetParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,33 @@ std::optional<AArch64::ArchInfo> AArch64::ArchInfo::findBySubArch(StringRef SubA
return {};
}

std::optional<AArch64::FMVInfo> lookupFMVByID(AArch64::ArchExtKind ExtID) {
for (const AArch64::FMVInfo &Info : AArch64::getFMVInfo())
if (Info.ID && *Info.ID == ExtID)
return Info;
return {};
}

uint64_t AArch64::getFMVPriority(ArrayRef<StringRef> Features) {
uint64_t Priority = 0;
for (StringRef Feature : Features)
if (std::optional<FMVInfo> Info = parseFMVExtension(Feature))
Priority |= (1ULL << Info->PriorityBit);
return Priority;
// Transitively enable the Arch Extensions which correspond to each feature.
ExtensionSet FeatureBits;
for (const StringRef Feature : Features) {
std::optional<FMVInfo> FMV = parseFMVExtension(Feature);
if (!FMV) {
if (std::optional<ExtensionInfo> Info = targetFeatureToExtension(Feature))
FMV = lookupFMVByID(Info->ID);
}
if (FMV && FMV->ID)
FeatureBits.enable(*FMV->ID);
}

// Construct a bitmask for all the transitively enabled Arch Extensions.
uint64_t PriorityMask = 0;
for (const FMVInfo &Info : getFMVInfo())
if (Info.ID && FeatureBits.Enabled.test(*Info.ID))
PriorityMask |= (1ULL << Info.PriorityBit);

return PriorityMask;
}

uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> Features) {
Expand Down
162 changes: 162 additions & 0 deletions llvm/lib/Transforms/IPO/GlobalOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2641,6 +2641,165 @@ DeleteDeadIFuncs(Module &M,
return Changed;
}

// Follows the use-def chain of \p V backwards until it finds a Function,
// in which case it collects in \p Versions. Return true on successful
// use-def chain traversal, false otherwise.
static bool collectVersions(TargetTransformInfo &TTI, Value *V,
SmallVectorImpl<Function *> &Versions) {
if (auto *F = dyn_cast<Function>(V)) {
if (!TTI.isMultiversionedFunction(*F))
return false;
Versions.push_back(F);
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
if (!collectVersions(TTI, Sel->getTrueValue(), Versions))
return false;
if (!collectVersions(TTI, Sel->getFalseValue(), Versions))
return false;
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
if (!collectVersions(TTI, Phi->getIncomingValue(I), Versions))
return false;
} else {
// Unknown instruction type. Bail.
return false;
}
labrinea marked this conversation as resolved.
Show resolved Hide resolved
return true;
}

// Bypass the IFunc Resolver of MultiVersioned functions when possible. To
// deduce whether the optimization is legal we need to compare the target
// features between caller and callee versions. The criteria for bypassing
// the resolver are the following:
//
// * If the callee's feature set is a subset of the caller's feature set,
// then the callee is a candidate for direct call.
//
// * Among such candidates the one of highest priority is the best match
// and it shall be picked, unless there is a version of the callee with
// higher priority than the best match which cannot be picked from a
// higher priority caller (directly or through the resolver).
//
// * For every higher priority callee version than the best match, there
// is a higher priority caller version whose feature set availability
// is implied by the callee's feature set.
//
static bool OptimizeNonTrivialIFuncs(
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
bool Changed = false;

// Cache containing the mask constructed from a function's target features.
DenseMap<Function *, uint64_t> FeatureMask;

for (GlobalIFunc &IF : M.ifuncs()) {
if (IF.isInterposable())
continue;

Function *Resolver = IF.getResolverFunction();
if (!Resolver)
continue;

if (Resolver->isInterposable())
continue;

TargetTransformInfo &TTI = GetTTI(*Resolver);

// Discover the callee versions.
SmallVector<Function *> Callees;
if (any_of(*Resolver, [&TTI, &Callees](BasicBlock &BB) {
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
if (!collectVersions(TTI, Ret->getReturnValue(), Callees))
return true;
return false;
}))
continue;

assert(!Callees.empty() && "Expecting successful collection of versions");

// Cache the feature mask for each callee.
for (Function *Callee : Callees) {
auto [It, Inserted] = FeatureMask.try_emplace(Callee);
if (Inserted)
It->second = TTI.getFeatureMask(*Callee);
}

// Sort the callee versions in decreasing priority order.
sort(Callees, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS] > FeatureMask[RHS];
labrinea marked this conversation as resolved.
Show resolved Hide resolved
});

// Find the callsites and cache the feature mask for each caller.
SmallVector<Function *> Callers;
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
for (User *U : IF.users()) {
if (auto *CB = dyn_cast<CallBase>(U)) {
if (CB->getCalledOperand() == &IF) {
Function *Caller = CB->getFunction();
auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
if (FeatInserted)
FeatIt->second = TTI.getFeatureMask(*Caller);
auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
if (CallInserted)
Callers.push_back(Caller);
CallIt->second.push_back(CB);
}
}
}

// Sort the caller versions in decreasing priority order.
sort(Callers, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS] > FeatureMask[RHS];
labrinea marked this conversation as resolved.
Show resolved Hide resolved
});

auto implies = [](uint64_t A, uint64_t B) { return (A & B) == B; };

// Index to the highest priority candidate.
unsigned I = 0;
// Now try to redirect calls starting from higher priority callers.
for (Function *Caller : Callers) {
assert(I < Callees.size() && "Found callers of equal priority");

Function *Callee = Callees[I];
uint64_t CallerBits = FeatureMask[Caller];
uint64_t CalleeBits = FeatureMask[Callee];

// In the case of FMV callers, we know that all higher priority callers
// than the current one did not get selected at runtime, which helps
// reason about the callees (if they have versions that mandate presence
// of the features which we already know are unavailable on this target).
if (TTI.isMultiversionedFunction(*Caller)) {
// If the feature set of the caller implies the feature set of the
// highest priority candidate then it shall be picked. In case of
// identical sets advance the candidate index one position.
if (CallerBits == CalleeBits)
++I;
else if (!implies(CallerBits, CalleeBits)) {
// Keep advancing the candidate index as long as the caller's
// features are a subset of the current candidate's.
while (implies(CalleeBits, CallerBits)) {
if (++I == Callees.size())
break;
CalleeBits = FeatureMask[Callees[I]];
}
continue;
}
} else {
// We can't reason much about non-FMV callers. Just pick the highest
// priority callee if it matches, otherwise bail.
if (I > 0 || !implies(CallerBits, CalleeBits))
continue;
}
auto &Calls = CallSites[Caller];
for (CallBase *CS : Calls)
CS->setCalledOperand(Callee);
Changed = true;
}
if (IF.use_empty() ||
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is probably a leftover from the time we had ifunc aliases. Subject to removal.

NumIFuncsResolved++;
}
return Changed;
}

static bool
optimizeGlobalsInModule(Module &M, const DataLayout &DL,
function_ref<TargetLibraryInfo &(Function &)> GetTLI,
Expand Down Expand Up @@ -2707,6 +2866,9 @@ optimizeGlobalsInModule(Module &M, const DataLayout &DL,
// Optimize IFuncs whose callee's are statically known.
LocalChange |= OptimizeStaticIFuncs(M);

// Optimize IFuncs based on the target features of the caller.
LocalChange |= OptimizeNonTrivialIFuncs(M, GetTTI);

// Remove any IFuncs that are now dead.
LocalChange |= DeleteDeadIFuncs(M, NotDiscardableComdats);

Expand Down
Loading
Loading