Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 131 additions & 36 deletions llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,9 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
switch (I->getOpcode()) {
default:
return false;
case AArch64::PTRUE_C_B:
case AArch64::LD1B_2Z_IMM:
case AArch64::ST1B_2Z_IMM:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As future work, I wonder if we can extend this further to use the quad variants of these instructions as well.

case AArch64::STR_ZXI:
case AArch64::STR_PXI:
case AArch64::LDR_ZXI:
Expand Down Expand Up @@ -2859,7 +2862,11 @@ static void computeCalleeSaveRegisterPairs(
RPI.Reg2 = NextReg;
break;
case RegPairInfo::PPR:
break;
case RegPairInfo::ZPR:
if (AFI->getPredicateRegForFillSpill() != 0)
if (((RPI.Reg1 - AArch64::Z0) & 1) == 0 && (NextReg == RPI.Reg1 + 1))
RPI.Reg2 = NextReg;
break;
}
}
Expand Down Expand Up @@ -2904,7 +2911,7 @@ static void computeCalleeSaveRegisterPairs(
assert(OffsetPre % Scale == 0);

if (RPI.isScalable())
ScalableByteOffset += StackFillDir * Scale;
ScalableByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale);
else
ByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale);

Expand All @@ -2915,9 +2922,6 @@ static void computeCalleeSaveRegisterPairs(
(IsWindows && RPI.Reg2 == AArch64::LR)))
ByteOffset += StackFillDir * 8;

assert(!(RPI.isScalable() && RPI.isPaired()) &&
"Paired spill/fill instructions don't exist for SVE vectors");

// Round up size of non-pair to pair size if we need to pad the
// callee-save area to ensure 16-byte alignment.
if (NeedGapToAlignStack && !NeedsWinCFI &&
Expand Down Expand Up @@ -3004,6 +3008,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
}
return true;
}
bool PtrueCreated = false;
for (const RegPairInfo &RPI : llvm::reverse(RegPairs)) {
unsigned Reg1 = RPI.Reg1;
unsigned Reg2 = RPI.Reg2;
Expand Down Expand Up @@ -3038,10 +3043,10 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
Alignment = Align(16);
break;
case RegPairInfo::ZPR:
StrOpc = AArch64::STR_ZXI;
Size = 16;
Alignment = Align(16);
break;
StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
Size = 16;
Alignment = Align(16);
break;
case RegPairInfo::PPR:
StrOpc = AArch64::STR_PXI;
Size = 2;
Expand All @@ -3065,33 +3070,64 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
std::swap(Reg1, Reg2);
std::swap(FrameIdxReg1, FrameIdxReg2);
}
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
if (!MRI.isReserved(Reg1))
MBB.addLiveIn(Reg1);
if (RPI.isPaired()) {

if (RPI.isPaired() && RPI.isScalable()) {
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
unsigned PnReg = AFI->getPredicateRegForFillSpill();
Copy link
Collaborator

@momchil-velikov momchil-velikov Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, the following are more like style remarks:

  • PnReg can be declared outside the loop. Then you can initialise it under if (!PtrueCreated) ..., which seems more logical place to do it.
  • PairRegs has only one use and it's on the following line, you don't save anything by naming the corresponding expression

if (!PtrueCreated) {
PtrueCreated = true;
BuildMI(MBB, MI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
.setMIFlags(MachineInstr::FrameSetup);
}
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
if (!MRI.isReserved(Reg1))
MBB.addLiveIn(Reg1);
if (!MRI.isReserved(Reg2))
MBB.addLiveIn(Reg2);
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
MIB.addReg(/*PairRegs*/ AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0));
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
MachineMemOperand::MOStore, Size, Alignment));
MIB.addReg(PnReg);
MIB.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale],
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameSetup);
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
MachineMemOperand::MOStore, Size, Alignment));
if (NeedsWinCFI)
InsertSEH(MIB, TII, MachineInstr::FrameSetup);
} else { // The code when the pair of ZReg is not present
MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
if (!MRI.isReserved(Reg1))
MBB.addLiveIn(Reg1);
if (RPI.isPaired()) {
if (!MRI.isReserved(Reg2))
MBB.addLiveIn(Reg2);
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
MachineMemOperand::MOStore, Size, Alignment));
}
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1))
.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale],
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameSetup);
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
MachineMemOperand::MOStore, Size, Alignment));
if (NeedsWinCFI)
InsertSEH(MIB, TII, MachineInstr::FrameSetup);
}
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1))
.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale],
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameSetup);
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
MachineMemOperand::MOStore, Size, Alignment));
if (NeedsWinCFI)
InsertSEH(MIB, TII, MachineInstr::FrameSetup);

// Update the StackIDs of the SVE stack slots.
MachineFrameInfo &MFI = MF.getFrameInfo();
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR)
MFI.setStackID(RPI.FrameIdx, TargetStackID::ScalableVector);

if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) {
MFI.setStackID(FrameIdxReg1, TargetStackID::ScalableVector);
if (RPI.isPaired())
MFI.setStackID(FrameIdxReg2, TargetStackID::ScalableVector);
}
}
return true;
}
Expand All @@ -3109,7 +3145,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
DL = MBBI->getDebugLoc();

computeCalleeSaveRegisterPairs(MF, CSI, TRI, RegPairs, hasFP(MF));

if (homogeneousPrologEpilog(MF, &MBB)) {
auto MIB = BuildMI(MBB, MBBI, DL, TII.get(AArch64::HOM_Epilog))
.setMIFlag(MachineInstr::FrameDestroy);
Expand All @@ -3130,6 +3166,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
auto ZPREnd = std::find_if_not(ZPRBegin, RegPairs.end(), IsZPR);
std::reverse(ZPRBegin, ZPREnd);

bool PtrueCreated = false;
for (const RegPairInfo &RPI : RegPairs) {
unsigned Reg1 = RPI.Reg1;
unsigned Reg2 = RPI.Reg2;
Expand Down Expand Up @@ -3162,7 +3199,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
Alignment = Align(16);
break;
case RegPairInfo::ZPR:
LdrOpc = AArch64::LDR_ZXI;
LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
Size = 16;
Alignment = Align(16);
break;
Expand All @@ -3187,15 +3224,41 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
std::swap(Reg1, Reg2);
std::swap(FrameIdxReg1, FrameIdxReg2);
}
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
if (RPI.isPaired()) {
MIB.addReg(Reg2, getDefRegState(true));

AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
if (RPI.isPaired() && RPI.isScalable()) {
unsigned PnReg = AFI->getPredicateRegForFillSpill();
if (!PtrueCreated) {
PtrueCreated = true;
BuildMI(MBB, MBBI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
.setMIFlags(MachineInstr::FrameDestroy);
}
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
MIB.addReg(/*PairRegs*/ AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0),
getDefRegState(true));
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
MachineMemOperand::MOLoad, Size, Alignment));
}
MIB.addReg(Reg1, getDefRegState(true))
.addReg(AArch64::SP)
MIB.addReg(PnReg);
MIB.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale]
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameDestroy);
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg1),
MachineMemOperand::MOLoad, Size, Alignment));
if (NeedsWinCFI)
InsertSEH(MIB, TII, MachineInstr::FrameDestroy);
} else {
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
if (RPI.isPaired()) {
MIB.addReg(Reg2, getDefRegState(true));
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
MachineMemOperand::MOLoad, Size, Alignment));
}
MIB.addReg(Reg1, getDefRegState(true));
MIB.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale]
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameDestroy);
Expand All @@ -3204,8 +3267,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
MachineMemOperand::MOLoad, Size, Alignment));
if (NeedsWinCFI)
InsertSEH(MIB, TII, MachineInstr::FrameDestroy);
}
}

return true;
}

Expand Down Expand Up @@ -3234,6 +3297,7 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,

unsigned ExtraCSSpill = 0;
bool HasUnpairedGPR64 = false;
bool HasPairZReg = false;
// Figure out which callee-saved registers to save/restore.
for (unsigned i = 0; CSRegs[i]; ++i) {
const unsigned Reg = CSRegs[i];
Expand Down Expand Up @@ -3287,8 +3351,39 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
!RegInfo->isReservedReg(MF, PairedReg))
ExtraCSSpill = PairedReg;
}

// Save PReg in FunctionInfo to build PTRUE instruction later. The PTRUE is
// being used in the function to save and restore the pair of ZReg
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
if (Subtarget.hasSVE2p1() || Subtarget.hasSME2()) {
if (AArch64::PPRRegClass.contains(Reg) &&
(Reg >= AArch64::P8 && Reg <= AArch64::P15) && SavedRegs.test(Reg) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When choosing a callee-saved register, there is the assumption that in the prologue this register will already be spilled before overwriting it with a new ptrue, and in the epilogue that it will be filled after defining it with a ptrue. This is dependent on the order in which the registers are specified in the AArch64CallingConvention.td file and the order in which they are iterated. To avoid this ever silently doing the wrong thing, can you add some asserts in restoreCalleeSavedRegisters and spillCalleeSavedRegisters to guard that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't see where you've addressed this. Did you miss this comment?

AFI->getPredicateRegForFillSpill() == 0)
AFI->setPredicateRegForFillSpill((Reg - AArch64::P0) + AArch64::PN0);

// Check if there is a pair of ZRegs, so it can select P8 to create PTRUE,
// in case there is no PRege being saved(above)
HasPairZReg =
HasPairZReg || (AArch64::ZPRRegClass.contains(Reg, CSRegs[i ^ 1]) &&
SavedRegs.test(CSRegs[i ^ 1]));
}
}

// Make sure there is a PReg saved to be used in save and restore when there
// is ZReg pair.
if ((Subtarget.hasSVE2p1() || Subtarget.hasSME2()) &&
(MF.getFunction().getCallingConv() ==
CallingConv::AArch64_SVE_VectorCall ||
MF.getFunction().getCallingConv() ==
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 ||
MF.getFunction().getCallingConv() ==
CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2))
if (AFI->getPredicateRegForFillSpill() == 0 && HasPairZReg) {
assert(!RegInfo->isReservedReg(MF, AArch64::P8) && "P8 is reserved");
SavedRegs.set(AArch64::P8);
AFI->setPredicateRegForFillSpill(AArch64::PN8);
}

if (MF.getFunction().getCallingConv() == CallingConv::Win64 &&
!Subtarget.isTargetWindows()) {
// For Windows calling convention on a non-windows OS, where X18 is treated
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
// on function entry to record the initial pstate of a function.
Register PStateSMReg = MCRegister::NoRegister;

// Has the PNReg used to build PTRUE instruction.
// The PTRUE is used for the LD/ST of ZReg pairs in save and restore.
unsigned PredicateRegForFillSpill = 0;

public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);

Expand All @@ -220,6 +224,13 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
const DenseMap<MachineBasicBlock *, MachineBasicBlock *> &Src2DstMBB)
const override;

void setPredicateRegForFillSpill(unsigned Reg) {
PredicateRegForFillSpill = Reg;
}
unsigned getPredicateRegForFillSpill() const {
return PredicateRegForFillSpill;
}

Register getPStateSMReg() const { return PStateSMReg; };
void setPStateSMReg(Register Reg) { PStateSMReg = Reg; };

Expand Down
Loading