Skip to content

Commit adb217d

Browse files
committed
Initial commit from #142941
1 parent 957ae8a commit adb217d

File tree

9 files changed

+988
-340
lines changed

9 files changed

+988
-340
lines changed

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "Utils/AArch64BaseInfo.h"
2121
#include "llvm/ADT/ArrayRef.h"
2222
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/ADT/SmallSet.h"
2324
#include "llvm/ADT/SmallVector.h"
2425
#include "llvm/CodeGen/CFIInstBuilder.h"
2526
#include "llvm/CodeGen/LivePhysRegs.h"
@@ -35,6 +36,7 @@
3536
#include "llvm/CodeGen/MachineRegisterInfo.h"
3637
#include "llvm/CodeGen/RegisterScavenging.h"
3738
#include "llvm/CodeGen/StackMaps.h"
39+
#include "llvm/CodeGen/TargetOpcodes.h"
3840
#include "llvm/CodeGen/TargetRegisterInfo.h"
3941
#include "llvm/CodeGen/TargetSubtargetInfo.h"
4042
#include "llvm/IR/DebugInfoMetadata.h"
@@ -7349,6 +7351,9 @@ bool AArch64InstrInfo::isThroughputPattern(unsigned Pattern) const {
73497351
case AArch64MachineCombinerPattern::MULSUBv2i32_indexed_OP2:
73507352
case AArch64MachineCombinerPattern::MULSUBv4i32_indexed_OP1:
73517353
case AArch64MachineCombinerPattern::MULSUBv4i32_indexed_OP2:
7354+
case AArch64MachineCombinerPattern::GATHER_LANE_i32:
7355+
case AArch64MachineCombinerPattern::GATHER_LANE_i16:
7356+
case AArch64MachineCombinerPattern::GATHER_LANE_i8:
73527357
return true;
73537358
} // end switch (Pattern)
73547359
return false;
@@ -7389,11 +7394,252 @@ static bool getMiscPatterns(MachineInstr &Root,
73897394
return false;
73907395
}
73917396

7397+
static bool getGatherPattern(MachineInstr &Root,
7398+
SmallVectorImpl<unsigned> &Patterns,
7399+
unsigned LoadLaneOpCode, unsigned NumLanes) {
7400+
const MachineFunction *MF = Root.getMF();
7401+
7402+
// Early exit if optimizing for size.
7403+
if (MF->getFunction().hasMinSize())
7404+
return false;
7405+
7406+
const MachineRegisterInfo &MRI = MF->getRegInfo();
7407+
const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
7408+
7409+
// The root of the pattern must load into the last lane of the vector.
7410+
if (Root.getOperand(2).getImm() != NumLanes - 1)
7411+
return false;
7412+
7413+
// Check that we have load into all lanes except lane 0.
7414+
// For each load we also want to check that:
7415+
// 1. It has a single non-debug use (since we will be replacing the virtual
7416+
// register)
7417+
// 2. That the addressing mode only uses a single offset register.
7418+
auto *CurrInstr = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
7419+
auto Range = llvm::seq<unsigned>(1, NumLanes - 1);
7420+
SmallSet<unsigned, 4> RemainingLanes(Range.begin(), Range.end());
7421+
while (!RemainingLanes.empty() && CurrInstr &&
7422+
CurrInstr->getOpcode() == LoadLaneOpCode &&
7423+
MRI.hasOneNonDBGUse(CurrInstr->getOperand(0).getReg()) &&
7424+
CurrInstr->getNumOperands() == 4) {
7425+
RemainingLanes.erase(CurrInstr->getOperand(2).getImm());
7426+
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
7427+
}
7428+
7429+
if (!RemainingLanes.empty())
7430+
return false;
7431+
7432+
// Match the SUBREG_TO_REG sequence.
7433+
if (CurrInstr->getOpcode() != TargetOpcode::SUBREG_TO_REG)
7434+
return false;
7435+
7436+
// Verify that the subreg to reg loads an integer into the first lane.
7437+
auto Lane0LoadReg = CurrInstr->getOperand(2).getReg();
7438+
unsigned SingleLaneSizeInBits = 128 / NumLanes;
7439+
if (TRI->getRegSizeInBits(Lane0LoadReg, MRI) != SingleLaneSizeInBits)
7440+
return false;
7441+
7442+
// Verify that it also has a single non debug use.
7443+
if (!MRI.hasOneNonDBGUse(Lane0LoadReg))
7444+
return false;
7445+
7446+
switch (NumLanes) {
7447+
case 4:
7448+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i32);
7449+
break;
7450+
case 8:
7451+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i16);
7452+
break;
7453+
case 16:
7454+
Patterns.push_back(AArch64MachineCombinerPattern::GATHER_LANE_i8);
7455+
break;
7456+
default:
7457+
llvm_unreachable("Got bad number of lanes for gather pattern.");
7458+
}
7459+
7460+
return true;
7461+
}
7462+
7463+
/// Search for patterns where we use LD1 instructions to load into
7464+
/// separate lanes of an 128 bit Neon register. We can increase Memory Level
7465+
/// Parallelism by loading into 2 Neon registers instead.
7466+
static bool getLoadPatterns(MachineInstr &Root,
7467+
SmallVectorImpl<unsigned> &Patterns) {
7468+
7469+
// The pattern searches for loads into single lanes.
7470+
switch (Root.getOpcode()) {
7471+
case AArch64::LD1i32:
7472+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 4);
7473+
case AArch64::LD1i16:
7474+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 8);
7475+
case AArch64::LD1i8:
7476+
return getGatherPattern(Root, Patterns, Root.getOpcode(), 16);
7477+
default:
7478+
return false;
7479+
}
7480+
}
7481+
7482+
static void
7483+
generateGatherPattern(MachineInstr &Root,
7484+
SmallVectorImpl<MachineInstr *> &InsInstrs,
7485+
SmallVectorImpl<MachineInstr *> &DelInstrs,
7486+
DenseMap<Register, unsigned> &InstrIdxForVirtReg,
7487+
unsigned Pattern, unsigned NumLanes) {
7488+
7489+
MachineFunction &MF = *Root.getParent()->getParent();
7490+
MachineRegisterInfo &MRI = MF.getRegInfo();
7491+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
7492+
7493+
// Gather the initial load instructions to build the pattern
7494+
SmallVector<MachineInstr *, 16> LoadToLaneInstrs;
7495+
MachineInstr *CurrInstr = &Root;
7496+
for (unsigned i = 0; i < NumLanes - 1; ++i) {
7497+
LoadToLaneInstrs.push_back(CurrInstr);
7498+
CurrInstr = MRI.getUniqueVRegDef(CurrInstr->getOperand(1).getReg());
7499+
}
7500+
7501+
// Sort the load instructions according to the lane.
7502+
llvm::sort(LoadToLaneInstrs,
7503+
[](const MachineInstr *A, const MachineInstr *B) {
7504+
return A->getOperand(2).getImm() > B->getOperand(2).getImm();
7505+
});
7506+
7507+
MachineInstr *SubregToReg = CurrInstr;
7508+
LoadToLaneInstrs.push_back(
7509+
MRI.getUniqueVRegDef(SubregToReg->getOperand(2).getReg()));
7510+
auto LoadToLaneInstrsAscending = llvm::reverse(LoadToLaneInstrs);
7511+
7512+
const TargetRegisterClass *FPR128RegClass =
7513+
MRI.getRegClass(Root.getOperand(0).getReg());
7514+
7515+
auto LoadLaneToRegister = [&](MachineInstr *OriginalInstr,
7516+
Register SrcRegister, unsigned Lane,
7517+
Register OffsetRegister) {
7518+
auto NewRegister = MRI.createVirtualRegister(FPR128RegClass);
7519+
MachineInstrBuilder LoadIndexIntoRegister =
7520+
BuildMI(MF, MIMetadata(*OriginalInstr), TII->get(Root.getOpcode()),
7521+
NewRegister)
7522+
.addReg(SrcRegister)
7523+
.addImm(Lane)
7524+
.addReg(OffsetRegister, getKillRegState(true));
7525+
InstrIdxForVirtReg.insert(std::make_pair(NewRegister, InsInstrs.size()));
7526+
InsInstrs.push_back(LoadIndexIntoRegister);
7527+
return NewRegister;
7528+
};
7529+
7530+
// Helper to create load instruction based on opcode
7531+
auto CreateLoadInstruction = [&](unsigned NumLanes, Register DestReg,
7532+
Register OffsetReg) -> MachineInstrBuilder {
7533+
unsigned Opcode;
7534+
switch (NumLanes) {
7535+
case 4:
7536+
Opcode = AArch64::LDRSui;
7537+
break;
7538+
case 8:
7539+
Opcode = AArch64::LDRHui;
7540+
break;
7541+
case 16:
7542+
Opcode = AArch64::LDRBui;
7543+
break;
7544+
default:
7545+
llvm_unreachable(
7546+
"Got unsupported number of lanes in machine-combiner gather pattern");
7547+
}
7548+
// Immediate offset load
7549+
return BuildMI(MF, MIMetadata(Root), TII->get(Opcode), DestReg)
7550+
.addReg(OffsetReg)
7551+
.addImm(0); // immediate offset
7552+
};
7553+
7554+
// Load the remaining lanes into register 0.
7555+
auto LanesToLoadToReg0 =
7556+
llvm::make_range(LoadToLaneInstrsAscending.begin() + 1,
7557+
LoadToLaneInstrsAscending.begin() + NumLanes / 2);
7558+
auto PrevReg = SubregToReg->getOperand(0).getReg();
7559+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg0)) {
7560+
PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1,
7561+
LoadInstr->getOperand(3).getReg());
7562+
DelInstrs.push_back(LoadInstr);
7563+
}
7564+
auto LastLoadReg0 = PrevReg;
7565+
7566+
// First load into register 1. Perform a LDRSui to zero out the upper lanes in
7567+
// a single instruction.
7568+
auto Lane0Load = *LoadToLaneInstrsAscending.begin();
7569+
auto OriginalSplitLoad =
7570+
*std::next(LoadToLaneInstrsAscending.begin(), NumLanes / 2);
7571+
auto DestRegForMiddleIndex = MRI.createVirtualRegister(
7572+
MRI.getRegClass(Lane0Load->getOperand(0).getReg()));
7573+
7574+
MachineInstrBuilder MiddleIndexLoadInstr =
7575+
CreateLoadInstruction(NumLanes, DestRegForMiddleIndex,
7576+
OriginalSplitLoad->getOperand(3).getReg());
7577+
7578+
InstrIdxForVirtReg.insert(
7579+
std::make_pair(DestRegForMiddleIndex, InsInstrs.size()));
7580+
InsInstrs.push_back(MiddleIndexLoadInstr);
7581+
DelInstrs.push_back(OriginalSplitLoad);
7582+
7583+
// Subreg To Reg instruction for register 1.
7584+
auto DestRegForSubregToReg = MRI.createVirtualRegister(FPR128RegClass);
7585+
unsigned SubregType;
7586+
switch (NumLanes) {
7587+
case 4:
7588+
SubregType = AArch64::ssub;
7589+
break;
7590+
case 8:
7591+
SubregType = AArch64::hsub;
7592+
break;
7593+
case 16:
7594+
SubregType = AArch64::bsub;
7595+
break;
7596+
default:
7597+
llvm_unreachable(
7598+
"Got invalid NumLanes for machine-combiner gather pattern");
7599+
}
7600+
7601+
auto SubRegToRegInstr =
7602+
BuildMI(MF, MIMetadata(Root), TII->get(SubregToReg->getOpcode()),
7603+
DestRegForSubregToReg)
7604+
.addImm(0)
7605+
.addReg(DestRegForMiddleIndex, getKillRegState(true))
7606+
.addImm(SubregType);
7607+
InstrIdxForVirtReg.insert(
7608+
std::make_pair(DestRegForSubregToReg, InsInstrs.size()));
7609+
InsInstrs.push_back(SubRegToRegInstr);
7610+
7611+
// Load remaining lanes into register 1.
7612+
auto LanesToLoadToReg1 =
7613+
llvm::make_range(LoadToLaneInstrsAscending.begin() + NumLanes / 2 + 1,
7614+
LoadToLaneInstrsAscending.end());
7615+
PrevReg = SubRegToRegInstr->getOperand(0).getReg();
7616+
for (auto [Index, LoadInstr] : llvm::enumerate(LanesToLoadToReg1)) {
7617+
PrevReg = LoadLaneToRegister(LoadInstr, PrevReg, Index + 1,
7618+
LoadInstr->getOperand(3).getReg());
7619+
if (Index == NumLanes / 2 - 2) {
7620+
break;
7621+
}
7622+
DelInstrs.push_back(LoadInstr);
7623+
}
7624+
auto LastLoadReg1 = PrevReg;
7625+
7626+
// Create the final zip instruction to combine the results.
7627+
MachineInstrBuilder ZipInstr =
7628+
BuildMI(MF, MIMetadata(Root), TII->get(AArch64::ZIP1v2i64),
7629+
Root.getOperand(0).getReg())
7630+
.addReg(LastLoadReg0)
7631+
.addReg(LastLoadReg1);
7632+
InsInstrs.push_back(ZipInstr);
7633+
}
7634+
73927635
CombinerObjective
73937636
AArch64InstrInfo::getCombinerObjective(unsigned Pattern) const {
73947637
switch (Pattern) {
73957638
case AArch64MachineCombinerPattern::SUBADD_OP1:
73967639
case AArch64MachineCombinerPattern::SUBADD_OP2:
7640+
case AArch64MachineCombinerPattern::GATHER_LANE_i32:
7641+
case AArch64MachineCombinerPattern::GATHER_LANE_i16:
7642+
case AArch64MachineCombinerPattern::GATHER_LANE_i8:
73977643
return CombinerObjective::MustReduceDepth;
73987644
default:
73997645
return TargetInstrInfo::getCombinerObjective(Pattern);
@@ -7423,6 +7669,10 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
74237669
if (getMiscPatterns(Root, Patterns))
74247670
return true;
74257671

7672+
// Load patterns
7673+
if (getLoadPatterns(Root, Patterns))
7674+
return true;
7675+
74267676
return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
74277677
DoRegPressureReduce);
74287678
}
@@ -8678,6 +8928,21 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
86788928
MUL = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs);
86798929
break;
86808930
}
8931+
case AArch64MachineCombinerPattern::GATHER_LANE_i32: {
8932+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
8933+
Pattern, 4);
8934+
break;
8935+
}
8936+
case AArch64MachineCombinerPattern::GATHER_LANE_i16: {
8937+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
8938+
Pattern, 8);
8939+
break;
8940+
}
8941+
case AArch64MachineCombinerPattern::GATHER_LANE_i8: {
8942+
generateGatherPattern(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg,
8943+
Pattern, 16);
8944+
break;
8945+
}
86818946

86828947
} // end switch (Pattern)
86838948
// Record MUL and ADD/SUB for deletion

llvm/lib/Target/AArch64/AArch64InstrInfo.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ enum AArch64MachineCombinerPattern : unsigned {
172172
FMULv8i16_indexed_OP2,
173173

174174
FNMADD,
175+
176+
GATHER_LANE_i32,
177+
GATHER_LANE_i16,
178+
GATHER_LANE_i8
175179
};
176180
class AArch64InstrInfo final : public AArch64GenInstrInfo {
177181
const AArch64RegisterInfo RI;

0 commit comments

Comments
 (0)