-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] SME implementation for agnostic-ZA functions #120150
[AArch64] SME implementation for agnostic-ZA functions #120150
Conversation
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-backend-aarch64 Author: Sander de Smalen (sdesmalen-arm) ChangesThis implements the lowering of calls from agnostic-ZA functions to non-agnostic-ZA functions, using the ABI routines This implements the proposal described in the following PRs:
Patch is 26.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120150.diff 11 Files Affected:
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 55de486e90e190..d216b870281cc3 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2264,19 +2264,23 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs,
Check((Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") +
Attrs.hasFnAttr("aarch64_inout_za") +
Attrs.hasFnAttr("aarch64_out_za") +
- Attrs.hasFnAttr("aarch64_preserves_za")) <= 1,
+ Attrs.hasFnAttr("aarch64_preserves_za") +
+ Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
- "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive",
+ "'aarch64_inout_za', 'aarch64_preserves_za' and "
+ "'aarch64_za_state_agnostic' are mutually exclusive",
V);
- Check(
- (Attrs.hasFnAttr("aarch64_new_zt0") + Attrs.hasFnAttr("aarch64_in_zt0") +
- Attrs.hasFnAttr("aarch64_inout_zt0") +
- Attrs.hasFnAttr("aarch64_out_zt0") +
- Attrs.hasFnAttr("aarch64_preserves_zt0")) <= 1,
- "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
- "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive",
- V);
+ Check((Attrs.hasFnAttr("aarch64_new_zt0") +
+ Attrs.hasFnAttr("aarch64_in_zt0") +
+ Attrs.hasFnAttr("aarch64_inout_zt0") +
+ Attrs.hasFnAttr("aarch64_out_zt0") +
+ Attrs.hasFnAttr("aarch64_preserves_zt0") +
+ Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
+ "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
+ "'aarch64_inout_zt0', 'aarch64_preserves_zt0' and "
+ "'aarch64_za_state_agnostic' are mutually exclusive",
+ V);
if (Attrs.hasFnAttr(Attribute::JumpTable)) {
const GlobalValue *GV = cast<GlobalValue>(V);
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index 9f0f23b6e6a658..738895998c1195 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -5197,7 +5197,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
SMEAttrs CallerAttrs(*FuncInfo.Fn);
if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
CallerAttrs.hasStreamingInterfaceOrBody() ||
- CallerAttrs.hasStreamingCompatibleInterface())
+ CallerAttrs.hasStreamingCompatibleInterface() ||
+ CallerAttrs.hasAgnosticZAInterface())
return nullptr;
return new AArch64FastISel(FuncInfo, LibInfo);
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a86ee5a6b64d27..a70047a6cb2124 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2631,6 +2631,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
break;
MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
+ MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
+ MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
MAKE_CASE(AArch64ISD::VG_SAVE)
MAKE_CASE(AArch64ISD::VG_RESTORE)
@@ -3218,6 +3220,39 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
return BB;
}
+// TODO: Find a way to merge this with EmitAllocateZABuffer.
+MachineBasicBlock *
+AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
+ MachineBasicBlock *BB) const {
+ MachineFunction *MF = BB->getParent();
+ MachineFrameInfo &MFI = MF->getFrameInfo();
+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+ assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
+ "Lazy ZA save is not yet supported on Windows");
+
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+ if (FuncInfo->getSMESaveBufferUsed()) {
+ // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
+ auto Size = MI.getOperand(1).getReg();
+ auto Dest = MI.getOperand(0).getReg();
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest)
+ .addReg(AArch64::SP)
+ .addReg(Size)
+ .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+ AArch64::SP)
+ .addReg(Dest);
+
+ // We have just allocated a variable sized object, tell this to PEI.
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
+ } else
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
+ MI.getOperand(0).getReg());
+
+ BB->remove_instr(&MI);
+ return BB;
+}
+
MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
MachineInstr &MI, MachineBasicBlock *BB) const {
@@ -3252,6 +3287,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
return EmitInitTPIDR2Object(MI, BB);
case AArch64::AllocateZABuffer:
return EmitAllocateZABuffer(MI, BB);
+ case AArch64::AllocateSMESaveBuffer:
+ return EmitAllocateSMESaveBuffer(MI, BB);
+ case AArch64::GetSMESaveSize: {
+ // If the buffer is used, emit a call to __arm_sme_state_size()
+ MachineFunction *MF = BB->getParent();
+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+ if (FuncInfo->getSMESaveBufferUsed()) {
+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
+ .addExternalSymbol("__arm_sme_state_size")
+ .addReg(AArch64::X0, RegState::ImplicitDefine)
+ .addRegMask(TRI->getCallPreservedMask(
+ *MF, CallingConv::
+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+ MI.getOperand(0).getReg())
+ .addReg(AArch64::X0);
+ } else
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+ MI.getOperand(0).getReg())
+ .addReg(AArch64::XZR);
+ BB->remove_instr(&MI);
+ return BB;
+ }
case AArch64::F128CSEL:
return EmitF128CSEL(MI, BB);
case TargetOpcode::STATEPOINT:
@@ -7651,6 +7711,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
case CallingConv::AArch64_VectorCall:
case CallingConv::AArch64_SVE_VectorCall:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
return CC_AArch64_AAPCS;
case CallingConv::ARM64EC_Thunk_X64:
@@ -8110,6 +8171,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Chain = DAG.getNode(
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
+ } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
+ // Call __arm_sme_state_size().
+ SDValue BufferSize =
+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
+ Chain = BufferSize.getValue(1);
+
+ SDValue Buffer;
+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
+ Buffer =
+ DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
+ } else {
+ // Allocate space dynamically.
+ Buffer = DAG.getNode(
+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
+ }
+
+ // Copy the value to a virtual register, and save that in FuncInfo.
+ Register BufferPtr =
+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
}
if (CallConv == CallingConv::PreserveNone) {
@@ -8398,6 +8484,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
CallerAttrs.hasStreamingBody())
return false;
@@ -8722,6 +8809,32 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
}
+// Emit a call to __arm_sme_save or __arm_sme_restore.
+static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
+ SelectionDAG &DAG,
+ AArch64FunctionInfo *Info, SDLoc DL,
+ SDValue Chain, bool IsSave) {
+ TargetLowering::ArgListTy Args;
+ TargetLowering::ArgListEntry Entry;
+ Entry.Ty = PointerType::getUnqual(*DAG.getContext());
+ Entry.Node =
+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
+ Args.push_back(Entry);
+
+ SDValue Callee =
+ DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
+ TLI.getPointerTy(DAG.getDataLayout()));
+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
+ TargetLowering::CallLoweringInfo CLI(DAG);
+ CLI.setDebugLoc(DL)
+ .setChain(Chain)
+ .setLibCallee(
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1,
+ RetTy, Callee, std::move(Args));
+
+ return TLI.LowerCallTo(CLI).second;
+}
+
static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
const SMEAttrs &CalleeAttrs) {
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
@@ -8882,6 +8995,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
};
bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
+ bool RequiresSaveAllZA =
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
+ SDValue ZAStateBuffer;
if (RequiresLazySave) {
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
MachinePointerInfo MPI =
@@ -8908,6 +9024,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
&MF.getFunction());
return DescribeCallsite(R) << " sets up a lazy save for ZA";
});
+ } else if (RequiresSaveAllZA) {
+ assert(!CalleeAttrs.hasSharedZAInterface() &&
+ "Cannot share state that may not exist");
+ Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+ /*IsSave=*/true);
}
SDValue PStateSM;
@@ -9455,9 +9576,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64));
TPIDR2.Uses++;
+ } else if (RequiresSaveAllZA) {
+ Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+ /*IsSave=*/false);
+ FuncInfo->setSMESaveBufferUsed();
}
- if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
+ if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
+ RequiresSaveAllZA) {
for (unsigned I = 0; I < InVals.size(); ++I) {
// The smstart/smstop is chained as part of the call, but when the
// resulting chain is discarded (which happens when the call is not part
@@ -28063,7 +28189,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
auto CalleeAttrs = SMEAttrs(*Base);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
- CallerAttrs.requiresPreservingZT0(CalleeAttrs))
+ CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
return true;
}
return false;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d51b36f7e49946..8621aa81edfb2f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -466,6 +466,10 @@ enum NodeType : unsigned {
ALLOCATE_ZA_BUFFER,
INIT_TPIDR2OBJ,
+ // Needed for __arm_agnostic("sme_za_state")
+ GET_SME_SAVE_SIZE,
+ ALLOC_SME_SAVE_BUFFER,
+
// Asserts that a function argument (i32) is zero-extended to i8 by
// the caller
ASSERT_ZEXT_BOOL,
@@ -663,6 +667,8 @@ class AArch64TargetLowering : public TargetLowering {
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI,
MachineBasicBlock *BB) const;
+ MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI,
+ MachineBasicBlock *BB) const;
MachineBasicBlock *
EmitInstrWithCustomInserter(MachineInstr &MI,
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index a77fdaf19bcf5f..7fd3a6c560329c 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -229,6 +229,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// on function entry to record the initial pstate of a function.
Register PStateSMReg = MCRegister::NoRegister;
+ // Holds a pointer to a buffer that is large enough to represent
+ // all SME ZA state and any additional state required by the
+ // __arm_sme_save/restore support routines.
+ Register SMESaveBufferAddr = MCRegister::NoRegister;
+
+ // true if SMESaveBufferAddr is used.
+ bool SMESaveBufferUsed = false;
+
// Has the PNReg used to build PTRUE instruction.
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
unsigned PredicateRegForFillSpill = 0;
@@ -252,6 +260,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
return PredicateRegForFillSpill;
}
+ Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
+ void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };
+
+ unsigned getSMESaveBufferUsed() const { return SMESaveBufferUsed; };
+ void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; };
+
Register getPStateSMReg() const { return PStateSMReg; };
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index fa577cf92e99d1..ac11b048340498 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -52,6 +52,22 @@ let usesCustomInserter = 1 in {
def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
}
+// Nodes to allocate a save buffer for SME.
+def AArch64SMESaveSize : SDNode<"AArch64ISD::GET_SME_SAVE_SIZE", SDTypeProfile<1, 0,
+ [SDTCisInt<0>]>, [SDNPHasChain]>;
+let usesCustomInserter = 1, Defs = [X0] in {
+ def GetSMESaveSize : Pseudo<(outs GPR64:$dst), (ins), []>, Sched<[]> {}
+}
+def : Pat<(i64 AArch64SMESaveSize), (GetSMESaveSize)>;
+
+def AArch64AllocateSMESaveBuffer : SDNode<"AArch64ISD::ALLOC_SME_SAVE_BUFFER", SDTypeProfile<1, 1,
+ [SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain]>;
+let usesCustomInserter = 1, Defs = [SP] in {
+ def AllocateSMESaveBuffer : Pseudo<(outs GPR64sp:$dst), (ins GPR64:$size), []>, Sched<[WriteI]> {}
+}
+def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
+ (AllocateSMESaveBuffer $size)>;
+
//===----------------------------------------------------------------------===//
// Instruction naming conventions.
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 6c2e04c3f8a7c1..ddb57090cf718a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -240,6 +240,17 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
(cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
isSMEABIRoutineCall(cast<CallInst>(I))))
return true;
+
+ if (auto *CB = dyn_cast<CallBase>(&I)) {
+ SMEAttrs CallerAttrs(*CB->getCaller()),
+ CalleeAttrs(*CB->getCalledFunction());
+ // When trying to determine if we can inline callees, we must check
+ // that for agnostic-ZA functions, they don't call any functions
+ // that are not agnostic-ZA, as that would require inserting of
+ // save/restore code.
+ if (CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
+ return true;
+ }
}
}
return false;
@@ -261,7 +272,13 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresSMChange(CalleeAttrs) ||
- CallerAttrs.requiresPreservingZT0(CalleeAttrs)) {
+ CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
+ if (hasPossibleIncompatibleOps(Callee))
+ return false;
+ }
+
+ if (CalleeAttrs.hasAgnosticZAInterface()) {
if (hasPossibleIncompatibleOps(Callee))
return false;
}
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 015ca4cb92b25e..bf16acd7f8f7e1 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -38,6 +38,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
isPreservesZT0())) &&
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
+
+ assert(!(hasAgnosticZAInterface() && hasSharedZAInterface()) &&
+ "Function cannot have a shared-ZA interface and an agnostic-ZA "
+ "interface");
}
SMEAttrs::SMEAttrs(const CallBase &CB) {
@@ -56,6 +60,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
Bitmask |= SMEAttrs::SM_Compatible;
+ if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
+ FuncName == "__arm_sme_state_size")
+ Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
}
SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
@@ -66,6 +73,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= SM_Compatible;
if (Attrs.hasFnAttr("aarch64_pstate_sm_body"))
Bitmask |= SM_Body;
+ if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
+ Bitmask |= ZA_State_Agnostic;
if (Attrs.hasFnAttr("aarch64_in_za"))
Bitmask |= encodeZAState(StateValue::In);
if (Attrs.hasFnAttr("aarch64_out_za"))
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 4c7c1c9b079538..fb093da70c46b6 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -42,9 +42,10 @@ class SMEAttrs {
SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible
SM_Body = 1 << 2, // aarch64_pstate_sm_body
SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
- ZA_Shift = 4,
+ ZA_State_Agnostic = 1 << 4,
+ ZA_Shift = 5,
ZA_Mask = 0b111 << ZA_Shift,
- ZT0_Shift = 7,
+ ZT0_Shift = 8,
ZT0_Mask = 0b111 << ZT0_Shift
};
@@ -96,8 +97,11 @@ class SMEAttrs {
ret...
[truncated]
|
@llvm/pr-subscribers-llvm-ir Author: Sander de Smalen (sdesmalen-arm) ChangesThis implements the lowering of calls from agnostic-ZA functions to non-agnostic-ZA functions, using the ABI routines This implements the proposal described in the following PRs:
Patch is 26.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120150.diff 11 Files Affected:
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 55de486e90e190..d216b870281cc3 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2264,19 +2264,23 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs,
Check((Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") +
Attrs.hasFnAttr("aarch64_inout_za") +
Attrs.hasFnAttr("aarch64_out_za") +
- Attrs.hasFnAttr("aarch64_preserves_za")) <= 1,
+ Attrs.hasFnAttr("aarch64_preserves_za") +
+ Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
"Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
- "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive",
+ "'aarch64_inout_za', 'aarch64_preserves_za' and "
+ "'aarch64_za_state_agnostic' are mutually exclusive",
V);
- Check(
- (Attrs.hasFnAttr("aarch64_new_zt0") + Attrs.hasFnAttr("aarch64_in_zt0") +
- Attrs.hasFnAttr("aarch64_inout_zt0") +
- Attrs.hasFnAttr("aarch64_out_zt0") +
- Attrs.hasFnAttr("aarch64_preserves_zt0")) <= 1,
- "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
- "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive",
- V);
+ Check((Attrs.hasFnAttr("aarch64_new_zt0") +
+ Attrs.hasFnAttr("aarch64_in_zt0") +
+ Attrs.hasFnAttr("aarch64_inout_zt0") +
+ Attrs.hasFnAttr("aarch64_out_zt0") +
+ Attrs.hasFnAttr("aarch64_preserves_zt0") +
+ Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
+ "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
+ "'aarch64_inout_zt0', 'aarch64_preserves_zt0' and "
+ "'aarch64_za_state_agnostic' are mutually exclusive",
+ V);
if (Attrs.hasFnAttr(Attribute::JumpTable)) {
const GlobalValue *GV = cast<GlobalValue>(V);
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index 9f0f23b6e6a658..738895998c1195 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -5197,7 +5197,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
SMEAttrs CallerAttrs(*FuncInfo.Fn);
if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
CallerAttrs.hasStreamingInterfaceOrBody() ||
- CallerAttrs.hasStreamingCompatibleInterface())
+ CallerAttrs.hasStreamingCompatibleInterface() ||
+ CallerAttrs.hasAgnosticZAInterface())
return nullptr;
return new AArch64FastISel(FuncInfo, LibInfo);
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a86ee5a6b64d27..a70047a6cb2124 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2631,6 +2631,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
break;
MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
+ MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
+ MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
MAKE_CASE(AArch64ISD::VG_SAVE)
MAKE_CASE(AArch64ISD::VG_RESTORE)
@@ -3218,6 +3220,39 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
return BB;
}
+// TODO: Find a way to merge this with EmitAllocateZABuffer.
+MachineBasicBlock *
+AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
+ MachineBasicBlock *BB) const {
+ MachineFunction *MF = BB->getParent();
+ MachineFrameInfo &MFI = MF->getFrameInfo();
+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+ assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
+ "Lazy ZA save is not yet supported on Windows");
+
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+ if (FuncInfo->getSMESaveBufferUsed()) {
+ // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
+ auto Size = MI.getOperand(1).getReg();
+ auto Dest = MI.getOperand(0).getReg();
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest)
+ .addReg(AArch64::SP)
+ .addReg(Size)
+ .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+ AArch64::SP)
+ .addReg(Dest);
+
+ // We have just allocated a variable sized object, tell this to PEI.
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
+ } else
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
+ MI.getOperand(0).getReg());
+
+ BB->remove_instr(&MI);
+ return BB;
+}
+
MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
MachineInstr &MI, MachineBasicBlock *BB) const {
@@ -3252,6 +3287,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
return EmitInitTPIDR2Object(MI, BB);
case AArch64::AllocateZABuffer:
return EmitAllocateZABuffer(MI, BB);
+ case AArch64::AllocateSMESaveBuffer:
+ return EmitAllocateSMESaveBuffer(MI, BB);
+ case AArch64::GetSMESaveSize: {
+ // If the buffer is used, emit a call to __arm_sme_state_size()
+ MachineFunction *MF = BB->getParent();
+ AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+ const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+ if (FuncInfo->getSMESaveBufferUsed()) {
+ const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
+ .addExternalSymbol("__arm_sme_state_size")
+ .addReg(AArch64::X0, RegState::ImplicitDefine)
+ .addRegMask(TRI->getCallPreservedMask(
+ *MF, CallingConv::
+ AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+ MI.getOperand(0).getReg())
+ .addReg(AArch64::X0);
+ } else
+ BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+ MI.getOperand(0).getReg())
+ .addReg(AArch64::XZR);
+ BB->remove_instr(&MI);
+ return BB;
+ }
case AArch64::F128CSEL:
return EmitF128CSEL(MI, BB);
case TargetOpcode::STATEPOINT:
@@ -7651,6 +7711,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
case CallingConv::AArch64_VectorCall:
case CallingConv::AArch64_SVE_VectorCall:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
+ case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
return CC_AArch64_AAPCS;
case CallingConv::ARM64EC_Thunk_X64:
@@ -8110,6 +8171,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
Chain = DAG.getNode(
AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
{/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
+ } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
+ // Call __arm_sme_state_size().
+ SDValue BufferSize =
+ DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
+ DAG.getVTList(MVT::i64, MVT::Other), Chain);
+ Chain = BufferSize.getValue(1);
+
+ SDValue Buffer;
+ if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
+ Buffer =
+ DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
+ DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
+ } else {
+ // Allocate space dynamically.
+ Buffer = DAG.getNode(
+ ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
+ {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
+ MFI.CreateVariableSizedObject(Align(16), nullptr);
+ }
+
+ // Copy the value to a virtual register, and save that in FuncInfo.
+ Register BufferPtr =
+ MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+ FuncInfo->setSMESaveBufferAddr(BufferPtr);
+ Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
}
if (CallConv == CallingConv::PreserveNone) {
@@ -8398,6 +8484,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
CallerAttrs.hasStreamingBody())
return false;
@@ -8722,6 +8809,32 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
}
+// Emit a call to __arm_sme_save or __arm_sme_restore.
+static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
+ SelectionDAG &DAG,
+ AArch64FunctionInfo *Info, SDLoc DL,
+ SDValue Chain, bool IsSave) {
+ TargetLowering::ArgListTy Args;
+ TargetLowering::ArgListEntry Entry;
+ Entry.Ty = PointerType::getUnqual(*DAG.getContext());
+ Entry.Node =
+ DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
+ Args.push_back(Entry);
+
+ SDValue Callee =
+ DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
+ TLI.getPointerTy(DAG.getDataLayout()));
+ auto *RetTy = Type::getVoidTy(*DAG.getContext());
+ TargetLowering::CallLoweringInfo CLI(DAG);
+ CLI.setDebugLoc(DL)
+ .setChain(Chain)
+ .setLibCallee(
+ CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1,
+ RetTy, Callee, std::move(Args));
+
+ return TLI.LowerCallTo(CLI).second;
+}
+
static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
const SMEAttrs &CalleeAttrs) {
if (!CallerAttrs.hasStreamingCompatibleInterface() ||
@@ -8882,6 +8995,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
};
bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
+ bool RequiresSaveAllZA =
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
+ SDValue ZAStateBuffer;
if (RequiresLazySave) {
const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
MachinePointerInfo MPI =
@@ -8908,6 +9024,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
&MF.getFunction());
return DescribeCallsite(R) << " sets up a lazy save for ZA";
});
+ } else if (RequiresSaveAllZA) {
+ assert(!CalleeAttrs.hasSharedZAInterface() &&
+ "Cannot share state that may not exist");
+ Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+ /*IsSave=*/true);
}
SDValue PStateSM;
@@ -9455,9 +9576,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
DAG.getConstant(0, DL, MVT::i64));
TPIDR2.Uses++;
+ } else if (RequiresSaveAllZA) {
+ Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+ /*IsSave=*/false);
+ FuncInfo->setSMESaveBufferUsed();
}
- if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
+ if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
+ RequiresSaveAllZA) {
for (unsigned I = 0; I < InVals.size(); ++I) {
// The smstart/smstop is chained as part of the call, but when the
// resulting chain is discarded (which happens when the call is not part
@@ -28063,7 +28189,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
auto CalleeAttrs = SMEAttrs(*Base);
if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
CallerAttrs.requiresLazySave(CalleeAttrs) ||
- CallerAttrs.requiresPreservingZT0(CalleeAttrs))
+ CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
return true;
}
return false;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d51b36f7e49946..8621aa81edfb2f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -466,6 +466,10 @@ enum NodeType : unsigned {
ALLOCATE_ZA_BUFFER,
INIT_TPIDR2OBJ,
+ // Needed for __arm_agnostic("sme_za_state")
+ GET_SME_SAVE_SIZE,
+ ALLOC_SME_SAVE_BUFFER,
+
// Asserts that a function argument (i32) is zero-extended to i8 by
// the caller
ASSERT_ZEXT_BOOL,
@@ -663,6 +667,8 @@ class AArch64TargetLowering : public TargetLowering {
MachineBasicBlock *BB) const;
MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI,
MachineBasicBlock *BB) const;
+ MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI,
+ MachineBasicBlock *BB) const;
MachineBasicBlock *
EmitInstrWithCustomInserter(MachineInstr &MI,
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index a77fdaf19bcf5f..7fd3a6c560329c 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -229,6 +229,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// on function entry to record the initial pstate of a function.
Register PStateSMReg = MCRegister::NoRegister;
+ // Holds a pointer to a buffer that is large enough to represent
+ // all SME ZA state and any additional state required by the
+ // __arm_sme_save/restore support routines.
+ Register SMESaveBufferAddr = MCRegister::NoRegister;
+
+ // true if SMESaveBufferAddr is used.
+ bool SMESaveBufferUsed = false;
+
// Has the PNReg used to build PTRUE instruction.
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
unsigned PredicateRegForFillSpill = 0;
@@ -252,6 +260,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
return PredicateRegForFillSpill;
}
+ Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
+ void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };
+
+ unsigned getSMESaveBufferUsed() const { return SMESaveBufferUsed; };
+ void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; };
+
Register getPStateSMReg() const { return PStateSMReg; };
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index fa577cf92e99d1..ac11b048340498 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -52,6 +52,22 @@ let usesCustomInserter = 1 in {
def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
}
+// Nodes to allocate a save buffer for SME.
+def AArch64SMESaveSize : SDNode<"AArch64ISD::GET_SME_SAVE_SIZE", SDTypeProfile<1, 0,
+ [SDTCisInt<0>]>, [SDNPHasChain]>;
+let usesCustomInserter = 1, Defs = [X0] in {
+ def GetSMESaveSize : Pseudo<(outs GPR64:$dst), (ins), []>, Sched<[]> {}
+}
+def : Pat<(i64 AArch64SMESaveSize), (GetSMESaveSize)>;
+
+def AArch64AllocateSMESaveBuffer : SDNode<"AArch64ISD::ALLOC_SME_SAVE_BUFFER", SDTypeProfile<1, 1,
+ [SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain]>;
+let usesCustomInserter = 1, Defs = [SP] in {
+ def AllocateSMESaveBuffer : Pseudo<(outs GPR64sp:$dst), (ins GPR64:$size), []>, Sched<[WriteI]> {}
+}
+def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
+ (AllocateSMESaveBuffer $size)>;
+
//===----------------------------------------------------------------------===//
// Instruction naming conventions.
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 6c2e04c3f8a7c1..ddb57090cf718a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -240,6 +240,17 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
(cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
isSMEABIRoutineCall(cast<CallInst>(I))))
return true;
+
+ if (auto *CB = dyn_cast<CallBase>(&I)) {
+ SMEAttrs CallerAttrs(*CB->getCaller()),
+ CalleeAttrs(*CB->getCalledFunction());
+ // When trying to determine if we can inline callees, we must check
+ // that for agnostic-ZA functions, they don't call any functions
+ // that are not agnostic-ZA, as that would require inserting of
+ // save/restore code.
+ if (CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
+ return true;
+ }
}
}
return false;
@@ -261,7 +272,13 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
CallerAttrs.requiresSMChange(CalleeAttrs) ||
- CallerAttrs.requiresPreservingZT0(CalleeAttrs)) {
+ CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+ CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
+ if (hasPossibleIncompatibleOps(Callee))
+ return false;
+ }
+
+ if (CalleeAttrs.hasAgnosticZAInterface()) {
if (hasPossibleIncompatibleOps(Callee))
return false;
}
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 015ca4cb92b25e..bf16acd7f8f7e1 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -38,6 +38,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
isPreservesZT0())) &&
"Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
"'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
+
+ assert(!(hasAgnosticZAInterface() && hasSharedZAInterface()) &&
+ "Function cannot have a shared-ZA interface and an agnostic-ZA "
+ "interface");
}
SMEAttrs::SMEAttrs(const CallBase &CB) {
@@ -56,6 +60,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
Bitmask |= SMEAttrs::SM_Compatible;
+ if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
+ FuncName == "__arm_sme_state_size")
+ Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
}
SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
@@ -66,6 +73,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
Bitmask |= SM_Compatible;
if (Attrs.hasFnAttr("aarch64_pstate_sm_body"))
Bitmask |= SM_Body;
+ if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
+ Bitmask |= ZA_State_Agnostic;
if (Attrs.hasFnAttr("aarch64_in_za"))
Bitmask |= encodeZAState(StateValue::In);
if (Attrs.hasFnAttr("aarch64_out_za"))
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 4c7c1c9b079538..fb093da70c46b6 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -42,9 +42,10 @@ class SMEAttrs {
SM_Compatible = 1 << 1, // aarch64_pstate_sm_compatible
SM_Body = 1 << 2, // aarch64_pstate_sm_body
SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
- ZA_Shift = 4,
+ ZA_State_Agnostic = 1 << 4,
+ ZA_Shift = 5,
ZA_Mask = 0b111 << ZA_Shift,
- ZT0_Shift = 7,
+ ZT0_Shift = 8,
ZT0_Mask = 0b111 << ZT0_Shift
};
@@ -96,8 +97,11 @@ class SMEAttrs {
ret...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
// When trying to determine if we can inline callees, we must check | ||
// that for agnostic-ZA functions, they don't call any functions | ||
// that are not agnostic-ZA, as that would require inserting of | ||
// save/restore code. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What problem is this solving? If we inline calls that might clobber za in a context where za is not live, it doesn't matter. If we inline calls that might clobber za into a context where za might be live, we'll insert save/restore calls.
If we care about reducing the number of save/restore calls, we can change code generation to remove redundant restore/save pairs. (We can use a dataflow analysis after isel.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What problem is this solving? If we inline calls that might clobber za in a context where za is not live, it doesn't matter. If we inline calls that might clobber za into a context where za might be live, we'll insert save/restore calls.
Thanks for pointing out. You're right that this change made no sense and can be removed.
If we care about reducing the number of save/restore calls, we can change code generation to remove redundant restore/save pairs. (We can use a dataflow analysis after isel.)
The expectation is that this is mostly used for leaf functions, or otherwise for functions where the transitive closure of callees is also agnostic-ZA. Calls to private-ZA functions would be expensive and therefore limited to debug routines in practice, which would probably be executed under a condition.
5d78ebf
to
4510b01
Compare
b5a6d3e
to
9b847a1
Compare
; [x] A -> Z | ||
; [ ] A -> S | ||
; [ ] A -> A | ||
define void @agnostic_za_caller_now_za_callee_dont_inline() "aarch64_za_state_agnostic" { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/now/new/
This implements the lowering of calls from agnostic-ZA functions to non-agnostic-ZA functions, using the ABI routines `__arm_sme_state_size`, `__arm_sme_save` and `__arm_sme_restore`. This implements the proposal described in the following PRs: * ARM-software/acle#336 * ARM-software/abi-aa#264
9b847a1
to
2fd87f7
Compare
This implements the lowering of calls from agnostic-ZA functions to non-agnostic-ZA functions, using the ABI routines
__arm_sme_state_size
,__arm_sme_save
and__arm_sme_restore
.This implements the proposal described in the following PRs: