diff --git a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h index 3828d859212cb..2f3f55a58a517 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/IRTranslator.h @@ -299,7 +299,7 @@ class IRTranslator : public MachineFunctionPass { bool translateIntrinsic( const CallBase &CB, Intrinsic::ID ID, MachineIRBuilder &MIRBuilder, - const TargetLowering::IntrinsicInfo *TgtMemIntrinsicInfo = nullptr); + ArrayRef TgtMemIntrinsicInfos = {}); /// When an invoke or a cleanupret unwinds to the next EH pad, there are /// many places it could ultimately go. In the IR, we have a single unwind diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index 89619da7c9f50..ed695dc990bae 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -438,10 +438,18 @@ class SelectionDAG { template static uint16_t getSyntheticNodeSubclassData(unsigned Opc, unsigned Order, - SDVTList VTs, EVT MemoryVT, - MachineMemOperand *MMO) { + SDVTList VTs, EVT MemoryVT, + MachineMemOperand *MMO) { return SDNodeTy(Opc, Order, DebugLoc(), VTs, MemoryVT, MMO) - .getRawSubclassData(); + .getRawSubclassData(); + } + + template + static uint16_t getSyntheticNodeSubclassData( + unsigned Opc, unsigned Order, SDVTList VTs, EVT MemoryVT, + PointerUnion MemRefs) { + return SDNodeTy(Opc, Order, DebugLoc(), VTs, MemoryVT, MemRefs) + .getRawSubclassData(); } void createOperands(SDNode *Node, ArrayRef Vals); @@ -1481,6 +1489,12 @@ class SelectionDAG { SDVTList VTList, ArrayRef Ops, EVT MemVT, MachineMemOperand *MMO); + /// getMemIntrinsicNode - Creates a MemIntrinsicNode with multiple MMOs. + LLVM_ABI SDValue getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, + SDVTList VTList, ArrayRef Ops, + EVT MemVT, + ArrayRef MMOs); + /// Creates a LifetimeSDNode that starts (`IsStart==true`) or ends /// (`IsStart==false`) the lifetime of the `FrameIndex`. LLVM_ABI SDValue getLifetimeNode(bool IsStart, const SDLoc &dl, SDValue Chain, diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index 536dca4602c03..a50bd8dab5407 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1411,19 +1411,26 @@ class MemSDNode : public SDNode { EVT MemoryVT; protected: - /// Memory reference information. - MachineMemOperand *MMO; + /// Memory reference information. Must always have at least one MMO. + /// - MachineMemOperand*: exactly 1 MMO (common case) + /// - MachineMemOperand**: pointer to array, size at offset -1 + PointerUnion MemRefs; public: - LLVM_ABI MemSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, - SDVTList VTs, EVT memvt, MachineMemOperand *MMO); + /// Constructor that supports single or multiple MMOs. For single MMO, pass + /// the MMO pointer directly. For multiple MMOs, pre-allocate storage with + /// count at offset -1 and pass pointer to array. + LLVM_ABI + MemSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT memvt, + PointerUnion memrefs); - bool readMem() const { return MMO->isLoad(); } - bool writeMem() const { return MMO->isStore(); } + bool readMem() const { return getMemOperand()->isLoad(); } + bool writeMem() const { return getMemOperand()->isStore(); } /// Returns alignment and volatility of the memory access - Align getBaseAlign() const { return MMO->getBaseAlign(); } - Align getAlign() const { return MMO->getAlign(); } + Align getBaseAlign() const { return getMemOperand()->getBaseAlign(); } + Align getAlign() const { return getMemOperand()->getAlign(); } /// Return the SubclassData value, without HasDebugValue. This contains an /// encoding of the volatile flag, as well as bits used by subclasses. This @@ -1450,36 +1457,40 @@ class MemSDNode : public SDNode { bool isInvariant() const { return MemSDNodeBits.IsInvariant; } // Returns the offset from the location of the access. - int64_t getSrcValueOffset() const { return MMO->getOffset(); } + int64_t getSrcValueOffset() const { return getMemOperand()->getOffset(); } /// Returns the AA info that describes the dereference. - AAMDNodes getAAInfo() const { return MMO->getAAInfo(); } + AAMDNodes getAAInfo() const { return getMemOperand()->getAAInfo(); } /// Returns the Ranges that describes the dereference. - const MDNode *getRanges() const { return MMO->getRanges(); } + const MDNode *getRanges() const { return getMemOperand()->getRanges(); } /// Returns the synchronization scope ID for this memory operation. - SyncScope::ID getSyncScopeID() const { return MMO->getSyncScopeID(); } + SyncScope::ID getSyncScopeID() const { + return getMemOperand()->getSyncScopeID(); + } /// Return the atomic ordering requirements for this memory operation. For /// cmpxchg atomic operations, return the atomic ordering requirements when /// store occurs. AtomicOrdering getSuccessOrdering() const { - return MMO->getSuccessOrdering(); + return getMemOperand()->getSuccessOrdering(); } /// Return a single atomic ordering that is at least as strong as both the /// success and failure orderings for an atomic operation. (For operations /// other than cmpxchg, this is equivalent to getSuccessOrdering().) - AtomicOrdering getMergedOrdering() const { return MMO->getMergedOrdering(); } + AtomicOrdering getMergedOrdering() const { + return getMemOperand()->getMergedOrdering(); + } /// Return true if the memory operation ordering is Unordered or higher. - bool isAtomic() const { return MMO->isAtomic(); } + bool isAtomic() const { return getMemOperand()->isAtomic(); } /// Returns true if the memory operation doesn't imply any ordering /// constraints on surrounding memory operations beyond the normal memory /// aliasing rules. - bool isUnordered() const { return MMO->isUnordered(); } + bool isUnordered() const { return getMemOperand()->isUnordered(); } /// Returns true if the memory operation is neither atomic or volatile. bool isSimple() const { return !isAtomic() && !isVolatile(); } @@ -1487,12 +1498,37 @@ class MemSDNode : public SDNode { /// Return the type of the in-memory value. EVT getMemoryVT() const { return MemoryVT; } - /// Return a MachineMemOperand object describing the memory + /// Return the unique MachineMemOperand object describing the memory /// reference performed by operation. - MachineMemOperand *getMemOperand() const { return MMO; } + /// Asserts if multiple MMOs are present - use memoperands() instead. + MachineMemOperand *getMemOperand() const { + assert(!isa(MemRefs) && + "Use memoperands() for nodes with multiple memory operands"); + return cast(MemRefs); + } + + /// Return the number of memory operands. + size_t getNumMemOperands() const { + if (isa(MemRefs)) + return 1; + MachineMemOperand **Array = cast(MemRefs); + return reinterpret_cast(Array)[-1]; + } + + /// Return true if this node has exactly one memory operand. + bool hasUniqueMemOperand() const { return isa(MemRefs); } + + /// Return the memory operands for this node. + ArrayRef memoperands() const { + if (isa(MemRefs)) + return ArrayRef(MemRefs.getAddrOfPtr1(), 1); + MachineMemOperand **Array = cast(MemRefs); + size_t Count = reinterpret_cast(Array)[-1]; + return ArrayRef(Array, Count); + } const MachinePointerInfo &getPointerInfo() const { - return MMO->getPointerInfo(); + return getMemOperand()->getPointerInfo(); } /// Return the address space for the associated pointer @@ -1501,19 +1537,35 @@ class MemSDNode : public SDNode { } /// Update this MemSDNode's MachineMemOperand information - /// to reflect the alignment of NewMMO, if it has a greater alignment. + /// to reflect the alignment of NewMMOs, if they have greater alignment. /// This must only be used when the new alignment applies to all users of - /// this MachineMemOperand. - void refineAlignment(const MachineMemOperand *NewMMO) { - MMO->refineAlignment(NewMMO); + /// these MachineMemOperands. The NewMMOs array must parallel memoperands(). + void refineAlignment(ArrayRef NewMMOs) { + ArrayRef MMOs = memoperands(); + assert(NewMMOs.size() == MMOs.size() && "MMO count mismatch"); + for (auto [MMO, NewMMO] : zip(MMOs, NewMMOs)) + MMO->refineAlignment(NewMMO); + } + + void refineAlignment(MachineMemOperand *NewMMO) { + refineAlignment(ArrayRef(NewMMO)); } - void refineRanges(const MachineMemOperand *NewMMO) { - // If this node has range metadata that is different than NewMMO, clear the - // range metadata. + /// Refine range metadata for all MMOs. The NewMMOs array must parallel + /// memoperands(). For each pair, if ranges differ, the stored range is + /// cleared. + void refineRanges(ArrayRef NewMMOs) { + ArrayRef MMOs = memoperands(); + assert(NewMMOs.size() == MMOs.size() && "MMO count mismatch"); // FIXME: Union the ranges instead? - if (getRanges() && getRanges() != NewMMO->getRanges()) - MMO->clearRanges(); + for (auto [MMO, NewMMO] : zip(MMOs, NewMMOs)) { + if (MMO->getRanges() && MMO->getRanges() != NewMMO->getRanges()) + MMO->clearRanges(); + } + } + + void refineRanges(MachineMemOperand *NewMMO) { + refineRanges(ArrayRef(NewMMO)); } const SDValue &getChain() const { return getOperand(0); } @@ -1626,7 +1678,7 @@ class AtomicSDNode : public MemSDNode { /// when store does not occur. AtomicOrdering getFailureOrdering() const { assert(isCompareAndSwap() && "Must be cmpxchg operation"); - return MMO->getFailureOrdering(); + return getMemOperand()->getFailureOrdering(); } // Methods to support isa and dyn_cast @@ -1666,9 +1718,11 @@ class AtomicSDNode : public MemSDNode { /// opcode (see `SelectionDAGTargetInfo::isTargetMemoryOpcode`). class MemIntrinsicSDNode : public MemSDNode { public: - MemIntrinsicSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, - SDVTList VTs, EVT MemoryVT, MachineMemOperand *MMO) - : MemSDNode(Opc, Order, dl, VTs, MemoryVT, MMO) { + MemIntrinsicSDNode( + unsigned Opc, unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemoryVT, + PointerUnion MemRefs) + : MemSDNode(Opc, Order, dl, VTs, MemoryVT, MemRefs) { SDNodeBits.IsMemIntrinsic = true; } diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 442225bdec01f..ada4ffd3bcc89 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1244,15 +1244,32 @@ class LLVM_ABI TargetLoweringBase { }; /// Given an intrinsic, checks if on the target the intrinsic will need to map - /// to a MemIntrinsicNode (touches memory). If this is the case, it returns - /// true and store the intrinsic information into the IntrinsicInfo that was - /// passed to the function. + /// to a MemIntrinsicNode (touches memory). If this is the case, it stores + /// the intrinsic information into the IntrinsicInfo vector passed to the + /// function. The vector may contain multiple entries for intrinsics that + /// access multiple memory locations. + virtual void getTgtMemIntrinsic(SmallVectorImpl &Infos, + const CallBase &I, MachineFunction &MF, + unsigned Intrinsic) const { + // The default implementation forwards to the legacy single-info overload + // for compatibility. + IntrinsicInfo Info; + if (getTgtMemIntrinsic(Info, I, MF, Intrinsic)) + Infos.push_back(Info); + } + +protected: + /// This is a legacy single-info overload. New code should override the + /// SmallVectorImpl overload instead to support multiple memory operands. + /// + /// TODO: Remove this once the refactoring is complete. virtual bool getTgtMemIntrinsic(IntrinsicInfo &, const CallBase &, MachineFunction &, unsigned /*Intrinsic*/) const { return false; } +public: /// Returns true if the target can instruction select the specified FP /// immediate natively. If false, the legalizer will materialize the FP /// immediate as a load from a constant pool. diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index a0fe900778cca..126199849b033 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -2819,20 +2819,16 @@ bool IRTranslator::translateCall(const User &U, MachineIRBuilder &MIRBuilder) { if (translateKnownIntrinsic(CI, ID, MIRBuilder)) return true; - TargetLowering::IntrinsicInfo Info; - bool IsTgtMemIntrinsic = TLI->getTgtMemIntrinsic(Info, CI, *MF, ID); + SmallVector Infos; + TLI->getTgtMemIntrinsic(Infos, CI, *MF, ID); - return translateIntrinsic(CI, ID, MIRBuilder, - IsTgtMemIntrinsic ? &Info : nullptr); + return translateIntrinsic(CI, ID, MIRBuilder, Infos); } /// Translate a call or callbr to an intrinsic. -/// Depending on whether TLI->getTgtMemIntrinsic() is true, TgtMemIntrinsicInfo -/// is a pointer to the correspondingly populated IntrinsicInfo object. -/// Otherwise, this pointer is null. bool IRTranslator::translateIntrinsic( const CallBase &CB, Intrinsic::ID ID, MachineIRBuilder &MIRBuilder, - const TargetLowering::IntrinsicInfo *TgtMemIntrinsicInfo) { + ArrayRef TgtMemIntrinsicInfos) { ArrayRef ResultRegs; if (!CB.getType()->isVoidTy()) ResultRegs = getOrCreateVRegs(CB); @@ -2874,30 +2870,25 @@ bool IRTranslator::translateIntrinsic( } } - // Add a MachineMemOperand if it is a target mem intrinsic. - if (TgtMemIntrinsicInfo) { - const Function *F = CB.getCalledFunction(); + // Add MachineMemOperands for each memory access described by the target. + for (const auto &Info : TgtMemIntrinsicInfos) { + Align Alignment = Info.align.value_or( + DL->getABITypeAlign(Info.memVT.getTypeForEVT(CB.getContext()))); + LLT MemTy = Info.memVT.isSimple() + ? getLLTForMVT(Info.memVT.getSimpleVT()) + : LLT::scalar(Info.memVT.getStoreSizeInBits()); - Align Alignment = TgtMemIntrinsicInfo->align.value_or(DL->getABITypeAlign( - TgtMemIntrinsicInfo->memVT.getTypeForEVT(F->getContext()))); - LLT MemTy = - TgtMemIntrinsicInfo->memVT.isSimple() - ? getLLTForMVT(TgtMemIntrinsicInfo->memVT.getSimpleVT()) - : LLT::scalar(TgtMemIntrinsicInfo->memVT.getStoreSizeInBits()); - - // TODO: We currently just fallback to address space 0 if getTgtMemIntrinsic - // didn't yield anything useful. + // TODO: We currently just fallback to address space 0 if + // getTgtMemIntrinsic didn't yield anything useful. MachinePointerInfo MPI; - if (TgtMemIntrinsicInfo->ptrVal) { - MPI = MachinePointerInfo(TgtMemIntrinsicInfo->ptrVal, - TgtMemIntrinsicInfo->offset); - } else if (TgtMemIntrinsicInfo->fallbackAddressSpace) { - MPI = MachinePointerInfo(*TgtMemIntrinsicInfo->fallbackAddressSpace); + if (Info.ptrVal) { + MPI = MachinePointerInfo(Info.ptrVal, Info.offset); + } else if (Info.fallbackAddressSpace) { + MPI = MachinePointerInfo(*Info.fallbackAddressSpace); } MIB.addMemOperand(MF->getMachineMemOperand( - MPI, TgtMemIntrinsicInfo->flags, MemTy, Alignment, CB.getAAMetadata(), - /*Ranges=*/nullptr, TgtMemIntrinsicInfo->ssid, - TgtMemIntrinsicInfo->order, TgtMemIntrinsicInfo->failureOrder)); + MPI, Info.flags, MemTy, Alignment, CB.getAAMetadata(), + /*Ranges=*/nullptr, Info.ssid, Info.order, Info.failureOrder)); } if (CB.isConvergent()) { diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index eb15aa8ce2261..df69f0870d27a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -1208,7 +1208,7 @@ bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc, for (SDNode *Node : N->users()) { auto *LoadStore = dyn_cast(Node); - if (!LoadStore) + if (!LoadStore || !LoadStore->hasUniqueMemOperand()) return false; // Is x[offset2] a legal addressing mode? If so then diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index c49e056dba5ac..302b8059e4df0 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -989,9 +989,11 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) { // to check. if (auto *MN = dyn_cast(N)) { ID.AddInteger(MN->getRawSubclassData()); - ID.AddInteger(MN->getPointerInfo().getAddrSpace()); - ID.AddInteger(MN->getMemOperand()->getFlags()); ID.AddInteger(MN->getMemoryVT().getRawBits()); + for (const MachineMemOperand *MMO : MN->memoperands()) { + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); + } } } @@ -1304,7 +1306,7 @@ SelectionDAG::AddModifiedNodeToCSEMaps(SDNode *N) { // recursive merging of other unrelated nodes down the line. Existing->intersectFlagsWith(N->getFlags()); if (auto *MemNode = dyn_cast(Existing)) - MemNode->refineRanges(cast(N)->getMemOperand()); + MemNode->refineRanges(cast(N)->memoperands()); ReplaceAllUsesWith(N, Existing); // N is now dead. Inform the listeners and delete it. @@ -9831,6 +9833,14 @@ SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, SDVTList VTList, ArrayRef Ops, EVT MemVT, MachineMemOperand *MMO) { + return getMemIntrinsicNode(Opcode, dl, VTList, Ops, MemVT, ArrayRef(MMO)); +} + +SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, + SDVTList VTList, + ArrayRef Ops, EVT MemVT, + ArrayRef MMOs) { + assert(!MMOs.empty() && "Must have at least one MMO"); assert( (Opcode == ISD::INTRINSIC_VOID || Opcode == ISD::INTRINSIC_W_CHAIN || Opcode == ISD::PREFETCH || @@ -9838,30 +9848,47 @@ SDValue SelectionDAG::getMemIntrinsicNode(unsigned Opcode, const SDLoc &dl, Opcode >= ISD::BUILTIN_OP_END && TSI->isTargetMemoryOpcode(Opcode))) && "Opcode is not a memory-accessing opcode!"); + PointerUnion MemRefs; + if (MMOs.size() == 1) { + MemRefs = MMOs[0]; + } else { + // Allocate: [size_t count][MMO*][MMO*]... + size_t AllocSize = + sizeof(size_t) + MMOs.size() * sizeof(MachineMemOperand *); + void *Buffer = Allocator.Allocate(AllocSize, alignof(size_t)); + size_t *CountPtr = static_cast(Buffer); + *CountPtr = MMOs.size(); + MachineMemOperand **Array = + reinterpret_cast(CountPtr + 1); + llvm::copy(MMOs, Array); + MemRefs = Array; + } + // Memoize the node unless it returns a glue result. MemIntrinsicSDNode *N; if (VTList.VTs[VTList.NumVTs-1] != MVT::Glue) { FoldingSetNodeID ID; AddNodeIDNode(ID, Opcode, VTList, Ops); ID.AddInteger(getSyntheticNodeSubclassData( - Opcode, dl.getIROrder(), VTList, MemVT, MMO)); - ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); - ID.AddInteger(MMO->getFlags()); + Opcode, dl.getIROrder(), VTList, MemVT, MemRefs)); ID.AddInteger(MemVT.getRawBits()); + for (const MachineMemOperand *MMO : MMOs) { + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + ID.AddInteger(MMO->getFlags()); + } void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { - cast(E)->refineAlignment(MMO); + cast(E)->refineAlignment(MMOs); return SDValue(E, 0); } N = newSDNode(Opcode, dl.getIROrder(), dl.getDebugLoc(), - VTList, MemVT, MMO); + VTList, MemVT, MemRefs); createOperands(N, Ops); - - CSEMap.InsertNode(N, IP); + CSEMap.InsertNode(N, IP); } else { N = newSDNode(Opcode, dl.getIROrder(), dl.getDebugLoc(), - VTList, MemVT, MMO); + VTList, MemVT, MemRefs); createOperands(N, Ops); } InsertNode(N); @@ -13285,21 +13312,33 @@ HandleSDNode::~HandleSDNode() { DropOperands(); } -MemSDNode::MemSDNode(unsigned Opc, unsigned Order, const DebugLoc &dl, - SDVTList VTs, EVT memvt, MachineMemOperand *mmo) - : SDNode(Opc, Order, dl, VTs), MemoryVT(memvt), MMO(mmo) { - MemSDNodeBits.IsVolatile = MMO->isVolatile(); - MemSDNodeBits.IsNonTemporal = MMO->isNonTemporal(); - MemSDNodeBits.IsDereferenceable = MMO->isDereferenceable(); - MemSDNodeBits.IsInvariant = MMO->isInvariant(); - - // We check here that the size of the memory operand fits within the size of - // the MMO. This is because the MMO might indicate only a possible address - // range instead of specifying the affected memory addresses precisely. - assert( - (!MMO->getType().isValid() || - TypeSize::isKnownLE(memvt.getStoreSize(), MMO->getSize().getValue())) && - "Size mismatch!"); +MemSDNode::MemSDNode( + unsigned Opc, unsigned Order, const DebugLoc &dl, SDVTList VTs, EVT memvt, + PointerUnion memrefs) + : SDNode(Opc, Order, dl, VTs), MemoryVT(memvt), MemRefs(memrefs) { + bool IsVolatile = false; + bool IsNonTemporal = false; + bool IsDereferenceable = true; + bool IsInvariant = true; + for (const MachineMemOperand *MMO : memoperands()) { + IsVolatile |= MMO->isVolatile(); + IsNonTemporal |= MMO->isNonTemporal(); + IsDereferenceable &= MMO->isDereferenceable(); + IsInvariant &= MMO->isInvariant(); + } + MemSDNodeBits.IsVolatile = IsVolatile; + MemSDNodeBits.IsNonTemporal = IsNonTemporal; + MemSDNodeBits.IsDereferenceable = IsDereferenceable; + MemSDNodeBits.IsInvariant = IsInvariant; + + // For the single-MMO case, we check here that the size of the memory operand + // fits within the size of the MMO. This is because the MMO might indicate + // only a possible address range instead of specifying the affected memory + // addresses precisely. + assert((getNumMemOperands() != 1 || !getMemOperand()->getType().isValid() || + TypeSize::isKnownLE(memvt.getStoreSize(), + getMemOperand()->getSize().getValue())) && + "Size mismatch!"); } /// Profile - Gather unique data for the node. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 18cb69a47d85f..6045b55130925 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -3514,10 +3514,12 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) { /// - they do not need custom argument handling (no /// TLI.CollectTargetIntrinsicOperands()) void SelectionDAGBuilder::visitCallBrIntrinsic(const CallBrInst &I) { - TargetLowering::IntrinsicInfo Info; - assert(!DAG.getTargetLoweringInfo().getTgtMemIntrinsic( - Info, I, DAG.getMachineFunction(), I.getIntrinsicID()) && - "Intrinsic touches memory"); +#ifndef NDEBUG + SmallVector Infos; + DAG.getTargetLoweringInfo().getTgtMemIntrinsic( + Infos, I, DAG.getMachineFunction(), I.getIntrinsicID()); + assert(Infos.empty() && "Intrinsic touches memory"); +#endif auto [HasChain, OnlyLoad] = getTargetIntrinsicCallProperties(I); @@ -5485,14 +5487,15 @@ void SelectionDAGBuilder::visitTargetIntrinsic(const CallInst &I, unsigned Intrinsic) { auto [HasChain, OnlyLoad] = getTargetIntrinsicCallProperties(I); - // Info is set by getTgtMemIntrinsic - TargetLowering::IntrinsicInfo Info; + // Infos is set by getTgtMemIntrinsic. + SmallVector Infos; const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - bool IsTgtMemIntrinsic = - TLI.getTgtMemIntrinsic(Info, I, DAG.getMachineFunction(), Intrinsic); + TLI.getTgtMemIntrinsic(Infos, I, DAG.getMachineFunction(), Intrinsic); + // Use the first (primary) info determines the node opcode. + TargetLowering::IntrinsicInfo *Info = !Infos.empty() ? &Infos[0] : nullptr; - SmallVector Ops = getTargetIntrinsicOperands( - I, HasChain, OnlyLoad, IsTgtMemIntrinsic ? &Info : nullptr); + SmallVector Ops = + getTargetIntrinsicOperands(I, HasChain, OnlyLoad, Info); SDVTList VTs = getTargetIntrinsicVTList(I, HasChain); // Propagate fast-math-flags from IR to node(s). @@ -5506,26 +5509,32 @@ void SelectionDAGBuilder::visitTargetIntrinsic(const CallInst &I, // In some cases, custom collection of operands from CallInst I may be needed. TLI.CollectTargetIntrinsicOperands(I, Ops, DAG); - if (IsTgtMemIntrinsic) { + if (!Infos.empty()) { // This is target intrinsic that touches memory - // - // TODO: We currently just fallback to address space 0 if getTgtMemIntrinsic - // didn't yield anything useful. - MachinePointerInfo MPI; - if (Info.ptrVal) - MPI = MachinePointerInfo(Info.ptrVal, Info.offset); - else if (Info.fallbackAddressSpace) - MPI = MachinePointerInfo(*Info.fallbackAddressSpace); - EVT MemVT = Info.memVT; - LocationSize Size = LocationSize::precise(Info.size); - if (Size.hasValue() && !Size.getValue()) - Size = LocationSize::precise(MemVT.getStoreSize()); - Align Alignment = Info.align.value_or(DAG.getEVTAlign(MemVT)); - MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( - MPI, Info.flags, Size, Alignment, I.getAAMetadata(), /*Ranges=*/nullptr, - Info.ssid, Info.order, Info.failureOrder); - Result = - DAG.getMemIntrinsicNode(Info.opc, getCurSDLoc(), VTs, Ops, MemVT, MMO); + // Create MachineMemOperands for each memory access described by the target. + MachineFunction &MF = DAG.getMachineFunction(); + SmallVector MMOs; + for (const auto &Info : Infos) { + // TODO: We currently just fallback to address space 0 if + // getTgtMemIntrinsic didn't yield anything useful. + MachinePointerInfo MPI; + if (Info.ptrVal) + MPI = MachinePointerInfo(Info.ptrVal, Info.offset); + else if (Info.fallbackAddressSpace) + MPI = MachinePointerInfo(*Info.fallbackAddressSpace); + EVT MemVT = Info.memVT; + LocationSize Size = LocationSize::precise(Info.size); + if (Size.hasValue() && !Size.getValue()) + Size = LocationSize::precise(MemVT.getStoreSize()); + Align Alignment = Info.align.value_or(DAG.getEVTAlign(MemVT)); + MachineMemOperand *MMO = MF.getMachineMemOperand( + MPI, Info.flags, Size, Alignment, I.getAAMetadata(), + /*Ranges=*/nullptr, Info.ssid, Info.order, Info.failureOrder); + MMOs.push_back(MMO); + } + + Result = DAG.getMemIntrinsicNode(Info->opc, getCurSDLoc(), VTs, Ops, + Info->memVT, MMOs); } else { Result = getTargetNonMemIntrinsicNode(*I.getType(), HasChain, Ops, VTs); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp index 66ecb40e48954..a213396f3df90 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -933,7 +933,9 @@ void SDNode::print_details(raw_ostream &OS, const SelectionDAG *G) const { OS << ">"; } else if (const MemSDNode *M = dyn_cast(this)) { OS << "<"; - printMemOperand(OS, *M->getMemOperand(), G); + interleaveComma(M->memoperands(), OS, [&](const MachineMemOperand *MMO) { + printMemOperand(OS, *MMO, G); + }); if (auto *A = dyn_cast(M)) if (A->getOpcode() == ISD::ATOMIC_LOAD) { bool doExt = true; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp index bd21c95c0ff93..e7cb0a3574b4a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -3597,7 +3597,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, } case OPC_RecordMemRef: if (auto *MN = dyn_cast(N)) - MatchedMemRefs.push_back(MN->getMemOperand()); + llvm::append_range(MatchedMemRefs, MN->memoperands()); else { LLVM_DEBUG(dbgs() << "Expected MemSDNode "; N->dump(CurDAG); dbgs() << '\n');