Skip to content

Commit

Permalink
[FMV][GlobalOpt] Bypass the IFunc Resolver of MultiVersioned functions.
Browse files Browse the repository at this point in the history
To deduce whether the optimization is legal we need to compare the target
features between caller and callee versions. The criteria for bypassing
the resolver are the following:

 * If the callee's feature set is a subset of the caller's feature set,
   then the callee is a candidate for direct call.

 * Among such candidates the one of highest priority is the best match
   and it shall be picked, unless there is a version of the callee with
   higher priority than the best match which cannot be picked from a
   higher priority caller (directly or through the resolver).

 * For every higher priority callee version than the best match, there
   is a higher priority caller version whose feature set availability
   is implied by the callee's feature set.

Example:

Callers and Callees are ordered in decreasing priority.
The arrows indicate successful call redirections.

  Caller        Callee      Explanation
=========================================================================
mops+sve2 --+--> mops       all the callee versions are subsets of the
            |               caller but mops has the highest priority
            |
     mops --+    sve2       between mops and default callees, mops wins

      sve        sve        between sve and default callees, sve wins
                            but sve2 does not have a high priority caller

  default -----> default    sve (callee) implies sve (caller),
                            sve2(callee) implies sve (caller),
                            mops(callee) implies mops(caller)
  • Loading branch information
labrinea committed Apr 9, 2024
1 parent a522dbb commit 02bd5a7
Show file tree
Hide file tree
Showing 9 changed files with 604 additions and 6 deletions.
14 changes: 14 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1762,6 +1762,12 @@ class TargetTransformInfo {
/// false, but it shouldn't matter what it returns anyway.
bool hasArmWideBranch(bool Thumb) const;

/// Returns true if the target supports Function MultiVersioning.
bool hasFMV() const;

/// Returns a bitmask constructed from the target features of a function.
uint64_t getFeatureMask(Function &F) const;

/// \return The maximum number of function arguments the target supports.
unsigned getMaxNumArgs() const;

Expand Down Expand Up @@ -2152,6 +2158,8 @@ class TargetTransformInfo::Concept {
virtual VPLegalization
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
virtual bool hasArmWideBranch(bool Thumb) const = 0;
virtual bool hasFMV() const = 0;
virtual uint64_t getFeatureMask(Function &F) const = 0;
virtual unsigned getMaxNumArgs() const = 0;
};

Expand Down Expand Up @@ -2904,6 +2912,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.hasArmWideBranch(Thumb);
}

bool hasFMV() const override { return Impl.hasFMV(); }

uint64_t getFeatureMask(Function &F) const override {
return Impl.getFeatureMask(F);
}

unsigned getMaxNumArgs() const override {
return Impl.getMaxNumArgs();
}
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,10 @@ class TargetTransformInfoImplBase {

bool hasArmWideBranch(bool) const { return false; }

bool hasFMV() const { return false; }

uint64_t getFeatureMask(Function &F) const { return 0; }

unsigned getMaxNumArgs() const { return UINT_MAX; }

protected:
Expand Down
4 changes: 3 additions & 1 deletion llvm/include/llvm/TargetParser/AArch64TargetParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@ const ArchInfo *getArchForCpu(StringRef CPU);
// Parser
const ArchInfo *parseArch(StringRef Arch);
std::optional<ExtensionInfo> parseArchExtension(StringRef Extension);
std::optional<ExtensionInfo> parseTargetFeature(StringRef Feature);
// Given the name of a CPU or alias, return the correponding CpuInfo.
std::optional<CpuInfo> parseCpu(StringRef Name);
// Used by target parser tests
Expand All @@ -856,7 +857,8 @@ bool isX18ReservedByDefault(const Triple &TT);
// For given feature names, return a bitmask corresponding to the entries of
// AArch64::CPUFeatures. The values in CPUFeatures are not bitmasks
// themselves, they are sequential (0, 1, 2, 3, ...).
uint64_t getCpuSupportsMask(ArrayRef<StringRef> FeatureStrs);
uint64_t getCpuSupportsMask(ArrayRef<StringRef> FeatureStrs,
bool IsBackEndFeature = false);

void PrintSupportedExtensions(StringMap<StringRef> DescMap);

Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,12 @@ bool TargetTransformInfo::hasArmWideBranch(bool Thumb) const {
return TTIImpl->hasArmWideBranch(Thumb);
}

bool TargetTransformInfo::hasFMV() const { return TTIImpl->hasFMV(); }

uint64_t TargetTransformInfo::getFeatureMask(Function &F) const {
return TTIImpl->getFeatureMask(F);
}

unsigned TargetTransformInfo::getMaxNumArgs() const {
return TTIImpl->getMaxNumArgs();
}
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/TargetParser/AArch64TargetParser.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
#include <algorithm>
Expand Down Expand Up @@ -231,6 +232,13 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
return false;
}

uint64_t AArch64TTIImpl::getFeatureMask(Function &F) const {
StringRef FeatureStr = F.getFnAttribute("target-features").getValueAsString();
SmallVector<StringRef, 8> Features;
FeatureStr.split(Features, ",");
return AArch64::getCpuSupportsMask(Features, /*IsBackEndFeature = */ true);
}

bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
const Function *Callee) const {
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
unsigned DefaultCallPenalty) const;

bool hasFMV() const { return ST->hasFMV(); }

uint64_t getFeatureMask(Function &F) const;

/// \name Scalar TTI Implementations
/// @{

Expand Down
17 changes: 13 additions & 4 deletions llvm/lib/TargetParser/AArch64TargetParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ std::optional<AArch64::ArchInfo> AArch64::ArchInfo::findBySubArch(StringRef SubA
return {};
}

uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> FeatureStrs) {
uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> FeatureStrs,
bool IsBackEndFeature) {
uint64_t FeaturesMask = 0;
for (const StringRef &FeatureStr : FeatureStrs) {
if (auto Ext = parseArchExtension(FeatureStr))
for (const StringRef FeatureStr : FeatureStrs)
if (auto Ext = IsBackEndFeature ? parseTargetFeature(FeatureStr)
: parseArchExtension(FeatureStr))
FeaturesMask |= (1ULL << Ext->CPUFeature);
}
return FeaturesMask;
}

Expand Down Expand Up @@ -132,6 +133,14 @@ std::optional<AArch64::ExtensionInfo> AArch64::parseArchExtension(StringRef Arch
return {};
}

std::optional<AArch64::ExtensionInfo>
AArch64::parseTargetFeature(StringRef Feature) {
for (const auto &E : Extensions)
if (Feature == E.Feature)
return E;
return {};
}

std::optional<AArch64::CpuInfo> AArch64::parseCpu(StringRef Name) {
// Resolve aliases first.
Name = resolveCPUAlias(Name);
Expand Down
141 changes: 140 additions & 1 deletion llvm/lib/Transforms/IPO/GlobalOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ STATISTIC(NumAliasesRemoved, "Number of global aliases eliminated");
STATISTIC(NumCXXDtorsRemoved, "Number of global C++ destructors removed");
STATISTIC(NumInternalFunc, "Number of internal functions");
STATISTIC(NumColdCC, "Number of functions marked coldcc");
STATISTIC(NumIFuncsResolved, "Number of statically resolved IFuncs");
STATISTIC(NumIFuncsResolved, "Number of resolved IFuncs");
STATISTIC(NumIFuncsDeleted, "Number of IFuncs removed");

static cl::opt<bool>
Expand Down Expand Up @@ -2462,6 +2462,142 @@ DeleteDeadIFuncs(Module &M,
return Changed;
}

// Follows the use-def chain of \p V backwards until it finds a Function,
// in which case it collects in \p Versions.
static void collectVersions(Value *V, SmallVectorImpl<Function *> &Versions) {
if (auto *F = dyn_cast<Function>(V)) {
Versions.push_back(F);
} else if (auto *Sel = dyn_cast<SelectInst>(V)) {
collectVersions(Sel->getTrueValue(), Versions);
collectVersions(Sel->getFalseValue(), Versions);
} else if (auto *Phi = dyn_cast<PHINode>(V)) {
for (unsigned I = 0, E = Phi->getNumIncomingValues(); I != E; ++I)
collectVersions(Phi->getIncomingValue(I), Versions);
}
}

// Bypass the IFunc Resolver of MultiVersioned functions when possible. To
// deduce whether the optimization is legal we need to compare the target
// features between caller and callee versions. The criteria for bypassing
// the resolver are the following:
//
// * If the callee's feature set is a subset of the caller's feature set,
// then the callee is a candidate for direct call.
//
// * Among such candidates the one of highest priority is the best match
// and it shall be picked, unless there is a version of the callee with
// higher priority than the best match which cannot be picked from a
// higher priority caller (directly or through the resolver).
//
// * For every higher priority callee version than the best match, there
// is a higher priority caller version whose feature set availability
// is implied by the callee's feature set.
//
static bool OptimizeNonTrivialIFuncs(
Module &M, function_ref<TargetTransformInfo &(Function &)> GetTTI) {
bool Changed = false;

// Cache containing the mask constructed from a function's target features.
DenseMap<Function *, uint64_t> FeatureMask;

for (GlobalIFunc &IF : M.ifuncs()) {
if (IF.isInterposable())
continue;

Function *Resolver = IF.getResolverFunction();
if (!Resolver)
continue;

if (Resolver->isInterposable())
continue;

TargetTransformInfo &TTI = GetTTI(*Resolver);
if (!TTI.hasFMV())
return false;

// Discover the callee versions.
SmallVector<Function *> Callees;
for (BasicBlock &BB : *Resolver)
if (auto *Ret = dyn_cast_or_null<ReturnInst>(BB.getTerminator()))
collectVersions(Ret->getReturnValue(), Callees);

if (Callees.empty())
continue;

// Cache the feature mask for each callee.
for (Function *Callee : Callees) {
auto [It, Inserted] = FeatureMask.try_emplace(Callee);
if (Inserted)
It->second = TTI.getFeatureMask(*Callee);
}

// Sort the callee versions in decreasing priority order.
sort(Callees, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS] > FeatureMask[RHS];
});

// Find the callsites and cache the feature mask for each caller.
SmallVector<Function *> Callers;
DenseMap<Function *, SmallVector<CallBase *>> CallSites;
for (User *U : IF.users()) {
if (auto *CB = dyn_cast<CallBase>(U)) {
if (CB->getCalledOperand() == &IF) {
Function *Caller = CB->getFunction();
auto [FeatIt, FeatInserted] = FeatureMask.try_emplace(Caller);
if (FeatInserted)
FeatIt->second = TTI.getFeatureMask(*Caller);
auto [CallIt, CallInserted] = CallSites.try_emplace(Caller);
if (CallInserted)
Callers.push_back(Caller);
CallIt->second.push_back(CB);
}
}
}

// Sort the caller versions in decreasing priority order.
sort(Callers, [&](auto *LHS, auto *RHS) {
return FeatureMask[LHS] > FeatureMask[RHS];
});

auto implies = [](uint64_t A, uint64_t B) { return (A & B) == B; };

// Index to the highest priority candidate.
unsigned I = 0;
// Now try to redirect calls starting from higher priority callers.
for (Function *Caller : Callers) {
// Getting here means we found callers of equal priority.
if (I == Callees.size())
break;
Function *Callee = Callees[I];
uint64_t CallerBits = FeatureMask[Caller];
uint64_t CalleeBits = FeatureMask[Callee];
// If the feature set of the caller implies the feature set of the
// highest priority candidate then it shall be picked. In case of
// identical sets advance the candidate index one position.
if (CallerBits == CalleeBits)
++I;
else if (!implies(CallerBits, CalleeBits)) {
// Keep advancing the candidate index as long as the caller's
// features are a subset of the current candidate's.
while (implies(CalleeBits, CallerBits)) {
if (++I == Callees.size())
break;
CalleeBits = FeatureMask[Callees[I]];
}
continue;
}
auto &Calls = CallSites[Caller];
for (CallBase *CS : Calls)
CS->setCalledOperand(Callee);
Changed = true;
}
if (IF.use_empty() ||
all_of(IF.users(), [](User *U) { return isa<GlobalAlias>(U); }))
NumIFuncsResolved++;
}
return Changed;
}

static bool
optimizeGlobalsInModule(Module &M, const DataLayout &DL,
function_ref<TargetLibraryInfo &(Function &)> GetTLI,
Expand Down Expand Up @@ -2525,6 +2661,9 @@ optimizeGlobalsInModule(Module &M, const DataLayout &DL,
// Optimize IFuncs whose callee's are statically known.
LocalChange |= OptimizeStaticIFuncs(M);

// Optimize IFuncs based on the target features of the caller.
LocalChange |= OptimizeNonTrivialIFuncs(M, GetTTI);

// Remove any IFuncs that are now dead.
LocalChange |= DeleteDeadIFuncs(M, NotDiscardableComdats);

Expand Down
Loading

0 comments on commit 02bd5a7

Please sign in to comment.