Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FMV][GlobalOpt] Statically resolve calls to versioned functions. #87939

Merged
merged 16 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1841,6 +1841,12 @@ class TargetTransformInfo {
/// false, but it shouldn't matter what it returns anyway.
bool hasArmWideBranch(bool Thumb) const;

/// Returns a bitmask constructed from the target features of a function.
uint64_t getFeatureMask(Function &F) const;
labrinea marked this conversation as resolved.
Show resolved Hide resolved

/// Returns true if this is an instance of a function with multiple versions.
bool isMultiversionedFunction(Function &F) const;

/// \return The maximum number of function arguments the target supports.
unsigned getMaxNumArgs() const;

Expand Down Expand Up @@ -2266,6 +2272,8 @@ class TargetTransformInfo::Concept {
virtual VPLegalization
getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
virtual bool hasArmWideBranch(bool Thumb) const = 0;
virtual uint64_t getFeatureMask(Function &F) const = 0;
virtual bool isMultiversionedFunction(Function &F) const = 0;
virtual unsigned getMaxNumArgs() const = 0;
virtual unsigned getNumBytesToPadGlobalArray(unsigned Size,
Type *ArrayType) const = 0;
Expand Down Expand Up @@ -3082,6 +3090,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
return Impl.hasArmWideBranch(Thumb);
}

uint64_t getFeatureMask(Function &F) const override {
return Impl.getFeatureMask(F);
}

bool isMultiversionedFunction(Function &F) const override {
return Impl.isMultiversionedFunction(F);
}

unsigned getMaxNumArgs() const override {
return Impl.getMaxNumArgs();
}
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,10 @@ class TargetTransformInfoImplBase {

bool hasArmWideBranch(bool) const { return false; }

uint64_t getFeatureMask(Function &F) const { return 0; }

bool isMultiversionedFunction(Function &F) const { return false; }

unsigned getMaxNumArgs() const { return UINT_MAX; }

unsigned getNumBytesToPadGlobalArray(unsigned Size, Type *ArrayType) const {
Expand Down
66 changes: 66 additions & 0 deletions llvm/include/llvm/TargetParser/AArch64FeatPriorities.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===- AArch64FeatPriorities.inc - AArch64 FMV Priorities enum --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file enumerates the AArch64 FMV features sorted in ascending priority.
//
//===----------------------------------------------------------------------===//

#ifndef AARCH64_FEAT_PRIORITIES_INC_H
#define AARCH64_FEAT_PRIORITIES_INC_H

// Function Multi Versioning feature priorities.
enum FeatPriorities {
labrinea marked this conversation as resolved.
Show resolved Hide resolved
PRIOR_RNG,
PRIOR_FLAGM,
PRIOR_FLAGM2,
PRIOR_LSE,
PRIOR_FP,
PRIOR_SIMD,
PRIOR_DOTPROD,
PRIOR_SM4,
PRIOR_RDM,
PRIOR_CRC,
PRIOR_SHA2,
PRIOR_SHA3,
PRIOR_PMULL,
PRIOR_FP16,
PRIOR_FP16FML,
PRIOR_DIT,
PRIOR_DPB,
PRIOR_DPB2,
PRIOR_JSCVT,
PRIOR_FCMA,
PRIOR_RCPC,
PRIOR_RCPC2,
PRIOR_RCPC3,
PRIOR_FRINTTS,
PRIOR_I8MM,
PRIOR_BF16,
PRIOR_SVE,
PRIOR_SVE_F32MM,
PRIOR_SVE_F64MM,
PRIOR_SVE2,
PRIOR_SVE_PMULL128,
PRIOR_SVE_BITPERM,
PRIOR_SVE_SHA3,
PRIOR_SVE_SM4,
PRIOR_SME,
PRIOR_MEMTAG2,
PRIOR_SB,
PRIOR_PREDRES,
PRIOR_SSBS2,
PRIOR_BTI,
PRIOR_LS64_ACCDATA,
PRIOR_WFXT,
PRIOR_SME_F64,
PRIOR_SME_I64,
PRIOR_SME2,
PRIOR_MOPS
};

#endif
15 changes: 10 additions & 5 deletions llvm/include/llvm/TargetParser/AArch64TargetParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct ArchInfo;
struct CpuInfo;

#include "llvm/TargetParser/AArch64CPUFeatures.inc"
#include "llvm/TargetParser/AArch64FeatPriorities.inc"

static_assert(FEAT_MAX < 62,
"Number of features in CPUFeatures are limited to 62 entries");
Expand Down Expand Up @@ -69,12 +70,12 @@ struct ExtensionInfo {

struct FMVInfo {
StringRef Name; // The target_version/target_clones spelling.
CPUFeatures Bit; // Index of the bit in the FMV feature bitset.
CPUFeatures FeatureBit; // Index of the bit in the FMV feature bitset.
std::optional<ArchExtKind> ID; // The architecture extension to enable.
unsigned Priority; // FMV priority.
FMVInfo(StringRef Name, CPUFeatures Bit, std::optional<ArchExtKind> ID,
unsigned Priority)
: Name(Name), Bit(Bit), ID(ID), Priority(Priority) {};
FeatPriorities PriorityBit; // FMV priority.
FMVInfo(StringRef Name, CPUFeatures FeatureBit, std::optional<ArchExtKind> ID,
FeatPriorities PriorityBit)
: Name(Name), FeatureBit(FeatureBit), ID(ID), PriorityBit(PriorityBit) {};
};

const std::vector<FMVInfo> &getFMVInfo();
Expand Down Expand Up @@ -271,6 +272,10 @@ bool isX18ReservedByDefault(const Triple &TT);
// Return the priority for a given set of FMV features.
unsigned getFMVPriority(ArrayRef<StringRef> Features);

// For given feature names, return a bitmask corresponding to the entries of
// AArch64::FeatPriorities.
uint64_t getPriorityMask(ArrayRef<StringRef> Features);

// For given feature names, return a bitmask corresponding to the entries of
// AArch64::CPUFeatures. The values in CPUFeatures are not bitmasks
// themselves, they are sequential (0, 1, 2, 3, ...).
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,14 @@ bool TargetTransformInfo::hasArmWideBranch(bool Thumb) const {
return TTIImpl->hasArmWideBranch(Thumb);
}

uint64_t TargetTransformInfo::getFeatureMask(Function &F) const {
return TTIImpl->getFeatureMask(F);
}

bool TargetTransformInfo::isMultiversionedFunction(Function &F) const {
return TTIImpl->isMultiversionedFunction(F);
}

unsigned TargetTransformInfo::getMaxNumArgs() const {
return TTIImpl->getMaxNumArgs();
}
Expand Down
105 changes: 53 additions & 52 deletions llvm/lib/Target/AArch64/AArch64FMV.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,64 +22,65 @@


// Something you can add to target_version or target_clones.
class FMVExtension<string n, string b, int p> {
class FMVExtension<string name, string enumeration> {
// Name, as spelled in target_version or target_clones. e.g. "memtag".
string Name = n;
string Name = name;

// A C++ expression giving the number of the bit in the FMV ABI.
// Currently this is given as a value from the enum "CPUFeatures".
string Bit = b;
string FeatureBit = "FEAT_" # enumeration;

// SubtargetFeature enabled for codegen when this FMV feature is present.
string BackendFeature = n;
string BackendFeature = name;

// The FMV priority.
int Priority = p;
// A C++ expression giving the number of the priority bit.
// Currently this is given as a value from the enum "FeatPriorities".
string PriorityBit = "PRIOR_" # enumeration;
}

def : FMVExtension<"aes", "FEAT_PMULL", 150>;
def : FMVExtension<"bf16", "FEAT_BF16", 280>;
def : FMVExtension<"bti", "FEAT_BTI", 510>;
def : FMVExtension<"crc", "FEAT_CRC", 110>;
def : FMVExtension<"dit", "FEAT_DIT", 180>;
def : FMVExtension<"dotprod", "FEAT_DOTPROD", 104>;
let BackendFeature = "ccpp" in def : FMVExtension<"dpb", "FEAT_DPB", 190>;
let BackendFeature = "ccdp" in def : FMVExtension<"dpb2", "FEAT_DPB2", 200>;
def : FMVExtension<"f32mm", "FEAT_SVE_F32MM", 350>;
def : FMVExtension<"f64mm", "FEAT_SVE_F64MM", 360>;
def : FMVExtension<"fcma", "FEAT_FCMA", 220>;
def : FMVExtension<"flagm", "FEAT_FLAGM", 20>;
let BackendFeature = "altnzcv" in def : FMVExtension<"flagm2", "FEAT_FLAGM2", 30>;
def : FMVExtension<"fp", "FEAT_FP", 90>;
def : FMVExtension<"fp16", "FEAT_FP16", 170>;
def : FMVExtension<"fp16fml", "FEAT_FP16FML", 175>;
let BackendFeature = "fptoint" in def : FMVExtension<"frintts", "FEAT_FRINTTS", 250>;
def : FMVExtension<"i8mm", "FEAT_I8MM", 270>;
def : FMVExtension<"jscvt", "FEAT_JSCVT", 210>;
def : FMVExtension<"ls64", "FEAT_LS64_ACCDATA", 520>;
def : FMVExtension<"lse", "FEAT_LSE", 80>;
def : FMVExtension<"memtag", "FEAT_MEMTAG2", 440>;
def : FMVExtension<"mops", "FEAT_MOPS", 650>;
def : FMVExtension<"predres", "FEAT_PREDRES", 480>;
def : FMVExtension<"rcpc", "FEAT_RCPC", 230>;
let BackendFeature = "rcpc-immo" in def : FMVExtension<"rcpc2", "FEAT_RCPC2", 240>;
def : FMVExtension<"rcpc3", "FEAT_RCPC3", 241>;
def : FMVExtension<"rdm", "FEAT_RDM", 108>;
def : FMVExtension<"rng", "FEAT_RNG", 10>;
def : FMVExtension<"sb", "FEAT_SB", 470>;
def : FMVExtension<"sha2", "FEAT_SHA2", 130>;
def : FMVExtension<"sha3", "FEAT_SHA3", 140>;
def : FMVExtension<"simd", "FEAT_SIMD", 100>;
def : FMVExtension<"sm4", "FEAT_SM4", 106>;
def : FMVExtension<"sme", "FEAT_SME", 430>;
def : FMVExtension<"sme-f64f64", "FEAT_SME_F64", 560>;
def : FMVExtension<"sme-i16i64", "FEAT_SME_I64", 570>;
def : FMVExtension<"sme2", "FEAT_SME2", 580>;
def : FMVExtension<"ssbs", "FEAT_SSBS2", 490>;
def : FMVExtension<"sve", "FEAT_SVE", 310>;
def : FMVExtension<"sve2", "FEAT_SVE2", 370>;
def : FMVExtension<"sve2-aes", "FEAT_SVE_PMULL128", 380>;
def : FMVExtension<"sve2-bitperm", "FEAT_SVE_BITPERM", 400>;
def : FMVExtension<"sve2-sha3", "FEAT_SVE_SHA3", 410>;
def : FMVExtension<"sve2-sm4", "FEAT_SVE_SM4", 420>;
def : FMVExtension<"wfxt", "FEAT_WFXT", 550>;
def : FMVExtension<"aes", "PMULL">;
def : FMVExtension<"bf16", "BF16">;
def : FMVExtension<"bti", "BTI">;
def : FMVExtension<"crc", "CRC">;
def : FMVExtension<"dit", "DIT">;
def : FMVExtension<"dotprod", "DOTPROD">;
let BackendFeature = "ccpp" in def : FMVExtension<"dpb", "DPB">;
let BackendFeature = "ccdp" in def : FMVExtension<"dpb2", "DPB2">;
def : FMVExtension<"f32mm", "SVE_F32MM">;
def : FMVExtension<"f64mm", "SVE_F64MM">;
def : FMVExtension<"fcma", "FCMA">;
def : FMVExtension<"flagm", "FLAGM">;
let BackendFeature = "altnzcv" in def : FMVExtension<"flagm2", "FLAGM2">;
def : FMVExtension<"fp", "FP">;
def : FMVExtension<"fp16", "FP16">;
def : FMVExtension<"fp16fml", "FP16FML">;
let BackendFeature = "fptoint" in def : FMVExtension<"frintts", "FRINTTS">;
def : FMVExtension<"i8mm", "I8MM">;
def : FMVExtension<"jscvt", "JSCVT">;
def : FMVExtension<"ls64", "LS64_ACCDATA">;
def : FMVExtension<"lse", "LSE">;
def : FMVExtension<"memtag", "MEMTAG2">;
def : FMVExtension<"mops", "MOPS">;
def : FMVExtension<"predres", "PREDRES">;
def : FMVExtension<"rcpc", "RCPC">;
let BackendFeature = "rcpc-immo" in def : FMVExtension<"rcpc2", "RCPC2">;
def : FMVExtension<"rcpc3", "RCPC3">;
def : FMVExtension<"rdm", "RDM">;
def : FMVExtension<"rng", "RNG">;
def : FMVExtension<"sb", "SB">;
def : FMVExtension<"sha2", "SHA2">;
def : FMVExtension<"sha3", "SHA3">;
def : FMVExtension<"simd", "SIMD">;
def : FMVExtension<"sm4", "SM4">;
def : FMVExtension<"sme", "SME">;
def : FMVExtension<"sme-f64f64", "SME_F64">;
def : FMVExtension<"sme-i16i64", "SME_I64">;
def : FMVExtension<"sme2", "SME2">;
def : FMVExtension<"ssbs", "SSBS2">;
def : FMVExtension<"sve", "SVE">;
def : FMVExtension<"sve2", "SVE2">;
def : FMVExtension<"sve2-aes", "SVE_PMULL128">;
def : FMVExtension<"sve2-bitperm", "SVE_BITPERM">;
def : FMVExtension<"sve2-sha3", "SVE_SHA3">;
def : FMVExtension<"sve2-sm4", "SVE_SM4">;
def : FMVExtension<"wfxt", "WFXT">;
15 changes: 15 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/TargetParser/AArch64TargetParser.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
#include <algorithm>
Expand Down Expand Up @@ -245,6 +246,20 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
return false;
}

uint64_t AArch64TTIImpl::getFeatureMask(Function &F) const {
StringRef FeatureStr = F.getFnAttribute("target-features").getValueAsString();
SmallVector<StringRef, 8> Features;
FeatureStr.split(Features, ",");
return AArch64::getPriorityMask(Features);
}

bool AArch64TTIImpl::isMultiversionedFunction(Function &F) const {
StringRef FeatureStr = F.getFnAttribute("target-features").getValueAsString();
SmallVector<StringRef, 8> Features;
FeatureStr.split(Features, ",");
return any_of(Features, [](StringRef Feat) { return Feat == "+fmv"; });
}

bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
const Function *Callee) const {
SMEAttrs CallerAttrs(*Caller), CalleeAttrs(*Callee);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
unsigned getInlineCallPenalty(const Function *F, const CallBase &Call,
unsigned DefaultCallPenalty) const;

uint64_t getFeatureMask(Function &F) const;

bool isMultiversionedFunction(Function &F) const;

/// \name Scalar TTI Implementations
/// @{

Expand Down
36 changes: 28 additions & 8 deletions llvm/lib/TargetParser/AArch64TargetParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,44 @@ std::optional<AArch64::ArchInfo> AArch64::ArchInfo::findBySubArch(StringRef SubA
return {};
}

std::optional<AArch64::FMVInfo>
lookupFMVByID(llvm::AArch64::ArchExtKind ExtID) {
for (const auto &I : llvm::AArch64::getFMVInfo())
if (I.ID && *I.ID == ExtID)
return I;
return {};
}

unsigned AArch64::getFMVPriority(ArrayRef<StringRef> Features) {
constexpr unsigned MaxFMVPriority = 1000;
unsigned Priority = 0;
unsigned NumFeatures = 0;
constexpr unsigned MaxFMVPriority = 100;
uint64_t Priority = 0;
FeatPriorities TopBit = static_cast<FeatPriorities>(0);
for (StringRef Feature : Features) {
if (auto Ext = parseFMVExtension(Feature)) {
Priority = std::max(Priority, Ext->Priority);
NumFeatures++;
if (auto FMVExt = parseFMVExtension(Feature)) {
TopBit = std::max(TopBit, FMVExt->PriorityBit);
Priority |= (1ULL << FMVExt->PriorityBit);
}
}
return Priority + MaxFMVPriority * NumFeatures;
return TopBit + MaxFMVPriority * popcount(Priority);
}

uint64_t AArch64::getPriorityMask(ArrayRef<StringRef> Features) {
uint64_t PriorityMask = 0;
for (StringRef Feature : Features) {
if (auto FMVExt = parseFMVExtension(Feature))
PriorityMask |= (1ULL << FMVExt->PriorityBit);
else if (auto ArchExt = targetFeatureToExtension(Feature))
if (auto FMVExt = lookupFMVByID(ArchExt->ID))
PriorityMask |= (1ULL << FMVExt->PriorityBit);
}
return PriorityMask;
}

uint64_t AArch64::getCpuSupportsMask(ArrayRef<StringRef> FeatureStrs) {
uint64_t FeaturesMask = 0;
for (const StringRef &FeatureStr : FeatureStrs) {
if (auto Ext = parseFMVExtension(FeatureStr))
FeaturesMask |= (1ULL << Ext->Bit);
FeaturesMask |= (1ULL << Ext->FeatureBit);
}
return FeaturesMask;
}
Expand Down
Loading