Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 70 additions & 18 deletions llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,12 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
switch (I->getOpcode()) {
default:
return false;
case AArch64::PTRUE_C_B:
case AArch64::LD1B_2Z_IMM:
case AArch64::ST1B_2Z_IMM:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As future work, I wonder if we can extend this further to use the quad variants of these instructions as well.

assert((I->getMF()->getSubtarget<AArch64Subtarget>().hasSVE2p1() ||
I->getMF()->getSubtarget<AArch64Subtarget>().hasSME2()) &&
"Expected SME2 or SVE2.1 Targer Architecture.");
case AArch64::STR_ZXI:
case AArch64::STR_PXI:
case AArch64::LDR_ZXI:
Expand Down Expand Up @@ -2791,6 +2797,7 @@ static void computeCalleeSaveRegisterPairs(

bool IsWindows = isTargetWindows(MF);
bool NeedsWinCFI = needsWinCFI(MF);
const auto &Subtarget = MF.getSubtarget<AArch64Subtarget>();
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
MachineFrameInfo &MFI = MF.getFrameInfo();
CallingConv::ID CC = MF.getFunction().getCallingConv();
Expand Down Expand Up @@ -2859,7 +2866,11 @@ static void computeCalleeSaveRegisterPairs(
RPI.Reg2 = NextReg;
break;
case RegPairInfo::PPR:
break;
case RegPairInfo::ZPR:
if (Subtarget.hasSVE2p1() || Subtarget.hasSME2())
if (((RPI.Reg1 - AArch64::Z0) & 1) == 0 && (NextReg == RPI.Reg1 + 1))
RPI.Reg2 = NextReg;
break;
}
}
Expand Down Expand Up @@ -2904,7 +2915,7 @@ static void computeCalleeSaveRegisterPairs(
assert(OffsetPre % Scale == 0);

if (RPI.isScalable())
ScalableByteOffset += StackFillDir * Scale;
ScalableByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale);
else
ByteOffset += StackFillDir * (RPI.isPaired() ? 2 * Scale : Scale);

Expand All @@ -2915,9 +2926,6 @@ static void computeCalleeSaveRegisterPairs(
(IsWindows && RPI.Reg2 == AArch64::LR)))
ByteOffset += StackFillDir * 8;

assert(!(RPI.isScalable() && RPI.isPaired()) &&
"Paired spill/fill instructions don't exist for SVE vectors");

// Round up size of non-pair to pair size if we need to pad the
// callee-save area to ensure 16-byte alignment.
if (NeedGapToAlignStack && !NeedsWinCFI &&
Expand Down Expand Up @@ -3004,6 +3012,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
}
return true;
}
bool PtrueCreated = false;
for (const RegPairInfo &RPI : llvm::reverse(RegPairs)) {
unsigned Reg1 = RPI.Reg1;
unsigned Reg2 = RPI.Reg2;
Expand Down Expand Up @@ -3038,10 +3047,10 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
Alignment = Align(16);
break;
case RegPairInfo::ZPR:
StrOpc = AArch64::STR_ZXI;
Size = 16;
Alignment = Align(16);
break;
StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
Size = 16;
Alignment = Align(16);
break;
case RegPairInfo::PPR:
StrOpc = AArch64::STR_PXI;
Size = 2;
Expand All @@ -3065,19 +3074,40 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
std::swap(Reg1, Reg2);
std::swap(FrameIdxReg1, FrameIdxReg2);
}

unsigned PairRegs;
unsigned PnReg;
if (RPI.isPaired() && RPI.isScalable()) {
PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
if (!PtrueCreated) {
PtrueCreated = true;
// Any one of predicate-as-count will be free to use
// This can be replaced in the future if needed
PnReg = AArch64::PN8;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not correct to blindly pick PN8 (P8) here. You can only clobber P8 if it is preserved by the preceding predicate callee-saves.

i.e.

define void @test_clobbers_3_z_regs(<vscale x 16 x i8> %v) {
  call void asm sideeffect "", "~{z8},~{z9}"()
  ret void
}

results in:

        str     x29, [sp, #-16]!
        addvl   sp, sp, #-2
        ptrue   pn8.b       ; pn8 is not preserved by foo, even though the AAPCS says that it should.
        st1b    { z8.b, z9.b }, pn8, [sp]
        ld1b    { z8.b, z9.b }, pn8/z, [sp]
        addvl   sp, sp, #2
        ldr     x29, [sp], #16
        ret

One thing you could do is try to see if one of the argument registers is available (p0 - p3), so that you can reuse one of those. Alternatively, you could mark p8 as clobbered by the function so that the preceding callee-save spills will include p8.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not correct to blindly pick PN8 (P8) here. You can only clobber P8 if it is preserved by the preceding predicate callee-saves.

Good point. I guess I misread the AAPCS when I suggested to just pick an arbitrary register as scratch.

Alternatively, you could mark p8 as clobbered by the function so that the preceding callee-save spills will include p8.

I would prefer this solution.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The P8 register is added to the list of SavedRegs now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little concerned with just blindly picking pn8 here, because this may not match the given calling convention (if someone would choose to use a different one from the standard SVE calling convention), defined in AArch64CallingConvention.td.

Can you create a function that finds a suitable caller-saved register instead?

BuildMI(MBB, MI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
.setMIFlags(MachineInstr::FrameSetup);
}
}

MachineInstrBuilder MIB = BuildMI(MBB, MI, DL, TII.get(StrOpc));
if (!MRI.isReserved(Reg1))
MBB.addLiveIn(Reg1);
if (RPI.isPaired()) {
if (!MRI.isReserved(Reg2))
MBB.addLiveIn(Reg2);
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
if (RPI.isScalable())
MIB.addReg(PairRegs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should PairRegs also use getPrologueDeath ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. I can see in getPrologueDeath that it kills the reg if it is does not live.

else
MIB.addReg(Reg2, getPrologueDeath(MF, Reg2));
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
MachineMemOperand::MOStore, Size, Alignment));
}
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1))
.addReg(AArch64::SP)
if (RPI.isPaired() && RPI.isScalable())
MIB.addReg(PnReg);
else
MIB.addReg(Reg1, getPrologueDeath(MF, Reg1));
MIB.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale],
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameSetup);
Expand All @@ -3089,8 +3119,11 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(

// Update the StackIDs of the SVE stack slots.
MachineFrameInfo &MFI = MF.getFrameInfo();
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR)
MFI.setStackID(RPI.FrameIdx, TargetStackID::ScalableVector);
if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) {
MFI.setStackID(FrameIdxReg1, TargetStackID::ScalableVector);
if (RPI.isPaired())
MFI.setStackID(FrameIdxReg2, TargetStackID::ScalableVector);
}

}
return true;
Expand All @@ -3109,7 +3142,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
DL = MBBI->getDebugLoc();

computeCalleeSaveRegisterPairs(MF, CSI, TRI, RegPairs, hasFP(MF));

if (homogeneousPrologEpilog(MF, &MBB)) {
auto MIB = BuildMI(MBB, MBBI, DL, TII.get(AArch64::HOM_Epilog))
.setMIFlag(MachineInstr::FrameDestroy);
Expand All @@ -3130,6 +3163,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
auto ZPREnd = std::find_if_not(ZPRBegin, RegPairs.end(), IsZPR);
std::reverse(ZPRBegin, ZPREnd);

bool PtrueCreated = false;
for (const RegPairInfo &RPI : RegPairs) {
unsigned Reg1 = RPI.Reg1;
unsigned Reg2 = RPI.Reg2;
Expand Down Expand Up @@ -3162,7 +3196,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
Alignment = Align(16);
break;
case RegPairInfo::ZPR:
LdrOpc = AArch64::LDR_ZXI;
LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
Size = 16;
Alignment = Align(16);
break;
Expand All @@ -3187,15 +3221,33 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
std::swap(Reg1, Reg2);
std::swap(FrameIdxReg1, FrameIdxReg2);
}

unsigned PnReg;
unsigned PairRegs;
if (RPI.isPaired() && RPI.isScalable()) {
PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
if (!PtrueCreated) {
PtrueCreated = true;
// Any one of predicate-as-count will be free to use
// This can be replaced in the future if needed
PnReg = AArch64::PN8;
BuildMI(MBB, MBBI, DL, TII.get(AArch64::PTRUE_C_B), PnReg)
.setMIFlags(MachineInstr::FrameDestroy);
}
}

MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII.get(LdrOpc));
if (RPI.isPaired()) {
MIB.addReg(Reg2, getDefRegState(true));
MIB.addReg(RPI.isScalable() ? PairRegs : Reg2, getDefRegState(true));
MIB.addMemOperand(MF.getMachineMemOperand(
MachinePointerInfo::getFixedStack(MF, FrameIdxReg2),
MachineMemOperand::MOLoad, Size, Alignment));
}
MIB.addReg(Reg1, getDefRegState(true))
.addReg(AArch64::SP)
if (RPI.isPaired() && RPI.isScalable())
MIB.addReg(PnReg);
else
MIB.addReg(Reg1, getDefRegState(true));
MIB.addReg(AArch64::SP)
.addImm(RPI.Offset) // [sp, #offset*scale]
// where factor*scale is implicit
.setMIFlag(MachineInstr::FrameDestroy);
Expand Down
Loading