Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] SME implementation for agnostic-ZA functions #120150

Merged
merged 5 commits into from
Dec 23, 2024

Conversation

sdesmalen-arm
Copy link
Collaborator

This implements the lowering of calls from agnostic-ZA functions to non-agnostic-ZA functions, using the ABI routines __arm_sme_state_size, __arm_sme_save and __arm_sme_restore.

This implements the proposal described in the following PRs:

@llvmbot
Copy link
Member

llvmbot commented Dec 16, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-aarch64

Author: Sander de Smalen (sdesmalen-arm)

Changes

This implements the lowering of calls from agnostic-ZA functions to non-agnostic-ZA functions, using the ABI routines __arm_sme_state_size, __arm_sme_save and __arm_sme_restore.

This implements the proposal described in the following PRs:


Patch is 26.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120150.diff

11 Files Affected:

  • (modified) llvm/lib/IR/Verifier.cpp (+14-10)
  • (modified) llvm/lib/Target/AArch64/AArch64FastISel.cpp (+2-1)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+129-2)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+6)
  • (modified) llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h (+14)
  • (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+16)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+18-1)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp (+9)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h (+13-4)
  • (added) llvm/test/CodeGen/AArch64/sme-agnostic-za.ll (+84)
  • (modified) llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll (+24)
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 55de486e90e190..d216b870281cc3 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2264,19 +2264,23 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs,
   Check((Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") +
          Attrs.hasFnAttr("aarch64_inout_za") +
          Attrs.hasFnAttr("aarch64_out_za") +
-         Attrs.hasFnAttr("aarch64_preserves_za")) <= 1,
+         Attrs.hasFnAttr("aarch64_preserves_za") +
+         Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
         "Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
-        "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive",
+        "'aarch64_inout_za', 'aarch64_preserves_za' and "
+        "'aarch64_za_state_agnostic' are mutually exclusive",
         V);
 
-  Check(
-      (Attrs.hasFnAttr("aarch64_new_zt0") + Attrs.hasFnAttr("aarch64_in_zt0") +
-       Attrs.hasFnAttr("aarch64_inout_zt0") +
-       Attrs.hasFnAttr("aarch64_out_zt0") +
-       Attrs.hasFnAttr("aarch64_preserves_zt0")) <= 1,
-      "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
-      "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive",
-      V);
+  Check((Attrs.hasFnAttr("aarch64_new_zt0") +
+         Attrs.hasFnAttr("aarch64_in_zt0") +
+         Attrs.hasFnAttr("aarch64_inout_zt0") +
+         Attrs.hasFnAttr("aarch64_out_zt0") +
+         Attrs.hasFnAttr("aarch64_preserves_zt0") +
+         Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
+        "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
+        "'aarch64_inout_zt0', 'aarch64_preserves_zt0' and "
+        "'aarch64_za_state_agnostic' are mutually exclusive",
+        V);
 
   if (Attrs.hasFnAttr(Attribute::JumpTable)) {
     const GlobalValue *GV = cast<GlobalValue>(V);
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index 9f0f23b6e6a658..738895998c1195 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -5197,7 +5197,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
   SMEAttrs CallerAttrs(*FuncInfo.Fn);
   if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
       CallerAttrs.hasStreamingInterfaceOrBody() ||
-      CallerAttrs.hasStreamingCompatibleInterface())
+      CallerAttrs.hasStreamingCompatibleInterface() ||
+      CallerAttrs.hasAgnosticZAInterface())
     return nullptr;
   return new AArch64FastISel(FuncInfo, LibInfo);
 }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a86ee5a6b64d27..a70047a6cb2124 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2631,6 +2631,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     break;
     MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
     MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
+    MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
+    MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
     MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
     MAKE_CASE(AArch64ISD::VG_SAVE)
     MAKE_CASE(AArch64ISD::VG_RESTORE)
@@ -3218,6 +3220,39 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
   return BB;
 }
 
+// TODO: Find a way to merge this with EmitAllocateZABuffer.
+MachineBasicBlock *
+AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
+                                                 MachineBasicBlock *BB) const {
+  MachineFunction *MF = BB->getParent();
+  MachineFrameInfo &MFI = MF->getFrameInfo();
+  AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+  assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
+         "Lazy ZA save is not yet supported on Windows");
+
+  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+  if (FuncInfo->getSMESaveBufferUsed()) {
+    // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
+    auto Size = MI.getOperand(1).getReg();
+    auto Dest = MI.getOperand(0).getReg();
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest)
+        .addReg(AArch64::SP)
+        .addReg(Size)
+        .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+            AArch64::SP)
+        .addReg(Dest);
+
+    // We have just allocated a variable sized object, tell this to PEI.
+    MFI.CreateVariableSizedObject(Align(16), nullptr);
+  } else
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
+            MI.getOperand(0).getReg());
+
+  BB->remove_instr(&MI);
+  return BB;
+}
+
 MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     MachineInstr &MI, MachineBasicBlock *BB) const {
 
@@ -3252,6 +3287,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     return EmitInitTPIDR2Object(MI, BB);
   case AArch64::AllocateZABuffer:
     return EmitAllocateZABuffer(MI, BB);
+  case AArch64::AllocateSMESaveBuffer:
+    return EmitAllocateSMESaveBuffer(MI, BB);
+  case AArch64::GetSMESaveSize: {
+    // If the buffer is used, emit a call to __arm_sme_state_size()
+    MachineFunction *MF = BB->getParent();
+    AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+    const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+    if (FuncInfo->getSMESaveBufferUsed()) {
+      const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
+          .addExternalSymbol("__arm_sme_state_size")
+          .addReg(AArch64::X0, RegState::ImplicitDefine)
+          .addRegMask(TRI->getCallPreservedMask(
+              *MF, CallingConv::
+                       AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
+      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+              MI.getOperand(0).getReg())
+          .addReg(AArch64::X0);
+    } else
+      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+              MI.getOperand(0).getReg())
+          .addReg(AArch64::XZR);
+    BB->remove_instr(&MI);
+    return BB;
+  }
   case AArch64::F128CSEL:
     return EmitF128CSEL(MI, BB);
   case TargetOpcode::STATEPOINT:
@@ -7651,6 +7711,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
   case CallingConv::AArch64_VectorCall:
   case CallingConv::AArch64_SVE_VectorCall:
   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
+  case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
     return CC_AArch64_AAPCS;
   case CallingConv::ARM64EC_Thunk_X64:
@@ -8110,6 +8171,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
     Chain = DAG.getNode(
         AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
         {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
+  } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
+    // Call __arm_sme_state_size().
+    SDValue BufferSize =
+        DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
+                    DAG.getVTList(MVT::i64, MVT::Other), Chain);
+    Chain = BufferSize.getValue(1);
+
+    SDValue Buffer;
+    if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
+      Buffer =
+          DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
+                      DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
+    } else {
+      // Allocate space dynamically.
+      Buffer = DAG.getNode(
+          ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
+          {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
+      MFI.CreateVariableSizedObject(Align(16), nullptr);
+    }
+
+    // Copy the value to a virtual register, and save that in FuncInfo.
+    Register BufferPtr =
+        MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+    FuncInfo->setSMESaveBufferAddr(BufferPtr);
+    Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
   }
 
   if (CallConv == CallingConv::PreserveNone) {
@@ -8398,6 +8484,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
   auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
   if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
       CallerAttrs.requiresLazySave(CalleeAttrs) ||
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
       CallerAttrs.hasStreamingBody())
     return false;
 
@@ -8722,6 +8809,32 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
   return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
 }
 
+// Emit a call to __arm_sme_save or __arm_sme_restore.
+static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
+                                       SelectionDAG &DAG,
+                                       AArch64FunctionInfo *Info, SDLoc DL,
+                                       SDValue Chain, bool IsSave) {
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  Entry.Ty = PointerType::getUnqual(*DAG.getContext());
+  Entry.Node =
+      DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
+  Args.push_back(Entry);
+
+  SDValue Callee =
+      DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
+                            TLI.getPointerTy(DAG.getDataLayout()));
+  auto *RetTy = Type::getVoidTy(*DAG.getContext());
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  CLI.setDebugLoc(DL)
+      .setChain(Chain)
+      .setLibCallee(
+          CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1,
+          RetTy, Callee, std::move(Args));
+
+  return TLI.LowerCallTo(CLI).second;
+}
+
 static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
                                const SMEAttrs &CalleeAttrs) {
   if (!CallerAttrs.hasStreamingCompatibleInterface() ||
@@ -8882,6 +8995,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   };
 
   bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
+  bool RequiresSaveAllZA =
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
+  SDValue ZAStateBuffer;
   if (RequiresLazySave) {
     const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
     MachinePointerInfo MPI =
@@ -8908,6 +9024,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
                                                    &MF.getFunction());
       return DescribeCallsite(R) << " sets up a lazy save for ZA";
     });
+  } else if (RequiresSaveAllZA) {
+    assert(!CalleeAttrs.hasSharedZAInterface() &&
+           "Cannot share state that may not exist");
+    Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+                                    /*IsSave=*/true);
   }
 
   SDValue PStateSM;
@@ -9455,9 +9576,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
         DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
         DAG.getConstant(0, DL, MVT::i64));
     TPIDR2.Uses++;
+  } else if (RequiresSaveAllZA) {
+    Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+                                     /*IsSave=*/false);
+    FuncInfo->setSMESaveBufferUsed();
   }
 
-  if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
+  if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
+      RequiresSaveAllZA) {
     for (unsigned I = 0; I < InVals.size(); ++I) {
       // The smstart/smstop is chained as part of the call, but when the
       // resulting chain is discarded (which happens when the call is not part
@@ -28063,7 +28189,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
     auto CalleeAttrs = SMEAttrs(*Base);
     if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
         CallerAttrs.requiresLazySave(CalleeAttrs) ||
-        CallerAttrs.requiresPreservingZT0(CalleeAttrs))
+        CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+        CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
       return true;
   }
   return false;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d51b36f7e49946..8621aa81edfb2f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -466,6 +466,10 @@ enum NodeType : unsigned {
   ALLOCATE_ZA_BUFFER,
   INIT_TPIDR2OBJ,
 
+  // Needed for __arm_agnostic("sme_za_state")
+  GET_SME_SAVE_SIZE,
+  ALLOC_SME_SAVE_BUFFER,
+
   // Asserts that a function argument (i32) is zero-extended to i8 by
   // the caller
   ASSERT_ZEXT_BOOL,
@@ -663,6 +667,8 @@ class AArch64TargetLowering : public TargetLowering {
                                           MachineBasicBlock *BB) const;
   MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI,
                                           MachineBasicBlock *BB) const;
+  MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI,
+                                               MachineBasicBlock *BB) const;
 
   MachineBasicBlock *
   EmitInstrWithCustomInserter(MachineInstr &MI,
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index a77fdaf19bcf5f..7fd3a6c560329c 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -229,6 +229,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   // on function entry to record the initial pstate of a function.
   Register PStateSMReg = MCRegister::NoRegister;
 
+  // Holds a pointer to a buffer that is large enough to represent
+  // all SME ZA state and any additional state required by the
+  // __arm_sme_save/restore support routines.
+  Register SMESaveBufferAddr = MCRegister::NoRegister;
+
+  // true if SMESaveBufferAddr is used.
+  bool SMESaveBufferUsed = false;
+
   // Has the PNReg used to build PTRUE instruction.
   // The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
   unsigned PredicateRegForFillSpill = 0;
@@ -252,6 +260,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
     return PredicateRegForFillSpill;
   }
 
+  Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
+  void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };
+
+  unsigned getSMESaveBufferUsed() const { return SMESaveBufferUsed; };
+  void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; };
+
   Register getPStateSMReg() const { return PStateSMReg; };
   void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
 
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index fa577cf92e99d1..ac11b048340498 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -52,6 +52,22 @@ let usesCustomInserter = 1 in {
   def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
 }
 
+// Nodes to allocate a save buffer for SME.
+def AArch64SMESaveSize : SDNode<"AArch64ISD::GET_SME_SAVE_SIZE", SDTypeProfile<1, 0,
+                               [SDTCisInt<0>]>, [SDNPHasChain]>;
+let usesCustomInserter = 1, Defs = [X0] in {
+  def GetSMESaveSize : Pseudo<(outs GPR64:$dst), (ins), []>, Sched<[]> {}
+}
+def : Pat<(i64 AArch64SMESaveSize), (GetSMESaveSize)>;
+
+def AArch64AllocateSMESaveBuffer : SDNode<"AArch64ISD::ALLOC_SME_SAVE_BUFFER", SDTypeProfile<1, 1,
+                                          [SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain]>;
+let usesCustomInserter = 1, Defs = [SP] in {
+  def AllocateSMESaveBuffer : Pseudo<(outs GPR64sp:$dst), (ins GPR64:$size), []>, Sched<[WriteI]> {}
+}
+def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
+          (AllocateSMESaveBuffer $size)>;
+
 //===----------------------------------------------------------------------===//
 // Instruction naming conventions.
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 6c2e04c3f8a7c1..ddb57090cf718a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -240,6 +240,17 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
           (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
            isSMEABIRoutineCall(cast<CallInst>(I))))
         return true;
+
+      if (auto *CB = dyn_cast<CallBase>(&I)) {
+        SMEAttrs CallerAttrs(*CB->getCaller()),
+            CalleeAttrs(*CB->getCalledFunction());
+        // When trying to determine if we can inline callees, we must check
+        // that for agnostic-ZA functions, they don't call any functions
+        // that are not agnostic-ZA, as that would require inserting of
+        // save/restore code.
+        if (CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
+          return true;
+      }
     }
   }
   return false;
@@ -261,7 +272,13 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
 
   if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
       CallerAttrs.requiresSMChange(CalleeAttrs) ||
-      CallerAttrs.requiresPreservingZT0(CalleeAttrs)) {
+      CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
+    if (hasPossibleIncompatibleOps(Callee))
+      return false;
+  }
+
+  if (CalleeAttrs.hasAgnosticZAInterface()) {
     if (hasPossibleIncompatibleOps(Callee))
       return false;
   }
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 015ca4cb92b25e..bf16acd7f8f7e1 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -38,6 +38,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
                         isPreservesZT0())) &&
       "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
       "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
+
+  assert(!(hasAgnosticZAInterface() && hasSharedZAInterface()) &&
+         "Function cannot have a shared-ZA interface and an agnostic-ZA "
+         "interface");
 }
 
 SMEAttrs::SMEAttrs(const CallBase &CB) {
@@ -56,6 +60,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
   if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
       FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
     Bitmask |= SMEAttrs::SM_Compatible;
+  if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
+      FuncName == "__arm_sme_state_size")
+    Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
@@ -66,6 +73,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
     Bitmask |= SM_Compatible;
   if (Attrs.hasFnAttr("aarch64_pstate_sm_body"))
     Bitmask |= SM_Body;
+  if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
+    Bitmask |= ZA_State_Agnostic;
   if (Attrs.hasFnAttr("aarch64_in_za"))
     Bitmask |= encodeZAState(StateValue::In);
   if (Attrs.hasFnAttr("aarch64_out_za"))
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 4c7c1c9b079538..fb093da70c46b6 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -42,9 +42,10 @@ class SMEAttrs {
     SM_Compatible = 1 << 1,   // aarch64_pstate_sm_compatible
     SM_Body = 1 << 2,         // aarch64_pstate_sm_body
     SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
-    ZA_Shift = 4,
+    ZA_State_Agnostic = 1 << 4,
+    ZA_Shift = 5,
     ZA_Mask = 0b111 << ZA_Shift,
-    ZT0_Shift = 7,
+    ZT0_Shift = 8,
     ZT0_Mask = 0b111 << ZT0_Shift
   };
 
@@ -96,8 +97,11 @@ class SMEAttrs {
     ret...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 16, 2024

@llvm/pr-subscribers-llvm-ir

Author: Sander de Smalen (sdesmalen-arm)

Changes

This implements the lowering of calls from agnostic-ZA functions to non-agnostic-ZA functions, using the ABI routines __arm_sme_state_size, __arm_sme_save and __arm_sme_restore.

This implements the proposal described in the following PRs:


Patch is 26.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120150.diff

11 Files Affected:

  • (modified) llvm/lib/IR/Verifier.cpp (+14-10)
  • (modified) llvm/lib/Target/AArch64/AArch64FastISel.cpp (+2-1)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+129-2)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+6)
  • (modified) llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h (+14)
  • (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+16)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+18-1)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp (+9)
  • (modified) llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h (+13-4)
  • (added) llvm/test/CodeGen/AArch64/sme-agnostic-za.ll (+84)
  • (modified) llvm/test/CodeGen/AArch64/sme-disable-gisel-fisel.ll (+24)
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 55de486e90e190..d216b870281cc3 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -2264,19 +2264,23 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs,
   Check((Attrs.hasFnAttr("aarch64_new_za") + Attrs.hasFnAttr("aarch64_in_za") +
          Attrs.hasFnAttr("aarch64_inout_za") +
          Attrs.hasFnAttr("aarch64_out_za") +
-         Attrs.hasFnAttr("aarch64_preserves_za")) <= 1,
+         Attrs.hasFnAttr("aarch64_preserves_za") +
+         Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
         "Attributes 'aarch64_new_za', 'aarch64_in_za', 'aarch64_out_za', "
-        "'aarch64_inout_za' and 'aarch64_preserves_za' are mutually exclusive",
+        "'aarch64_inout_za', 'aarch64_preserves_za' and "
+        "'aarch64_za_state_agnostic' are mutually exclusive",
         V);
 
-  Check(
-      (Attrs.hasFnAttr("aarch64_new_zt0") + Attrs.hasFnAttr("aarch64_in_zt0") +
-       Attrs.hasFnAttr("aarch64_inout_zt0") +
-       Attrs.hasFnAttr("aarch64_out_zt0") +
-       Attrs.hasFnAttr("aarch64_preserves_zt0")) <= 1,
-      "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
-      "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive",
-      V);
+  Check((Attrs.hasFnAttr("aarch64_new_zt0") +
+         Attrs.hasFnAttr("aarch64_in_zt0") +
+         Attrs.hasFnAttr("aarch64_inout_zt0") +
+         Attrs.hasFnAttr("aarch64_out_zt0") +
+         Attrs.hasFnAttr("aarch64_preserves_zt0") +
+         Attrs.hasFnAttr("aarch64_za_state_agnostic")) <= 1,
+        "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
+        "'aarch64_inout_zt0', 'aarch64_preserves_zt0' and "
+        "'aarch64_za_state_agnostic' are mutually exclusive",
+        V);
 
   if (Attrs.hasFnAttr(Attribute::JumpTable)) {
     const GlobalValue *GV = cast<GlobalValue>(V);
diff --git a/llvm/lib/Target/AArch64/AArch64FastISel.cpp b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
index 9f0f23b6e6a658..738895998c1195 100644
--- a/llvm/lib/Target/AArch64/AArch64FastISel.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FastISel.cpp
@@ -5197,7 +5197,8 @@ FastISel *AArch64::createFastISel(FunctionLoweringInfo &FuncInfo,
   SMEAttrs CallerAttrs(*FuncInfo.Fn);
   if (CallerAttrs.hasZAState() || CallerAttrs.hasZT0State() ||
       CallerAttrs.hasStreamingInterfaceOrBody() ||
-      CallerAttrs.hasStreamingCompatibleInterface())
+      CallerAttrs.hasStreamingCompatibleInterface() ||
+      CallerAttrs.hasAgnosticZAInterface())
     return nullptr;
   return new AArch64FastISel(FuncInfo, LibInfo);
 }
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a86ee5a6b64d27..a70047a6cb2124 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2631,6 +2631,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     break;
     MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
     MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
+    MAKE_CASE(AArch64ISD::GET_SME_SAVE_SIZE)
+    MAKE_CASE(AArch64ISD::ALLOC_SME_SAVE_BUFFER)
     MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
     MAKE_CASE(AArch64ISD::VG_SAVE)
     MAKE_CASE(AArch64ISD::VG_RESTORE)
@@ -3218,6 +3220,39 @@ AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
   return BB;
 }
 
+// TODO: Find a way to merge this with EmitAllocateZABuffer.
+MachineBasicBlock *
+AArch64TargetLowering::EmitAllocateSMESaveBuffer(MachineInstr &MI,
+                                                 MachineBasicBlock *BB) const {
+  MachineFunction *MF = BB->getParent();
+  MachineFrameInfo &MFI = MF->getFrameInfo();
+  AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+  assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
+         "Lazy ZA save is not yet supported on Windows");
+
+  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+  if (FuncInfo->getSMESaveBufferUsed()) {
+    // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
+    auto Size = MI.getOperand(1).getReg();
+    auto Dest = MI.getOperand(0).getReg();
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::SUBXrx64), Dest)
+        .addReg(AArch64::SP)
+        .addReg(Size)
+        .addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0));
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+            AArch64::SP)
+        .addReg(Dest);
+
+    // We have just allocated a variable sized object, tell this to PEI.
+    MFI.CreateVariableSizedObject(Align(16), nullptr);
+  } else
+    BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::IMPLICIT_DEF),
+            MI.getOperand(0).getReg());
+
+  BB->remove_instr(&MI);
+  return BB;
+}
+
 MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     MachineInstr &MI, MachineBasicBlock *BB) const {
 
@@ -3252,6 +3287,31 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
     return EmitInitTPIDR2Object(MI, BB);
   case AArch64::AllocateZABuffer:
     return EmitAllocateZABuffer(MI, BB);
+  case AArch64::AllocateSMESaveBuffer:
+    return EmitAllocateSMESaveBuffer(MI, BB);
+  case AArch64::GetSMESaveSize: {
+    // If the buffer is used, emit a call to __arm_sme_state_size()
+    MachineFunction *MF = BB->getParent();
+    AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
+    const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+    if (FuncInfo->getSMESaveBufferUsed()) {
+      const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::BL))
+          .addExternalSymbol("__arm_sme_state_size")
+          .addReg(AArch64::X0, RegState::ImplicitDefine)
+          .addRegMask(TRI->getCallPreservedMask(
+              *MF, CallingConv::
+                       AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1));
+      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+              MI.getOperand(0).getReg())
+          .addReg(AArch64::X0);
+    } else
+      BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
+              MI.getOperand(0).getReg())
+          .addReg(AArch64::XZR);
+    BB->remove_instr(&MI);
+    return BB;
+  }
   case AArch64::F128CSEL:
     return EmitF128CSEL(MI, BB);
   case TargetOpcode::STATEPOINT:
@@ -7651,6 +7711,7 @@ CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
   case CallingConv::AArch64_VectorCall:
   case CallingConv::AArch64_SVE_VectorCall:
   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
+  case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1:
   case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
     return CC_AArch64_AAPCS;
   case CallingConv::ARM64EC_Thunk_X64:
@@ -8110,6 +8171,31 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
     Chain = DAG.getNode(
         AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
         {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
+  } else if (SMEAttrs(MF.getFunction()).hasAgnosticZAInterface()) {
+    // Call __arm_sme_state_size().
+    SDValue BufferSize =
+        DAG.getNode(AArch64ISD::GET_SME_SAVE_SIZE, DL,
+                    DAG.getVTList(MVT::i64, MVT::Other), Chain);
+    Chain = BufferSize.getValue(1);
+
+    SDValue Buffer;
+    if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
+      Buffer =
+          DAG.getNode(AArch64ISD::ALLOC_SME_SAVE_BUFFER, DL,
+                      DAG.getVTList(MVT::i64, MVT::Other), {Chain, BufferSize});
+    } else {
+      // Allocate space dynamically.
+      Buffer = DAG.getNode(
+          ISD::DYNAMIC_STACKALLOC, DL, DAG.getVTList(MVT::i64, MVT::Other),
+          {Chain, BufferSize, DAG.getConstant(1, DL, MVT::i64)});
+      MFI.CreateVariableSizedObject(Align(16), nullptr);
+    }
+
+    // Copy the value to a virtual register, and save that in FuncInfo.
+    Register BufferPtr =
+        MF.getRegInfo().createVirtualRegister(&AArch64::GPR64RegClass);
+    FuncInfo->setSMESaveBufferAddr(BufferPtr);
+    Chain = DAG.getCopyToReg(Chain, DL, BufferPtr, Buffer);
   }
 
   if (CallConv == CallingConv::PreserveNone) {
@@ -8398,6 +8484,7 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
   auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
   if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
       CallerAttrs.requiresLazySave(CalleeAttrs) ||
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs) ||
       CallerAttrs.hasStreamingBody())
     return false;
 
@@ -8722,6 +8809,32 @@ SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
   return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
 }
 
+// Emit a call to __arm_sme_save or __arm_sme_restore.
+static SDValue emitSMEStateSaveRestore(const AArch64TargetLowering &TLI,
+                                       SelectionDAG &DAG,
+                                       AArch64FunctionInfo *Info, SDLoc DL,
+                                       SDValue Chain, bool IsSave) {
+  TargetLowering::ArgListTy Args;
+  TargetLowering::ArgListEntry Entry;
+  Entry.Ty = PointerType::getUnqual(*DAG.getContext());
+  Entry.Node =
+      DAG.getCopyFromReg(Chain, DL, Info->getSMESaveBufferAddr(), MVT::i64);
+  Args.push_back(Entry);
+
+  SDValue Callee =
+      DAG.getExternalSymbol(IsSave ? "__arm_sme_save" : "__arm_sme_restore",
+                            TLI.getPointerTy(DAG.getDataLayout()));
+  auto *RetTy = Type::getVoidTy(*DAG.getContext());
+  TargetLowering::CallLoweringInfo CLI(DAG);
+  CLI.setDebugLoc(DL)
+      .setChain(Chain)
+      .setLibCallee(
+          CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1,
+          RetTy, Callee, std::move(Args));
+
+  return TLI.LowerCallTo(CLI).second;
+}
+
 static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
                                const SMEAttrs &CalleeAttrs) {
   if (!CallerAttrs.hasStreamingCompatibleInterface() ||
@@ -8882,6 +8995,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   };
 
   bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
+  bool RequiresSaveAllZA =
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs);
+  SDValue ZAStateBuffer;
   if (RequiresLazySave) {
     const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
     MachinePointerInfo MPI =
@@ -8908,6 +9024,11 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
                                                    &MF.getFunction());
       return DescribeCallsite(R) << " sets up a lazy save for ZA";
     });
+  } else if (RequiresSaveAllZA) {
+    assert(!CalleeAttrs.hasSharedZAInterface() &&
+           "Cannot share state that may not exist");
+    Chain = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+                                    /*IsSave=*/true);
   }
 
   SDValue PStateSM;
@@ -9455,9 +9576,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
         DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
         DAG.getConstant(0, DL, MVT::i64));
     TPIDR2.Uses++;
+  } else if (RequiresSaveAllZA) {
+    Result = emitSMEStateSaveRestore(*this, DAG, FuncInfo, DL, Chain,
+                                     /*IsSave=*/false);
+    FuncInfo->setSMESaveBufferUsed();
   }
 
-  if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
+  if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0 ||
+      RequiresSaveAllZA) {
     for (unsigned I = 0; I < InVals.size(); ++I) {
       // The smstart/smstop is chained as part of the call, but when the
       // resulting chain is discarded (which happens when the call is not part
@@ -28063,7 +28189,8 @@ bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
     auto CalleeAttrs = SMEAttrs(*Base);
     if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
         CallerAttrs.requiresLazySave(CalleeAttrs) ||
-        CallerAttrs.requiresPreservingZT0(CalleeAttrs))
+        CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+        CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
       return true;
   }
   return false;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d51b36f7e49946..8621aa81edfb2f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -466,6 +466,10 @@ enum NodeType : unsigned {
   ALLOCATE_ZA_BUFFER,
   INIT_TPIDR2OBJ,
 
+  // Needed for __arm_agnostic("sme_za_state")
+  GET_SME_SAVE_SIZE,
+  ALLOC_SME_SAVE_BUFFER,
+
   // Asserts that a function argument (i32) is zero-extended to i8 by
   // the caller
   ASSERT_ZEXT_BOOL,
@@ -663,6 +667,8 @@ class AArch64TargetLowering : public TargetLowering {
                                           MachineBasicBlock *BB) const;
   MachineBasicBlock *EmitAllocateZABuffer(MachineInstr &MI,
                                           MachineBasicBlock *BB) const;
+  MachineBasicBlock *EmitAllocateSMESaveBuffer(MachineInstr &MI,
+                                               MachineBasicBlock *BB) const;
 
   MachineBasicBlock *
   EmitInstrWithCustomInserter(MachineInstr &MI,
diff --git a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
index a77fdaf19bcf5f..7fd3a6c560329c 100644
--- a/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
@@ -229,6 +229,14 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   // on function entry to record the initial pstate of a function.
   Register PStateSMReg = MCRegister::NoRegister;
 
+  // Holds a pointer to a buffer that is large enough to represent
+  // all SME ZA state and any additional state required by the
+  // __arm_sme_save/restore support routines.
+  Register SMESaveBufferAddr = MCRegister::NoRegister;
+
+  // true if SMESaveBufferAddr is used.
+  bool SMESaveBufferUsed = false;
+
   // Has the PNReg used to build PTRUE instruction.
   // The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
   unsigned PredicateRegForFillSpill = 0;
@@ -252,6 +260,12 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
     return PredicateRegForFillSpill;
   }
 
+  Register getSMESaveBufferAddr() const { return SMESaveBufferAddr; };
+  void setSMESaveBufferAddr(Register Reg) { SMESaveBufferAddr = Reg; };
+
+  unsigned getSMESaveBufferUsed() const { return SMESaveBufferUsed; };
+  void setSMESaveBufferUsed(bool Used = true) { SMESaveBufferUsed = Used; };
+
   Register getPStateSMReg() const { return PStateSMReg; };
   void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };
 
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index fa577cf92e99d1..ac11b048340498 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -52,6 +52,22 @@ let usesCustomInserter = 1 in {
   def InitTPIDR2Obj : Pseudo<(outs), (ins GPR64:$buffer), [(AArch64InitTPIDR2Obj GPR64:$buffer)]>, Sched<[WriteI]> {}
 }
 
+// Nodes to allocate a save buffer for SME.
+def AArch64SMESaveSize : SDNode<"AArch64ISD::GET_SME_SAVE_SIZE", SDTypeProfile<1, 0,
+                               [SDTCisInt<0>]>, [SDNPHasChain]>;
+let usesCustomInserter = 1, Defs = [X0] in {
+  def GetSMESaveSize : Pseudo<(outs GPR64:$dst), (ins), []>, Sched<[]> {}
+}
+def : Pat<(i64 AArch64SMESaveSize), (GetSMESaveSize)>;
+
+def AArch64AllocateSMESaveBuffer : SDNode<"AArch64ISD::ALLOC_SME_SAVE_BUFFER", SDTypeProfile<1, 1,
+                                          [SDTCisInt<0>, SDTCisInt<1>]>, [SDNPHasChain]>;
+let usesCustomInserter = 1, Defs = [SP] in {
+  def AllocateSMESaveBuffer : Pseudo<(outs GPR64sp:$dst), (ins GPR64:$size), []>, Sched<[WriteI]> {}
+}
+def : Pat<(i64 (AArch64AllocateSMESaveBuffer GPR64:$size)),
+          (AllocateSMESaveBuffer $size)>;
+
 //===----------------------------------------------------------------------===//
 // Instruction naming conventions.
 //===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 6c2e04c3f8a7c1..ddb57090cf718a 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -240,6 +240,17 @@ static bool hasPossibleIncompatibleOps(const Function *F) {
           (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
            isSMEABIRoutineCall(cast<CallInst>(I))))
         return true;
+
+      if (auto *CB = dyn_cast<CallBase>(&I)) {
+        SMEAttrs CallerAttrs(*CB->getCaller()),
+            CalleeAttrs(*CB->getCalledFunction());
+        // When trying to determine if we can inline callees, we must check
+        // that for agnostic-ZA functions, they don't call any functions
+        // that are not agnostic-ZA, as that would require inserting of
+        // save/restore code.
+        if (CallerAttrs.requiresPreservingAllZAState(CalleeAttrs))
+          return true;
+      }
     }
   }
   return false;
@@ -261,7 +272,13 @@ bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
 
   if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
       CallerAttrs.requiresSMChange(CalleeAttrs) ||
-      CallerAttrs.requiresPreservingZT0(CalleeAttrs)) {
+      CallerAttrs.requiresPreservingZT0(CalleeAttrs) ||
+      CallerAttrs.requiresPreservingAllZAState(CalleeAttrs)) {
+    if (hasPossibleIncompatibleOps(Callee))
+      return false;
+  }
+
+  if (CalleeAttrs.hasAgnosticZAInterface()) {
     if (hasPossibleIncompatibleOps(Callee))
       return false;
   }
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
index 015ca4cb92b25e..bf16acd7f8f7e1 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.cpp
@@ -38,6 +38,10 @@ void SMEAttrs::set(unsigned M, bool Enable) {
                         isPreservesZT0())) &&
       "Attributes 'aarch64_new_zt0', 'aarch64_in_zt0', 'aarch64_out_zt0', "
       "'aarch64_inout_zt0' and 'aarch64_preserves_zt0' are mutually exclusive");
+
+  assert(!(hasAgnosticZAInterface() && hasSharedZAInterface()) &&
+         "Function cannot have a shared-ZA interface and an agnostic-ZA "
+         "interface");
 }
 
 SMEAttrs::SMEAttrs(const CallBase &CB) {
@@ -56,6 +60,9 @@ SMEAttrs::SMEAttrs(StringRef FuncName) : Bitmask(0) {
   if (FuncName == "__arm_sc_memcpy" || FuncName == "__arm_sc_memset" ||
       FuncName == "__arm_sc_memmove" || FuncName == "__arm_sc_memchr")
     Bitmask |= SMEAttrs::SM_Compatible;
+  if (FuncName == "__arm_sme_save" || FuncName == "__arm_sme_restore" ||
+      FuncName == "__arm_sme_state_size")
+    Bitmask |= SMEAttrs::SM_Compatible | SMEAttrs::SME_ABI_Routine;
 }
 
 SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
@@ -66,6 +73,8 @@ SMEAttrs::SMEAttrs(const AttributeList &Attrs) {
     Bitmask |= SM_Compatible;
   if (Attrs.hasFnAttr("aarch64_pstate_sm_body"))
     Bitmask |= SM_Body;
+  if (Attrs.hasFnAttr("aarch64_za_state_agnostic"))
+    Bitmask |= ZA_State_Agnostic;
   if (Attrs.hasFnAttr("aarch64_in_za"))
     Bitmask |= encodeZAState(StateValue::In);
   if (Attrs.hasFnAttr("aarch64_out_za"))
diff --git a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
index 4c7c1c9b079538..fb093da70c46b6 100644
--- a/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
+++ b/llvm/lib/Target/AArch64/Utils/AArch64SMEAttributes.h
@@ -42,9 +42,10 @@ class SMEAttrs {
     SM_Compatible = 1 << 1,   // aarch64_pstate_sm_compatible
     SM_Body = 1 << 2,         // aarch64_pstate_sm_body
     SME_ABI_Routine = 1 << 3, // Used for SME ABI routines to avoid lazy saves
-    ZA_Shift = 4,
+    ZA_State_Agnostic = 1 << 4,
+    ZA_Shift = 5,
     ZA_Mask = 0b111 << ZA_Shift,
-    ZT0_Shift = 7,
+    ZT0_Shift = 8,
     ZT0_Mask = 0b111 << ZT0_Shift
   };
 
@@ -96,8 +97,11 @@ class SMEAttrs {
     ret...
[truncated]

Copy link

github-actions bot commented Dec 16, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

// When trying to determine if we can inline callees, we must check
// that for agnostic-ZA functions, they don't call any functions
// that are not agnostic-ZA, as that would require inserting of
// save/restore code.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What problem is this solving? If we inline calls that might clobber za in a context where za is not live, it doesn't matter. If we inline calls that might clobber za into a context where za might be live, we'll insert save/restore calls.

If we care about reducing the number of save/restore calls, we can change code generation to remove redundant restore/save pairs. (We can use a dataflow analysis after isel.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What problem is this solving? If we inline calls that might clobber za in a context where za is not live, it doesn't matter. If we inline calls that might clobber za into a context where za might be live, we'll insert save/restore calls.

Thanks for pointing out. You're right that this change made no sense and can be removed.

If we care about reducing the number of save/restore calls, we can change code generation to remove redundant restore/save pairs. (We can use a dataflow analysis after isel.)

The expectation is that this is mostly used for leaf functions, or otherwise for functions where the transitive closure of callees is also agnostic-ZA. Calls to private-ZA functions would be expensive and therefore limited to debug routines in practice, which would probably be executed under a condition.

; [x] A -> Z
; [ ] A -> S
; [ ] A -> A
define void @agnostic_za_caller_now_za_callee_dont_inline() "aarch64_za_state_agnostic" {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/now/new/

This implements the lowering of calls from agnostic-ZA
functions to non-agnostic-ZA functions, using the ABI routines
`__arm_sme_state_size`, `__arm_sme_save` and `__arm_sme_restore`.

This implements the proposal described in the following PRs:
* ARM-software/acle#336
* ARM-software/abi-aa#264
@sdesmalen-arm sdesmalen-arm merged commit 2ce168b into llvm:main Dec 23, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants