From 55d1a793e67e6d472642a739ac5d6fb9eccf5229 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolai=20H=C3=A4hnle?= Date: Mon, 12 Jan 2026 09:42:04 -0800 Subject: [PATCH] [CodeGen] Add getTgtMemIntrinsic overload for multiple memory operands (NFC) There are target intrinsics that logically require two MMOs, such as llvm.amdgcn.global.load.lds, which is a copy from global memory to LDS, so there's both a load and a store to different addresses. Add an overload of getTgtMemIntrinsic that produces intrinsic info in a vector, and implement it in terms of the existing (now protected) overload. GlobalISel and SelectionDAG paths are updated to support multiple MMOs. The main part of this change is supporting multiple MMOs in MemIntrinsicNodes. Converting the backends to using the new overload is a fairly mechanical step that is done in a separate change in the hope that that allows reducing merging pains during review and for downstreams. A later change will then enable using multiple MMOs in AMDGPU. commit-id:b4a924aa --- .../llvm/CodeGen/GlobalISel/IRTranslator.h | 2 +- llvm/include/llvm/CodeGen/SelectionDAG.h | 20 ++- llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 118 +++++++++++++----- llvm/include/llvm/CodeGen/TargetLowering.h | 23 +++- llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp | 47 +++---- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 2 +- .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 91 ++++++++++---- .../SelectionDAG/SelectionDAGBuilder.cpp | 67 +++++----- .../SelectionDAG/SelectionDAGDumper.cpp | 4 +- .../CodeGen/SelectionDAG/SelectionDAGISel.cpp | 2 +- 10 files changed, 251 insertions(+), 125 deletions(-) diff --git a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h index 3828d859212cb..2f3f55a58a517 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h @@ -299,7 +299,7 @@ class IRTranslator : public MachineFunctionPass { bool translateIntrinsic( const CallBase &CB, Intrinsic::ID ID, MachineIRBuilder &MIRBuilder, - const TargetLowering::IntrinsicInfo *TgtMemIntrinsicInfo = nullptr); + ArrayRef TgtMemIntrinsicInfos = {}); /// When an invoke or a cleanupret unwinds to the next EH pad, there are /// many places it could ultimately go. In the IR, we have a single unwind diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 89619da7c9f50..ed695dc990bae 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -438,10 +438,18 @@ class SelectionDAG { template static uint16_t getSyntheticNodeSubclassData(unsigned Opc, unsigned Order, - SDVTList VTs, EVT MemoryVT, - MachineMemOperand *MMO) { + SDVTList VTs, EVT MemoryVT, + MachineMemOperand *MMO) { return SDNodeTy(Opc, Order, DebugLoc(), VTs, MemoryVT, MMO) - .getRawSubclassData(); + .getRawSubclassData(); + } + + template + static uint16_t getSyntheticNodeSubclassData( + unsigned Opc, unsigned Order, SDVTList VTs, EVT MemoryVT, + PointerUnion MemRefs) { + return SDNodeTy(Opc, Order, DebugLoc(), VTs, MemoryVT, MemRefs) + .getRawSubclassData(); } void createOperands(SDNode *Node, ArrayRef Vals); @@ -1481,6 +1489,12 @@ class SelectionDAG { SDVTList VTList, ArrayRef Ops, EVT MemVT, MachineMemOperand *MMO); + /// getMemIntrinsicNode - Creates a MemIntrinsicNode with multiple MMOs. + LLVM_ABI SDValue getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, + SDVTList VTList, ArrayRef Ops, + EVT MemVT, + ArrayRef MMOs); + /// Creates a LifetimeSDNode that starts (`IsStart==true`) or ends /// (`IsStart==false`) the lifetime of the `FrameIndex`. LLVM_ABI SDValue getLifetimeNode(bool IsStart, const SDLoc &dl, SDValue Chain, diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index 536dca4602c03..a50bd8dab5407 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1411,19 +1411,26 @@ class MemSDNode : public SDNode { EVT MemoryVT; protected: - /// Memory reference information. - MachineMemOperand *MMO; + /// Memory reference information. Must always have at least one MMO. + /// - MachineMemOperand*: exactly 1 MMO (common case) + /// - MachineMemOperand**: pointer to array, size at offset -1 + PointerUnion MemRefs; public: - LLVM_ABI MemSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, - SDVTList VTs, EVT memvt, MachineMemOperand *MMO); + /// Constructor that supports single or multiple MMOs. For single MMO, pass + /// the MMO pointer directly. For multiple MMOs, pre-allocate storage with + /// count at offset -1 and pass pointer to array. + LLVM_ABI + MemSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT memvt, + PointerUnion memrefs); - bool readMem() const { return MMO->isLoad(); } - bool writeMem() const { return MMO->isStore(); } + bool readMem() const { return getMemOperand()->isLoad(); } + bool writeMem() const { return getMemOperand()->isStore(); } /// Returns alignment and volatility of the memory access - Align getBaseAlign() const { return MMO->getBaseAlign(); } - Align getAlign() const { return MMO->getAlign(); } + Align getBaseAlign() const { return getMemOperand()->getBaseAlign(); } + Align getAlign() const { return getMemOperand()->getAlign(); } /// Return the SubclassData value, without HasDebugValue. This contains an /// encoding of the volatile flag, as well as bits used by subclasses. This @@ -1450,36 +1457,40 @@ class MemSDNode : public SDNode { bool isInvariant() const { return MemSDNodeBits.IsInvariant; } // Returns the offset from the location of the access. - int64_t getSrcValueOffset() const { return MMO->getOffset(); } + int64_t getSrcValueOffset() const { return getMemOperand()->getOffset(); } /// Returns the AA info that describes the dereference. - AAMDNodes getAAInfo() const { return MMO->getAAInfo(); } + AAMDNodes getAAInfo() const { return getMemOperand()->getAAInfo(); } /// Returns the Ranges that describes the dereference. - const MDNode *getRanges() const { return MMO->getRanges(); } + const MDNode *getRanges() const { return getMemOperand()->getRanges(); } /// Returns the synchronization scope ID for this memory operation. - SyncScope::ID getSyncScopeID() const { return MMO->getSyncScopeID(); } + SyncScope::ID getSyncScopeID() const { + return getMemOperand()->getSyncScopeID(); + } /// Return the atomic ordering requirements for this memory operation. For /// cmpxchg atomic operations, return the atomic ordering requirements when /// store occurs. AtomicOrdering getSuccessOrdering() const { - return MMO->getSuccessOrdering(); + return getMemOperand()->getSuccessOrdering(); } /// Return a single atomic ordering that is at least as strong as both the /// success and failure orderings for an atomic operation. (For operations /// other than cmpxchg, this is equivalent to getSuccessOrdering().) - AtomicOrdering getMergedOrdering() const { return MMO->getMergedOrdering(); } + AtomicOrdering getMergedOrdering() const { + return getMemOperand()->getMergedOrdering(); + } /// Return true if the memory operation ordering is Unordered or higher. - bool isAtomic() const { return MMO->isAtomic(); } + bool isAtomic() const { return getMemOperand()->isAtomic(); } /// Returns true if the memory operation doesn't imply any ordering /// constraints on surrounding memory operations beyond the normal memory /// aliasing rules. - bool isUnordered() const { return MMO->isUnordered(); } + bool isUnordered() const { return getMemOperand()->isUnordered(); } /// Returns true if the memory operation is neither atomic or volatile. bool isSimple() const { return !isAtomic() && !isVolatile(); } @@ -1487,12 +1498,37 @@ class MemSDNode : public SDNode { /// Return the type of the in-memory value. EVT getMemoryVT() const { return MemoryVT; } - /// Return a MachineMemOperand object describing the memory + /// Return the unique MachineMemOperand object describing the memory /// reference performed by operation. - MachineMemOperand *getMemOperand() const { return MMO; } + /// Asserts if multiple MMOs are present - use memoperands() instead. + MachineMemOperand *getMemOperand() const { + assert(!isa(MemRefs) && + "Use memoperands() for nodes with multiple memory operands"); + return cast(MemRefs); + } + + /// Return the number of memory operands. + size_t getNumMemOperands() const { + if (isa(MemRefs)) + return 1; + MachineMemOperand **Array = cast(MemRefs); + return reinterpret_cast(Array)[-1]; + } + + /// Return true if this node has exactly one memory operand. + bool hasUniqueMemOperand() const { return isa(MemRefs); } + + /// Return the memory operands for this node. + ArrayRef memoperands() const { + if (isa(MemRefs)) + return ArrayRef(MemRefs.getAddrOfPtr1(), 1); + MachineMemOperand **Array = cast(MemRefs); + size_t Count = reinterpret_cast(Array)[-1]; + return ArrayRef(Array, Count); + } const MachinePointerInfo &getPointerInfo() const { - return MMO->getPointerInfo(); + return getMemOperand()->getPointerInfo(); } /// Return the address space for the associated pointer @@ -1501,19 +1537,35 @@ class MemSDNode : public SDNode { } /// Update this MemSDNode's MachineMemOperand information - /// to reflect the alignment of NewMMO, if it has a greater alignment. + /// to reflect the alignment of NewMMOs, if they have greater alignment. /// This must only be used when the new alignment applies to all users of - /// this MachineMemOperand. - void refineAlignment(const MachineMemOperand *NewMMO) { - MMO->refineAlignment(NewMMO); + /// these MachineMemOperands. The NewMMOs array must parallel memoperands(). + void refineAlignment(ArrayRef NewMMOs) { + ArrayRef MMOs = memoperands(); + assert(NewMMOs.size() == MMOs.size() && "MMO count mismatch"); + for (auto [MMO, NewMMO] : zip(MMOs, NewMMOs)) + MMO->refineAlignment(NewMMO); + } + + void refineAlignment(MachineMemOperand *NewMMO) { + refineAlignment(ArrayRef(NewMMO)); } - void refineRanges(const MachineMemOperand *NewMMO) { - // If this node has range metadata that is different than NewMMO, clear the - // range metadata. + /// Refine range metadata for all MMOs. The NewMMOs array must parallel + /// memoperands(). For each pair, if ranges differ, the stored range is + /// cleared. + void refineRanges(ArrayRef NewMMOs) { + ArrayRef MMOs = memoperands(); + assert(NewMMOs.size() == MMOs.size() && "MMO count mismatch"); // FIXME: Union the ranges instead? - if (getRanges() && getRanges() != NewMMO->getRanges()) - MMO->clearRanges(); + for (auto [MMO, NewMMO] : zip(MMOs, NewMMOs)) { + if (MMO->getRanges() && MMO->getRanges() != NewMMO->getRanges()) + MMO->clearRanges(); + } + } + + void refineRanges(MachineMemOperand *NewMMO) { + refineRanges(ArrayRef(NewMMO)); } const SDValue &getChain() const { return getOperand(0); } @@ -1626,7 +1678,7 @@ class AtomicSDNode : public MemSDNode { /// when store does not occur. AtomicOrdering getFailureOrdering() const { assert(isCompareAndSwap() && "Must be cmpxchg operation"); - return MMO->getFailureOrdering(); + return getMemOperand()->getFailureOrdering(); } // Methods to support isa and dyn_cast @@ -1666,9 +1718,11 @@ class AtomicSDNode : public MemSDNode { /// opcode (see `SelectionDAGTargetInfo::isTargetMemoryOpcode`). class MemIntrinsicSDNode : public MemSDNode { public: - MemIntrinsicSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, - SDVTList VTs, EVT MemoryVT, MachineMemOperand *MMO) - : MemSDNode(Opc, Order, dl, VTs, MemoryVT, MMO) { + MemIntrinsicSDNode( + unsigned Opc, unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemoryVT, + PointerUnion MemRefs) + : MemSDNode(Opc, Order, dl, VTs, MemoryVT, MemRefs) { SDNodeBits.IsMemIntrinsic = true; } diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 442225bdec01f..ada4ffd3bcc89 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1244,15 +1244,32 @@ class LLVM_ABI TargetLoweringBase { }; /// Given an intrinsic, checks if on the target the intrinsic will need to map - /// to a MemIntrinsicNode (touches memory). If this is the case, it returns - /// true and store the intrinsic information into the IntrinsicInfo that was - /// passed to the function. + /// to a MemIntrinsicNode (touches memory). If this is the case, it stores + /// the intrinsic information into the IntrinsicInfo vector passed to the + /// function. The vector may contain multiple entries for intrinsics that + /// access multiple memory locations. + virtual void getTgtMemIntrinsic(SmallVectorImpl &Infos, + const CallBase &I, MachineFunction &MF, + unsigned Intrinsic) const { + // The default implementation forwards to the legacy single-info overload + // for compatibility. + IntrinsicInfo Info; + if (getTgtMemIntrinsic(Info, I, MF, Intrinsic)) + Infos.push_back(Info); + } + +protected: + /// This is a legacy single-info overload. New code should override the + /// SmallVectorImpl overload instead to support multiple memory operands. + /// + /// TODO: Remove this once the refactoring is complete. virtual bool getTgtMemIntrinsic(IntrinsicInfo &, const CallBase &, MachineFunction &, unsigned /*Intrinsic*/) const { return false; } +public: /// Returns true if the target can instruction select the specified FP /// immediate natively. If false, the legalizer will materialize the FP /// immediate as a load from a constant pool. diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index a0fe900778cca..126199849b033 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -2819,20 +2819,16 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) { if (translateKnownIntrinsic(CI, ID, MIRBuilder)) return true; - TargetLowering::IntrinsicInfo Info; - bool IsTgtMemIntrinsic = TLI->getTgtMemIntrinsic(Info, CI, *MF, ID); + SmallVector Infos; + TLI->getTgtMemIntrinsic(Infos, CI, *MF, ID); - return translateIntrinsic(CI, ID, MIRBuilder, - IsTgtMemIntrinsic ? &Info : nullptr); + return translateIntrinsic(CI, ID, MIRBuilder, Infos); } /// Translate a call or callbr to an intrinsic. -/// Depending on whether TLI->getTgtMemIntrinsic() is true, TgtMemIntrinsicInfo -/// is a pointer to the correspondingly populated IntrinsicInfo object. -/// Otherwise, this pointer is null. bool IRTranslator::translateIntrinsic( const CallBase &CB, Intrinsic::ID ID, MachineIRBuilder &MIRBuilder, - const TargetLowering::IntrinsicInfo *TgtMemIntrinsicInfo) { + ArrayRef TgtMemIntrinsicInfos) { ArrayRef ResultRegs; if (!CB.getType()->isVoidTy()) ResultRegs = getOrCreateVRegs(CB); @@ -2874,30 +2870,25 @@ bool IRTranslator::translateIntrinsic( } } - // Add a MachineMemOperand if it is a target mem intrinsic. - if (TgtMemIntrinsicInfo) { - const Function *F = CB.getCalledFunction(); + // Add MachineMemOperands for each memory access described by the target. + for (const auto &Info : TgtMemIntrinsicInfos) { + Align Alignment = Info.align.value_or( + DL->getABITypeAlign(Info.memVT.getTypeForEVT(CB.getContext()))); + LLT MemTy = Info.memVT.isSimple() + ? getLLTForMVT(Info.memVT.getSimpleVT()) + : LLT::scalar(Info.memVT.getStoreSizeInBits()); - Align Alignment = TgtMemIntrinsicInfo->align.value_or(DL->getABITypeAlign( - TgtMemIntrinsicInfo->memVT.getTypeForEVT(F->getContext()))); - LLT MemTy = - TgtMemIntrinsicInfo->memVT.isSimple() - ? getLLTForMVT(TgtMemIntrinsicInfo->memVT.getSimpleVT()) - : LLT::scalar(TgtMemIntrinsicInfo->memVT.getStoreSizeInBits()); - - // TODO: We currently just fallback to address space 0 if getTgtMemIntrinsic - // didn't yield anything useful. + // TODO: We currently just fallback to address space 0 if + // getTgtMemIntrinsic didn't yield anything useful. MachinePointerInfo MPI; - if (TgtMemIntrinsicInfo->ptrVal) { - MPI = MachinePointerInfo(TgtMemIntrinsicInfo->ptrVal, - TgtMemIntrinsicInfo->offset); - } else if (TgtMemIntrinsicInfo->fallbackAddressSpace) { - MPI = MachinePointerInfo(*TgtMemIntrinsicInfo->fallbackAddressSpace); + if (Info.ptrVal) { + MPI = MachinePointerInfo(Info.ptrVal, Info.offset); + } else if (Info.fallbackAddressSpace) { + MPI = MachinePointerInfo(*Info.fallbackAddressSpace); } MIB.addMemOperand(MF->getMachineMemOperand( - MPI, TgtMemIntrinsicInfo->flags, MemTy, Alignment, CB.getAAMetadata(), - /*Ranges=*/nullptr, TgtMemIntrinsicInfo->ssid, - TgtMemIntrinsicInfo->order, TgtMemIntrinsicInfo->failureOrder)); + MPI, Info.flags, MemTy, Alignment, CB.getAAMetadata(), + /*Ranges=*/nullptr, Info.ssid, Info.order, Info.failureOrder)); } if (CB.isConvergent()) { diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index eb15aa8ce2261..df69f0870d27a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -1208,7 +1208,7 @@ bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc, for (SDNode *Node : N->users()) { auto *LoadStore = dyn_cast(Node); - if (!LoadStore) + if (!LoadStore || !LoadStore->hasUniqueMemOperand()) return false; // Is x[offset2] a legal addressing mode? If so then diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index c49e056dba5ac..302b8059e4df0 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -989,9 +989,11 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { // to check. if (auto *MN = dyn_cast(N)) { ID.AddInteger(MN->getRawSubclassData()); - ID.AddInteger(MN->getPointerInfo().getAddrSpace()); - ID.AddInteger(MN->getMemOperand()->getFlags()); ID.AddInteger(MN->getMemoryVT().getRawBits()); + for (const MachineMemOperand *MMO : MN->memoperands()) { + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); + } } } @@ -1304,7 +1306,7 @@ SelectionDAG::AddModifiedNodeToCSEMaps(SDNode *N) { // recursive merging of other unrelated nodes down the line. Existing->intersectFlagsWith(N->getFlags()); if (auto *MemNode = dyn_cast(Existing)) - MemNode->refineRanges(cast(N)->getMemOperand()); + MemNode->refineRanges(cast(N)->memoperands()); ReplaceAllUsesWith(N, Existing); // N is now dead. Inform the listeners and delete it. @@ -9831,6 +9833,14 @@ SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, SDVTList VTList, ArrayRef Ops, EVT MemVT, MachineMemOperand *MMO) { + return getMemIntrinsicNode(Opcode, dl, VTList, Ops, MemVT, ArrayRef(MMO)); +} + +SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, + SDVTList VTList, + ArrayRef Ops, EVT MemVT, + ArrayRef MMOs) { + assert(!MMOs.empty() && "Must have at least one MMO"); assert( (Opcode == ISD::INTRINSIC_VOID || Opcode == ISD::INTRINSIC_W_CHAIN || Opcode == ISD::PREFETCH || @@ -9838,30 +9848,47 @@ SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, Opcode >= ISD::BUILTIN_OP_END && TSI->isTargetMemoryOpcode(Opcode))) && "Opcode is not a memory-accessing opcode!"); + PointerUnion MemRefs; + if (MMOs.size() == 1) { + MemRefs = MMOs[0]; + } else { + // Allocate: [size_t count][MMO*][MMO*]... + size_t AllocSize = + sizeof(size_t) + MMOs.size() * sizeof(MachineMemOperand *); + void *Buffer = Allocator.Allocate(AllocSize, alignof(size_t)); + size_t *CountPtr = static_cast(Buffer); + *CountPtr = MMOs.size(); + MachineMemOperand **Array = + reinterpret_cast(CountPtr + 1); + llvm::copy(MMOs, Array); + MemRefs = Array; + } + // Memoize the node unless it returns a glue result. MemIntrinsicSDNode *N; if (VTList.VTs[VTList.NumVTs-1] != MVT::Glue) { FoldingSetNodeID ID; AddNodeIDNode(ID, Opcode, VTList, Ops); ID.AddInteger(getSyntheticNodeSubclassData( - Opcode, dl.getIROrder(), VTList, MemVT, MMO)); - ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); - ID.AddInteger(MMO->getFlags()); + Opcode, dl.getIROrder(), VTList, MemVT, MemRefs)); ID.AddInteger(MemVT.getRawBits()); + for (const MachineMemOperand *MMO : MMOs) { + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); + } void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { - cast(E)->refineAlignment(MMO); + cast(E)->refineAlignment(MMOs); return SDValue(E, 0); } N = newSDNode(Opcode, dl.getIROrder(), dl.getDebugLoc(), - VTList, MemVT, MMO); + VTList, MemVT, MemRefs); createOperands(N, Ops); - - CSEMap.InsertNode(N, IP); + CSEMap.InsertNode(N, IP); } else { N = newSDNode(Opcode, dl.getIROrder(), dl.getDebugLoc(), - VTList, MemVT, MMO); + VTList, MemVT, MemRefs); createOperands(N, Ops); } InsertNode(N); @@ -13285,21 +13312,33 @@ HandleSDNode::~HandleSDNode() { DropOperands(); } -MemSDNode::MemSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, - SDVTList VTs, EVT memvt, MachineMemOperand *mmo) - : SDNode(Opc, Order, dl, VTs), MemoryVT(memvt), MMO(mmo) { - MemSDNodeBits.IsVolatile = MMO->isVolatile(); - MemSDNodeBits.IsNonTemporal = MMO->isNonTemporal(); - MemSDNodeBits.IsDereferenceable = MMO->isDereferenceable(); - MemSDNodeBits.IsInvariant = MMO->isInvariant(); - - // We check here that the size of the memory operand fits within the size of - // the MMO. This is because the MMO might indicate only a possible address - // range instead of specifying the affected memory addresses precisely. - assert( - (!MMO->getType().isValid() || - TypeSize::isKnownLE(memvt.getStoreSize(), MMO->getSize().getValue())) && - "Size mismatch!"); +MemSDNode::MemSDNode( + unsigned Opc, unsigned Order, const DebugLoc &dl, SDVTList VTs, EVT memvt, + PointerUnion memrefs) + : SDNode(Opc, Order, dl, VTs), MemoryVT(memvt), MemRefs(memrefs) { + bool IsVolatile = false; + bool IsNonTemporal = false; + bool IsDereferenceable = true; + bool IsInvariant = true; + for (const MachineMemOperand *MMO : memoperands()) { + IsVolatile |= MMO->isVolatile(); + IsNonTemporal |= MMO->isNonTemporal(); + IsDereferenceable &= MMO->isDereferenceable(); + IsInvariant &= MMO->isInvariant(); + } + MemSDNodeBits.IsVolatile = IsVolatile; + MemSDNodeBits.IsNonTemporal = IsNonTemporal; + MemSDNodeBits.IsDereferenceable = IsDereferenceable; + MemSDNodeBits.IsInvariant = IsInvariant; + + // For the single-MMO case, we check here that the size of the memory operand + // fits within the size of the MMO. This is because the MMO might indicate + // only a possible address range instead of specifying the affected memory + // addresses precisely. + assert((getNumMemOperands() != 1 || !getMemOperand()->getType().isValid() || + TypeSize::isKnownLE(memvt.getStoreSize(), + getMemOperand()->getSize().getValue())) && + "Size mismatch!"); } /// Profile - Gather unique data for the node. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 18cb69a47d85f..6045b55130925 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -3514,10 +3514,12 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) { /// - they do not need custom argument handling (no /// TLI.CollectTargetIntrinsicOperands()) void SelectionDAGBuilder::visitCallBrIntrinsic(const CallBrInst &I) { - TargetLowering::IntrinsicInfo Info; - assert(!DAG.getTargetLoweringInfo().getTgtMemIntrinsic( - Info, I, DAG.getMachineFunction(), I.getIntrinsicID()) && - "Intrinsic touches memory"); +#ifndef NDEBUG + SmallVector Infos; + DAG.getTargetLoweringInfo().getTgtMemIntrinsic( + Infos, I, DAG.getMachineFunction(), I.getIntrinsicID()); + assert(Infos.empty() && "Intrinsic touches memory"); +#endif auto [HasChain, OnlyLoad] = getTargetIntrinsicCallProperties(I); @@ -5485,14 +5487,15 @@ void SelectionDAGBuilder::visitTargetIntrinsic(const CallInst &I, unsigned Intrinsic) { auto [HasChain, OnlyLoad] = getTargetIntrinsicCallProperties(I); - // Info is set by getTgtMemIntrinsic - TargetLowering::IntrinsicInfo Info; + // Infos is set by getTgtMemIntrinsic. + SmallVector Infos; const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - bool IsTgtMemIntrinsic = - TLI.getTgtMemIntrinsic(Info, I, DAG.getMachineFunction(), Intrinsic); + TLI.getTgtMemIntrinsic(Infos, I, DAG.getMachineFunction(), Intrinsic); + // Use the first (primary) info determines the node opcode. + TargetLowering::IntrinsicInfo *Info = !Infos.empty() ? &Infos[0] : nullptr; - SmallVector Ops = getTargetIntrinsicOperands( - I, HasChain, OnlyLoad, IsTgtMemIntrinsic ? &Info : nullptr); + SmallVector Ops = + getTargetIntrinsicOperands(I, HasChain, OnlyLoad, Info); SDVTList VTs = getTargetIntrinsicVTList(I, HasChain); // Propagate fast-math-flags from IR to node(s). @@ -5506,26 +5509,32 @@ void SelectionDAGBuilder::visitTargetIntrinsic(const CallInst &I, // In some cases, custom collection of operands from CallInst I may be needed. TLI.CollectTargetIntrinsicOperands(I, Ops, DAG); - if (IsTgtMemIntrinsic) { + if (!Infos.empty()) { // This is target intrinsic that touches memory - // - // TODO: We currently just fallback to address space 0 if getTgtMemIntrinsic - // didn't yield anything useful. - MachinePointerInfo MPI; - if (Info.ptrVal) - MPI = MachinePointerInfo(Info.ptrVal, Info.offset); - else if (Info.fallbackAddressSpace) - MPI = MachinePointerInfo(*Info.fallbackAddressSpace); - EVT MemVT = Info.memVT; - LocationSize Size = LocationSize::precise(Info.size); - if (Size.hasValue() && !Size.getValue()) - Size = LocationSize::precise(MemVT.getStoreSize()); - Align Alignment = Info.align.value_or(DAG.getEVTAlign(MemVT)); - MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( - MPI, Info.flags, Size, Alignment, I.getAAMetadata(), /*Ranges=*/nullptr, - Info.ssid, Info.order, Info.failureOrder); - Result = - DAG.getMemIntrinsicNode(Info.opc, getCurSDLoc(), VTs, Ops, MemVT, MMO); + // Create MachineMemOperands for each memory access described by the target. + MachineFunction &MF = DAG.getMachineFunction(); + SmallVector MMOs; + for (const auto &Info : Infos) { + // TODO: We currently just fallback to address space 0 if + // getTgtMemIntrinsic didn't yield anything useful. + MachinePointerInfo MPI; + if (Info.ptrVal) + MPI = MachinePointerInfo(Info.ptrVal, Info.offset); + else if (Info.fallbackAddressSpace) + MPI = MachinePointerInfo(*Info.fallbackAddressSpace); + EVT MemVT = Info.memVT; + LocationSize Size = LocationSize::precise(Info.size); + if (Size.hasValue() && !Size.getValue()) + Size = LocationSize::precise(MemVT.getStoreSize()); + Align Alignment = Info.align.value_or(DAG.getEVTAlign(MemVT)); + MachineMemOperand *MMO = MF.getMachineMemOperand( + MPI, Info.flags, Size, Alignment, I.getAAMetadata(), + /*Ranges=*/nullptr, Info.ssid, Info.order, Info.failureOrder); + MMOs.push_back(MMO); + } + + Result = DAG.getMemIntrinsicNode(Info->opc, getCurSDLoc(), VTs, Ops, + Info->memVT, MMOs); } else { Result = getTargetNonMemIntrinsicNode(*I.getType(), HasChain, Ops, VTs); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp index 66ecb40e48954..a213396f3df90 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -933,7 +933,9 @@ void SDNode::print_details(raw_ostream &OS, const SelectionDAG *G) const { OS << ">"; } else if (const MemSDNode *M = dyn_cast(this)) { OS << "<"; - printMemOperand(OS, *M->getMemOperand(), G); + interleaveComma(M->memoperands(), OS, [&](const MachineMemOperand *MMO) { + printMemOperand(OS, *MMO, G); + }); if (auto *A = dyn_cast(M)) if (A->getOpcode() == ISD::ATOMIC_LOAD) { bool doExt = true; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp index bd21c95c0ff93..e7cb0a3574b4a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -3597,7 +3597,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, } case OPC_RecordMemRef: if (auto *MN = dyn_cast(N)) - MatchedMemRefs.push_back(MN->getMemOperand()); + llvm::append_range(MatchedMemRefs, MN->memoperands()); else { LLVM_DEBUG(dbgs() << "Expected MemSDNode "; N->dump(CurDAG); dbgs() << '\n');