-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[RISCV] Handle recurrences in RISCVVLOptimizer #151285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
9aba342
ea2b861
9f24fe7
9b61df6
a78cc60
257ed3c
4d80e45
5465920
a31269c
97a12b1
8f75db7
d550870
dc0ca0e
1dd36a4
ebce546
cadf393
0ef2baf
1068f3a
62635d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,9 +10,19 @@ | |
| // instructions are inserted. | ||
| // | ||
| // The purpose of this optimization is to make the VL argument, for instructions | ||
| // that have a VL argument, as small as possible. This is implemented by | ||
| // visiting each instruction in reverse order and checking that if it has a VL | ||
| // argument, whether the VL can be reduced. | ||
| // that have a VL argument, as small as possible. | ||
| // | ||
| // This is split into a sparse dataflow analysis where we determine what VL is | ||
| // demanded by each instruction first, and then afterwards try to reduce the VL | ||
| // of each instruction if it demands less than its VL operand. | ||
| // | ||
| // The analysis is explained in more detail in the 2025 EuroLLVM Developers' | ||
| // Meeting talk "Accidental Dataflow Analysis: Extending the RISC-V VL | ||
| // Optimizer", which is available on YouTube at | ||
| // https://www.youtube.com/watch?v=Mfb5fRSdJAc | ||
| // | ||
| // The slides for the talk are available at | ||
| // https://llvm.org/devmtg/2025-04/slides/technical_talk/lau_accidental_dataflow.pdf | ||
| // | ||
| //===---------------------------------------------------------------------===// | ||
|
|
||
|
|
@@ -30,6 +40,27 @@ using namespace llvm; | |
|
|
||
| namespace { | ||
|
|
||
| /// Wrapper around MachineOperand that defaults to immediate 0. | ||
| struct DemandedVL { | ||
| MachineOperand VL; | ||
| DemandedVL() : VL(MachineOperand::CreateImm(0)) {} | ||
| DemandedVL(MachineOperand VL) : VL(VL) {} | ||
| static DemandedVL vlmax() { | ||
| return DemandedVL(MachineOperand::CreateImm(RISCV::VLMaxSentinel)); | ||
| } | ||
| bool operator!=(const DemandedVL &Other) const { | ||
| return !VL.isIdenticalTo(Other.VL); | ||
| } | ||
| }; | ||
|
|
||
| static DemandedVL max(const DemandedVL &LHS, const DemandedVL &RHS) { | ||
| if (RISCV::isVLKnownLE(LHS.VL, RHS.VL)) | ||
| return RHS; | ||
| if (RISCV::isVLKnownLE(RHS.VL, LHS.VL)) | ||
| return LHS; | ||
| return DemandedVL::vlmax(); | ||
wangpc-pp marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| class RISCVVLOptimizer : public MachineFunctionPass { | ||
| const MachineRegisterInfo *MRI; | ||
| const MachineDominatorTree *MDT; | ||
|
|
@@ -51,17 +82,17 @@ class RISCVVLOptimizer : public MachineFunctionPass { | |
| StringRef getPassName() const override { return PASS_NAME; } | ||
|
|
||
| private: | ||
| std::optional<MachineOperand> | ||
| getMinimumVLForUser(const MachineOperand &UserOp) const; | ||
| /// Returns the largest common VL MachineOperand that may be used to optimize | ||
| /// MI. Returns std::nullopt if it failed to find a suitable VL. | ||
| std::optional<MachineOperand> checkUsers(const MachineInstr &MI) const; | ||
| DemandedVL getMinimumVLForUser(const MachineOperand &UserOp) const; | ||
| /// Returns true if the users of \p MI have compatible EEWs and SEWs. | ||
| bool checkUsers(const MachineInstr &MI) const; | ||
| bool tryReduceVL(MachineInstr &MI) const; | ||
| bool isCandidate(const MachineInstr &MI) const; | ||
| void transfer(const MachineInstr &MI); | ||
|
|
||
| /// For a given instruction, records what elements of it are demanded by | ||
| /// downstream users. | ||
| DenseMap<const MachineInstr *, std::optional<MachineOperand>> DemandedVLs; | ||
| DenseMap<const MachineInstr *, DemandedVL> DemandedVLs; | ||
| SetVector<const MachineInstr *> Worklist; | ||
| }; | ||
|
|
||
| /// Represents the EMUL and EEW of a MachineOperand. | ||
|
|
@@ -813,6 +844,7 @@ static std::optional<OperandInfo> getOperandInfo(const MachineOperand &MO) { | |
| const MachineInstr &MI = *MO.getParent(); | ||
| const RISCVVPseudosTable::PseudoInfo *RVV = | ||
| RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); | ||
| MI.dump(); | ||
lukel97 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert(RVV && "Could not find MI in PseudoTable"); | ||
|
|
||
| std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO); | ||
|
|
@@ -847,10 +879,15 @@ static std::optional<OperandInfo> getOperandInfo(const MachineOperand &MO) { | |
| return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), *Log2EEW); | ||
| } | ||
|
|
||
| static bool isTupleInsertInstr(const MachineInstr &MI); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Forward declared here to remove code motion from the diff, will remove in a follow-up commit |
||
|
|
||
| /// Return true if this optimization should consider MI for VL reduction. This | ||
| /// white-list approach simplifies this optimization for instructions that may | ||
| /// have more complex semantics with relation to how it uses VL. | ||
| static bool isSupportedInstr(const MachineInstr &MI) { | ||
| if (MI.isPHI() || MI.isFullCopy() || isTupleInsertInstr(MI)) | ||
| return true; | ||
|
|
||
| const RISCVVPseudosTable::PseudoInfo *RVV = | ||
| RISCVVPseudosTable::getPseudoInfo(MI.getOpcode()); | ||
|
|
||
|
|
@@ -1348,21 +1385,24 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const { | |
| return true; | ||
| } | ||
|
|
||
| std::optional<MachineOperand> | ||
| DemandedVL | ||
| RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const { | ||
| const MachineInstr &UserMI = *UserOp.getParent(); | ||
| const MCInstrDesc &Desc = UserMI.getDesc(); | ||
|
|
||
| if (UserMI.isPHI() || UserMI.isFullCopy() || isTupleInsertInstr(UserMI)) | ||
| return DemandedVLs.lookup(&UserMI); | ||
|
|
||
| if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) { | ||
| LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that" | ||
| " use VLMAX\n"); | ||
| return std::nullopt; | ||
| return DemandedVL::vlmax(); | ||
| } | ||
|
|
||
| if (RISCVII::readsPastVL( | ||
| TII->get(RISCV::getRVVMCOpcode(UserMI.getOpcode())).TSFlags)) { | ||
| LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n"); | ||
| return std::nullopt; | ||
| return DemandedVL::vlmax(); | ||
| } | ||
|
|
||
| unsigned VLOpNum = RISCVII::getVLOpNum(Desc); | ||
|
|
@@ -1376,11 +1416,10 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const { | |
| if (UserOp.isTied()) { | ||
| assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() && | ||
| RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc())); | ||
| auto DemandedVL = DemandedVLs.lookup(&UserMI); | ||
| if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp)) { | ||
| if (!RISCV::isVLKnownLE(DemandedVLs.lookup(&UserMI).VL, VLOp)) { | ||
| LLVM_DEBUG(dbgs() << " Abort because user is passthru in " | ||
| "instruction with demanded tail\n"); | ||
| return std::nullopt; | ||
| return DemandedVL::vlmax(); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1393,11 +1432,8 @@ RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const { | |
|
|
||
| // If we know the demanded VL of UserMI, then we can reduce the VL it | ||
| // requires. | ||
| if (auto DemandedVL = DemandedVLs.lookup(&UserMI)) { | ||
| assert(isCandidate(UserMI)); | ||
| if (RISCV::isVLKnownLE(*DemandedVL, VLOp)) | ||
| return DemandedVL; | ||
| } | ||
| if (RISCV::isVLKnownLE(DemandedVLs.lookup(&UserMI).VL, VLOp)) | ||
| return DemandedVLs.lookup(&UserMI); | ||
|
|
||
| return VLOp; | ||
| } | ||
|
|
@@ -1450,9 +1486,10 @@ static bool isSegmentedStoreInstr(const MachineInstr &MI) { | |
| } | ||
| } | ||
|
|
||
| std::optional<MachineOperand> | ||
| RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { | ||
| std::optional<MachineOperand> CommonVL; | ||
| bool RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { | ||
| if (MI.isPHI() || MI.isFullCopy() || isTupleInsertInstr(MI)) | ||
| return true; | ||
|
|
||
| SmallSetVector<MachineOperand *, 8> Worklist; | ||
|
||
| SmallPtrSet<const MachineInstr *, 4> PHISeen; | ||
| for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) | ||
|
|
@@ -1481,7 +1518,7 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { | |
| // whole register group). | ||
| if (!isTupleInsertInstr(CandidateMI) && | ||
| !isSegmentedStoreInstr(CandidateMI)) | ||
| return std::nullopt; | ||
| return false; | ||
| Worklist.insert(&UseOp); | ||
| } | ||
| continue; | ||
|
|
@@ -1497,23 +1534,9 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { | |
| continue; | ||
| } | ||
|
|
||
| auto VLOp = getMinimumVLForUser(UserOp); | ||
| if (!VLOp) | ||
| return std::nullopt; | ||
|
|
||
| // Use the largest VL among all the users. If we cannot determine this | ||
| // statically, then we cannot optimize the VL. | ||
| if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) { | ||
| CommonVL = *VLOp; | ||
| LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n"); | ||
| } else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) { | ||
| LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n"); | ||
| return std::nullopt; | ||
| } | ||
|
|
||
| if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) { | ||
| LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n"); | ||
| return std::nullopt; | ||
| return false; | ||
| } | ||
|
|
||
| std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp); | ||
|
|
@@ -1522,7 +1545,7 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { | |
| LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n"); | ||
| LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); | ||
| LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); | ||
| return std::nullopt; | ||
| return false; | ||
| } | ||
|
|
||
| if (!OperandInfo::areCompatible(*ProducerInfo, *ConsumerInfo)) { | ||
|
|
@@ -1531,11 +1554,11 @@ RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const { | |
| << " Abort due to incompatible information for EMUL or EEW.\n"); | ||
| LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n"); | ||
| LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n"); | ||
| return std::nullopt; | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| return CommonVL; | ||
| return true; | ||
| } | ||
|
|
||
| bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const { | ||
|
|
@@ -1551,9 +1574,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const { | |
| return false; | ||
| } | ||
|
|
||
| auto CommonVL = DemandedVLs.lookup(&MI); | ||
| if (!CommonVL) | ||
| return false; | ||
| auto *CommonVL = &DemandedVLs.at(&MI).VL; | ||
|
|
||
| assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) && | ||
| "Expected VL to be an Imm or virtual Reg"); | ||
|
|
@@ -1564,7 +1585,7 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const { | |
| const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg()); | ||
| if (RISCVInstrInfo::isFaultOnlyFirstLoad(*VLMI) && | ||
| !MDT->dominates(VLMI, &MI)) | ||
| CommonVL = VLMI->getOperand(RISCVII::getVLOpNum(VLMI->getDesc())); | ||
| CommonVL = &VLMI->getOperand(RISCVII::getVLOpNum(VLMI->getDesc())); | ||
| } | ||
|
|
||
| if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) { | ||
|
|
@@ -1599,6 +1620,30 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const { | |
| return true; | ||
| } | ||
|
|
||
| static bool isPhysical(const MachineOperand &MO) { | ||
| return MO.isReg() && MO.getReg().isPhysical(); | ||
| } | ||
|
|
||
| static bool isVirtualVec(const MachineOperand &MO) { | ||
| return MO.isReg() && MO.getReg().isVirtual() && | ||
| RISCVRegisterInfo::isRVVRegClass( | ||
| MO.getParent()->getMF()->getRegInfo().getRegClass(MO.getReg())); | ||
|
||
| } | ||
|
|
||
| /// Look through \p MI's operands and propagate what it demands to its uses. | ||
| void RISCVVLOptimizer::transfer(const MachineInstr &MI) { | ||
| if (!isSupportedInstr(MI) || !checkUsers(MI) || any_of(MI.defs(), isPhysical)) | ||
| DemandedVLs[&MI] = DemandedVL::vlmax(); | ||
|
|
||
| for (const MachineOperand &MO : make_filter_range(MI.uses(), isVirtualVec)) { | ||
| const MachineInstr *Def = MRI->getVRegDef(MO.getReg()); | ||
| DemandedVL Prev = DemandedVLs[Def]; | ||
| DemandedVLs[Def] = max(DemandedVLs[Def], getMinimumVLForUser(MO)); | ||
| if (DemandedVLs[Def] != Prev) | ||
| Worklist.insert(Def); | ||
| } | ||
| } | ||
|
|
||
| bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) { | ||
| if (skipFunction(MF.getFunction())) | ||
| return false; | ||
|
|
@@ -1614,15 +1659,18 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) { | |
|
|
||
| assert(DemandedVLs.empty()); | ||
|
|
||
| // For each instruction that defines a vector, compute what VL its | ||
| // downstream users demand. | ||
| // For each instruction that defines a vector, propagate the VL it | ||
| // uses to its inputs. | ||
| for (MachineBasicBlock *MBB : post_order(&MF)) { | ||
| assert(MDT->isReachableFromEntry(MBB)); | ||
| for (MachineInstr &MI : reverse(*MBB)) { | ||
| if (!isCandidate(MI)) | ||
| continue; | ||
| DemandedVLs.insert({&MI, checkUsers(MI)}); | ||
| } | ||
| for (MachineInstr &MI : reverse(*MBB)) | ||
| Worklist.insert(&MI); | ||
| } | ||
|
|
||
| while (!Worklist.empty()) { | ||
| const MachineInstr *MI = Worklist.front(); | ||
| Worklist.remove(MI); | ||
| transfer(*MI); | ||
| } | ||
|
|
||
| // Then go through and see if we can reduce the VL of any instructions to | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this into the DemandedVL class so that it has to be called as
DemandedVL::max?