Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPIR-V] Improve Tablegen instruction selection and account for a pointer size of the target #88725

Merged
merged 1 commit into from
Apr 17, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR resolves the issue that SPIR-V Backend uses the notion of a pointer size of the target, most notably, in legalizer code, but Tablegen instruction selection in SPIR-V Backend doesn't account for a pointer size of the target. See #88723 for a detailed description. There are 3 test cases attached to the PR that reproduced the issue, when dealing with spirv32-spirv64 differences, and are working correctly now with this PR.

@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR resolves the issue that SPIR-V Backend uses the notion of a pointer size of the target, most notably, in legalizer code, but Tablegen instruction selection in SPIR-V Backend doesn't account for a pointer size of the target. See #88723 for a detailed description. There are 3 test cases attached to the PR that reproduced the issue, when dealing with spirv32-spirv64 differences, and are working correctly now with this PR.


Patch is 20.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88725.diff

13 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+2-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp (+4-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+12-6)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+6-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+3-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+45-17)
  • (modified) llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp (+2-18)
  • (modified) llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td (+2-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td (+28-21)
  • (modified) llvm/test/CodeGen/SPIRV/instructions/select-phi.ll (+4)
  • (added) llvm/test/CodeGen/SPIRV/instructions/select-ptr-load.ll (+25)
  • (modified) llvm/test/CodeGen/SPIRV/instructions/select.ll (+3)
  • (renamed) llvm/test/CodeGen/SPIRV/select-builtin.ll (+2)
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 05e41e06248e35..cebe230d3e8ce3 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -653,7 +653,8 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
   auto MRI = MIRBuilder.getMRI();
   assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
   if (Reg != ResVReg) {
-    LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
+    LLT RegLLTy =
+        LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize());
     MRI->setType(Reg, RegLLTy);
     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
   } else {
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
index aacfecc1e313f0..af98f2f8804593 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.cpp
@@ -247,8 +247,10 @@ void SPIRVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
 
 bool SPIRVInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
   if (MI.getOpcode() == SPIRV::GET_ID || MI.getOpcode() == SPIRV::GET_fID ||
-      MI.getOpcode() == SPIRV::GET_pID || MI.getOpcode() == SPIRV::GET_vfID ||
-      MI.getOpcode() == SPIRV::GET_vID || MI.getOpcode() == SPIRV::GET_vpID) {
+      MI.getOpcode() == SPIRV::GET_pID32 ||
+      MI.getOpcode() == SPIRV::GET_pID64 || MI.getOpcode() == SPIRV::GET_vfID ||
+      MI.getOpcode() == SPIRV::GET_vID || MI.getOpcode() == SPIRV::GET_vpID32 ||
+      MI.getOpcode() == SPIRV::GET_vpID64) {
     auto &MRI = MI.getMF()->getRegInfo();
     MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
     MI.eraseFromParent();
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index a3f981457c8daa..151d0ec1fe5690 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -19,10 +19,12 @@ let isCodeGenOnly=1 in {
   def DECL_TYPE: Pseudo<(outs ANYID:$dst_id), (ins ANYID:$src_id, TYPE:$src_ty)>;
   def GET_ID: Pseudo<(outs ID:$dst_id), (ins ANYID:$src)>;
   def GET_fID: Pseudo<(outs fID:$dst_id), (ins ANYID:$src)>;
-  def GET_pID: Pseudo<(outs pID:$dst_id), (ins ANYID:$src)>;
+  def GET_pID32: Pseudo<(outs pID32:$dst_id), (ins ANYID:$src)>;
+  def GET_pID64: Pseudo<(outs pID64:$dst_id), (ins ANYID:$src)>;
   def GET_vID: Pseudo<(outs vID:$dst_id), (ins ANYID:$src)>;
   def GET_vfID: Pseudo<(outs vfID:$dst_id), (ins ANYID:$src)>;
-  def GET_vpID: Pseudo<(outs vpID:$dst_id), (ins ANYID:$src)>;
+  def GET_vpID32: Pseudo<(outs vpID32:$dst_id), (ins ANYID:$src)>;
+  def GET_vpID64: Pseudo<(outs vpID64:$dst_id), (ins ANYID:$src)>;
 }
 
 def SPVTypeBin : SDTypeProfile<1, 2, []>;
@@ -66,8 +68,10 @@ multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genP =
     def SIVCond: TernOpTyped<name, opCode, vID, ID, node>;
   }
   if genP then {
-    def SPSCond: TernOpTyped<name, opCode, ID, pID, node>;
-    def SPVCond: TernOpTyped<name, opCode, vID, pID, node>;
+    def SPSCond32: TernOpTyped<name, opCode, ID, pID32, node>;
+    def SPVCond32: TernOpTyped<name, opCode, vID, pID32, node>;
+    def SPSCond64: TernOpTyped<name, opCode, ID, pID64, node>;
+    def SPVCond64: TernOpTyped<name, opCode, vID, pID64, node>;
   }
   if genV then {
     if genF then {
@@ -79,8 +83,10 @@ multiclass TernOpTypedGen<string name, bits<16> opCode, SDNode node, bit genP =
       def VIVCond: TernOpTyped<name, opCode, vID, vID, node>;
     }
     if genP then {
-      def VPSCond: TernOpTyped<name, opCode, ID, vpID, node>;
-      def VPVCond: TernOpTyped<name, opCode, vID, vpID, node>;
+      def VPSCond32: TernOpTyped<name, opCode, ID, vpID32, node>;
+      def VPVCond32: TernOpTyped<name, opCode, vID, vpID32, node>;
+      def VPSCond64: TernOpTyped<name, opCode, ID, vpID64, node>;
+      def VPVCond64: TernOpTyped<name, opCode, vID, vpID64, node>;
     }
   }
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index c1c0fc4b7dd489..afae239eb5136e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -284,14 +284,18 @@ bool SPIRVInstructionSelector::select(MachineInstr &I) {
   // If it's not a GMIR instruction, we've selected it already.
   if (!isPreISelGenericOpcode(Opcode)) {
     if (Opcode == SPIRV::ASSIGN_TYPE) { // These pseudos aren't needed any more.
-      auto *Def = MRI->getVRegDef(I.getOperand(1).getReg());
+      Register DstReg = I.getOperand(0).getReg();
+      Register SrcReg = I.getOperand(1).getReg();
+      auto *Def = MRI->getVRegDef(SrcReg);
       if (isTypeFoldingSupported(Def->getOpcode())) {
+        if (MRI->getType(DstReg).isPointer())
+          MRI->setType(DstReg, LLT::scalar(32));
         bool Res = selectImpl(I, *CoverageInfo);
         assert(Res || Def->getOpcode() == TargetOpcode::G_CONSTANT);
         if (Res)
           return Res;
       }
-      MRI->replaceRegWith(I.getOperand(1).getReg(), I.getOperand(0).getReg());
+      MRI->replaceRegWith(SrcReg, DstReg);
       I.removeFromParent();
       return true;
     } else if (I.getNumDefs() == 1) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
index f069a92ac68683..d652b5de608086 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp
@@ -55,8 +55,9 @@ extern void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
 
 static bool isMetaInstrGET(unsigned Opcode) {
   return Opcode == SPIRV::GET_ID || Opcode == SPIRV::GET_fID ||
-         Opcode == SPIRV::GET_pID || Opcode == SPIRV::GET_vID ||
-         Opcode == SPIRV::GET_vfID || Opcode == SPIRV::GET_vpID;
+         Opcode == SPIRV::GET_pID32 || Opcode == SPIRV::GET_pID64 ||
+         Opcode == SPIRV::GET_vID || Opcode == SPIRV::GET_vfID ||
+         Opcode == SPIRV::GET_vpID32 || Opcode == SPIRV::GET_vpID64;
 }
 
 static bool mayBeInserted(unsigned Opcode) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index 2c964595fc39e8..ed03154d4c8dd6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -223,11 +223,12 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
 }
 
 static std::pair<Register, unsigned>
-createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
+createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
                const SPIRVGlobalRegistry &GR) {
-  LLT NewT = LLT::scalar(32);
-  SPIRVType *SpvType = GR.getSPIRVTypeForVReg(ValReg);
+  if (!SpvType)
+    SpvType = GR.getSPIRVTypeForVReg(SrcReg);
   assert(SpvType && "VReg is expected to have SPIRV type");
+  LLT NewT = LLT::scalar(32);
   bool IsFloat = SpvType->getOpcode() == SPIRV::OpTypeFloat;
   bool IsVectorFloat =
       SpvType->getOpcode() == SPIRV::OpTypeVector &&
@@ -236,14 +237,38 @@ createNewIdReg(Register ValReg, unsigned Opcode, MachineRegisterInfo &MRI,
   IsFloat |= IsVectorFloat;
   auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
   auto DstClass = IsFloat ? &SPIRV::fIDRegClass : &SPIRV::IDRegClass;
-  if (MRI.getType(ValReg).isPointer()) {
-    NewT = LLT::pointer(0, 32);
-    GetIdOp = SPIRV::GET_pID;
-    DstClass = &SPIRV::pIDRegClass;
-  } else if (MRI.getType(ValReg).isVector()) {
+  if (MRI.getType(SrcReg).isPointer()) {
+    unsigned PtrSz = GR.getPointerSize();
+    NewT = LLT::pointer(0, PtrSz);
+    bool IsVec = MRI.getType(SrcReg).isVector();
+    if (IsVec)
+      NewT = LLT::fixed_vector(2, NewT);
+    if (PtrSz == 64) {
+      if (IsVec) {
+        GetIdOp = SPIRV::GET_vpID64;
+        DstClass = &SPIRV::vpID64RegClass;
+      } else {
+        GetIdOp = SPIRV::GET_pID64;
+        DstClass = &SPIRV::pID64RegClass;
+      }
+    } else {
+      if (IsVec) {
+        GetIdOp = SPIRV::GET_vpID32;
+        DstClass = &SPIRV::vpID32RegClass;
+      } else {
+        GetIdOp = SPIRV::GET_pID32;
+        DstClass = &SPIRV::pID32RegClass;
+      }
+    }
+  } else if (MRI.getType(SrcReg).isVector()) {
     NewT = LLT::fixed_vector(2, NewT);
-    GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
-    DstClass = IsFloat ? &SPIRV::vfIDRegClass : &SPIRV::vIDRegClass;
+    if (IsFloat) {
+      GetIdOp = SPIRV::GET_vfID;
+      DstClass = &SPIRV::vfIDRegClass;
+    } else {
+      GetIdOp = SPIRV::GET_vID;
+      DstClass = &SPIRV::vIDRegClass;
+    }
   }
   Register IdReg = MRI.createGenericVirtualRegister(NewT);
   MRI.setRegClass(IdReg, DstClass);
@@ -264,6 +289,7 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
   MIB.setInsertPt(*Def->getParent(),
                   (Def->getNextNode() ? Def->getNextNode()->getIterator()
                                       : Def->getParent()->end()));
+  SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
     MRI.setRegClass(NewReg, RC);
@@ -271,7 +297,6 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
     MRI.setRegClass(NewReg, &SPIRV::IDRegClass);
     MRI.setRegClass(Reg, &SPIRV::IDRegClass);
   }
-  SpirvTy = SpirvTy ? SpirvTy : GR->getOrCreateSPIRVType(Ty, MIB);
   GR->assignSPIRVTypeToVReg(SpirvTy, Reg, MIB.getMF());
   // This is to make it convenient for Legalizer to get the SPIRVType
   // when processing the actual MI (i.e. not pseudo one).
@@ -290,11 +315,11 @@ Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpirvTy,
 
 void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
                   MachineRegisterInfo &MRI, SPIRVGlobalRegistry *GR) {
-  unsigned Opc = MI.getOpcode();
   assert(MI.getNumDefs() > 0 && MRI.hasOneUse(MI.getOperand(0).getReg()));
   MachineInstr &AssignTypeInst =
       *(MRI.use_instr_begin(MI.getOperand(0).getReg()));
-  auto NewReg = createNewIdReg(MI.getOperand(0).getReg(), Opc, MRI, *GR).first;
+  auto NewReg =
+      createNewIdReg(nullptr, MI.getOperand(0).getReg(), MRI, *GR).first;
   AssignTypeInst.getOperand(1).setReg(NewReg);
   MI.getOperand(0).setReg(NewReg);
   MIB.setInsertPt(*MI.getParent(),
@@ -303,7 +328,7 @@ void processInstr(MachineInstr &MI, MachineIRBuilder &MIB,
   for (auto &Op : MI.operands()) {
     if (!Op.isReg() || Op.isDef())
       continue;
-    auto IdOpInfo = createNewIdReg(Op.getReg(), Opc, MRI, *GR);
+    auto IdOpInfo = createNewIdReg(nullptr, Op.getReg(), MRI, *GR);
     MIB.buildInstr(IdOpInfo.second).addDef(IdOpInfo.first).addUse(Op.getReg());
     Op.setReg(IdOpInfo.first);
   }
@@ -419,6 +444,7 @@ static void processInstrsWithTypeFolding(MachineFunction &MF,
         processInstr(MI, MIB, MRI, GR);
     }
   }
+
   for (MachineBasicBlock &MBB : MF) {
     for (MachineInstr &MI : MBB) {
       // We need to rewrite dst types for ASSIGN_TYPE instrs to be able
@@ -431,16 +457,18 @@ static void processInstrsWithTypeFolding(MachineFunction &MF,
       if (!isTypeFoldingSupported(Opcode))
         continue;
       Register DstReg = MI.getOperand(0).getReg();
-      if (MRI.getType(DstReg).isVector())
+      bool IsDstPtr = MRI.getType(DstReg).isPointer();
+      if (IsDstPtr || MRI.getType(DstReg).isVector())
         MRI.setRegClass(DstReg, &SPIRV::IDRegClass);
       // Don't need to reset type of register holding constant and used in
-      // G_ADDRSPACE_CAST, since it braaks legalizer.
+      // G_ADDRSPACE_CAST, since it breaks legalizer.
       if (Opcode == TargetOpcode::G_CONSTANT && MRI.hasOneUse(DstReg)) {
         MachineInstr &UseMI = *MRI.use_instr_begin(DstReg);
         if (UseMI.getOpcode() == TargetOpcode::G_ADDRSPACE_CAST)
           continue;
       }
-      MRI.setType(DstReg, LLT::scalar(32));
+      MRI.setType(DstReg, IsDstPtr ? LLT::pointer(0, GR->getPointerSize())
+                                   : LLT::scalar(32));
     }
   }
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp
index 5983c9229cb3c2..ecd99f1840d7e0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBankInfo.cpp
@@ -27,23 +27,7 @@ using namespace llvm;
 const RegisterBank &
 SPIRVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
                                               LLT Ty) const {
-  switch (RC.getID()) {
-  case SPIRV::TYPERegClassID:
+  if (RC.getID() == SPIRV::TYPERegClassID)
     return SPIRV::TYPERegBank;
-  case SPIRV::pIDRegClassID:
-  case SPIRV::IDRegClassID:
-    return SPIRV::IDRegBank;
-  case SPIRV::fIDRegClassID:
-    return SPIRV::fIDRegBank;
-  case SPIRV::vIDRegClassID:
-    return SPIRV::vIDRegBank;
-  case SPIRV::vfIDRegClassID:
-    return SPIRV::vfIDRegBank;
-  case SPIRV::vpIDRegClassID:
-    return SPIRV::vpIDRegBank;
-  case SPIRV::ANYIDRegClassID:
-  case SPIRV::ANYRegClassID:
-    return SPIRV::IDRegBank;
-  }
-  llvm_unreachable("Unknown register class");
+  return SPIRV::IDRegBank;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td
index c7f1e172f3d4f1..dea2ef402d3d97 100644
--- a/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td
+++ b/llvm/lib/Target/SPIRV/SPIRVRegisterBanks.td
@@ -8,9 +8,6 @@
 
 // Although RegisterBankSelection is disabled we need to distinct the banks
 // as InstructionSelector RegClass checking code relies on them
-def IDRegBank : RegisterBank<"IDBank", [ID]>;
-def fIDRegBank : RegisterBank<"fIDBank", [fID]>;
-def vIDRegBank : RegisterBank<"vIDBank", [vID]>;
-def vfIDRegBank : RegisterBank<"vfIDBank", [vfID]>;
-def vpIDRegBank : RegisterBank<"vpIDBank", [vpID]>;
+
 def TYPERegBank : RegisterBank<"TYPEBank", [TYPE]>;
+def IDRegBank : RegisterBank<"IDBank", [ID, fID, pID32, pID64, vID, vfID, vpID32, vpID64]>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
index 6d2bfb91a97f12..9231d22e8d8362 100644
--- a/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVRegisterInfo.td
@@ -11,39 +11,46 @@
 //===----------------------------------------------------------------------===//
 
 let Namespace = "SPIRV" in {
-  def p0 : PtrValueType <i32, 0>;
-
-  class P0Vec<ValueType scalar>
-      : PtrValueType <scalar, 0> {
-    let nElem = 2;
-    let ElementType = p0;
-    let isInteger = false;
-    let isFP = false;
-    let isVector = true;
+  // Pointer types for patterns with the GlobalISelEmitter
+  def p32 : PtrValueType <i32, 0>;
+  def p64 : PtrValueType <i64, 0>;
+
+  class VTPtrVec<int nelem, PtrValueType ptr>
+      : VTVec<nelem, ValueType<ptr.Size, ptr.Value>, ptr.Value> {
+    int isPointer = true;
   }
 
-  def v2p0 : P0Vec<i32>;
-  // All registers are for 32-bit identifiers, so have a single dummy register
+  def v2p32 : VTPtrVec<2, p32>;
+  def v2p64 : VTPtrVec<2, p64>;
 
-  // Class for registers that are the result of OpTypeXXX instructions
+  // Class for type registers
   def TYPE0 : Register<"TYPE0">;
   def TYPE : RegisterClass<"SPIRV", [i32], 32, (add TYPE0)>;
 
-  // Class for every other non-type ID
+  // Class for non-type registers
   def ID0 : Register<"ID0">;
-  def ID : RegisterClass<"SPIRV", [i32], 32, (add ID0)>;
   def fID0 : Register<"fID0">;
-  def fID : RegisterClass<"SPIRV", [f32], 32, (add fID0)>;
-  def pID0 : Register<"pID0">;
-  def pID : RegisterClass<"SPIRV", [p0], 32, (add pID0)>;
+  def pID320 : Register<"pID320">;
+  def pID640 : Register<"pID640">;
   def vID0 : Register<"vID0">;
-  def vID : RegisterClass<"SPIRV", [v2i32], 32, (add vID0)>;
   def vfID0 : Register<"vfID0">;
+  def vpID320 : Register<"vpID320">;
+  def vpID640 : Register<"vpID640">;
+
+  def ID : RegisterClass<"SPIRV", [i32], 32, (add ID0)>;
+  def fID : RegisterClass<"SPIRV", [f32], 32, (add fID0)>;
+  def pID32 : RegisterClass<"SPIRV", [p32], 32, (add pID320)>;
+  def pID64 : RegisterClass<"SPIRV", [p64], 32, (add pID640)>;
+  def vID : RegisterClass<"SPIRV", [v2i32], 32, (add vID0)>;
   def vfID : RegisterClass<"SPIRV", [v2f32], 32, (add vfID0)>;
-  def vpID0 : Register<"vpID0">;
-  def vpID : RegisterClass<"SPIRV", [v2p0], 32, (add vpID0)>;
+  def vpID32 : RegisterClass<"SPIRV", [v2p32], 32, (add vpID320)>;
+  def vpID64 : RegisterClass<"SPIRV", [v2p64], 32, (add vpID640)>;
 
-  def ANYID : RegisterClass<"SPIRV", [i32, f32, p0, v2i32, v2f32], 32, (add ID, fID, pID, vID, vfID)>;
+  def ANYID : RegisterClass<
+      "SPIRV",
+      [i32, f32, p32, p64, v2i32, v2f32, v2p32, v2p64],
+      32,
+      (add ID0, fID0, pID320, pID640, vID0, vfID0, vpID320, vpID640)>;
 
   // A few instructions like OpName can take ids from both type and non-type
   // instructions, so we need a super-class to allow for both to count as valid
diff --git a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
index afc75c616f023b..3828fe89e60aec 100644
--- a/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
+++ b/llvm/test/CodeGen/SPIRV/instructions/select-phi.ll
@@ -1,6 +1,10 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --translator-compatibility-mode %s -o - -filetype=obj | spirv-val %}
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --translator-compatibility-mode %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK-DAG: %[[Char:.*]] = OpTypeInt 8 0
 ; CHECK-DAG: %[[Long:.*]] = OpTypeInt 32 0
diff --git a/llvm/test/CodeGen/SPIRV/instructions/select-ptr-load.ll b/llvm/test/CodeGen/SPIRV/instructions/select-ptr-load.ll
new file mode 100644
index 00000000000000..0ff28952f8081a
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/instructions/select-ptr-load.ll
@@ -0,0 +1,25 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[Float:.*]] = OpTypeFloat 32
+; CHECK-SPIRV-DAG: %[[FloatPtr:.*]] = OpTypePointer Function %[[Float]]
+; CHECK-SPIRV: OpInBoundsPtrAccessChain %[[FloatPtr]]
+; CHECK-SPIRV: OpInBoundsPtrAccessChain %[[FloatPtr]]
+; CHECK-SPIRV: OpSelect %[[FloatPtr]]
+; CHECK-SPIRV: OpLoad %[[Float]]
+
+%struct = type { [3 x float] }
+
+define spir_kernel void @bar(i1 %sw) {
+entry:
+  %var1 = alloca %struct
+  %var2 = alloca %struct
+  %elem1 = getelementptr inbounds [3 x float], ptr %var1, i64 0, i64 0
+  %elem2 = getelementptr inbounds [3 x float], ptr %var2, i64 0, i64 1
+  %elem = select i1 %sw, ptr %elem1, ptr %elem2
+  %res = load float, ptr %elem
+  ret void
+}
diff --git a/llvm/test/CodeGen/SPIRV/instructions/select.ll b/llvm/test/CodeGen/SPIRV/instructions/select.ll
index c4176b17abb449..9234b97157d9d8 100644
--- a/llvm/test/CodeGen/SPIRV/instructions/select.ll
+++ b/llvm/test/CodeGen/SPIRV/instructions/select.ll
@@ -1,6 +1,9 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
 ; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
 ; CHECK-DAG:  OpName [[SCALARi32:%.+]] "select_i32"
 ; CHECK-DAG:  OpName [[SCALARPTR:%.+]] "select_ptr"
 ; CHECK-DAG:  OpName [[VEC2i32:%.+]] "select_i32v2"
diff --git a/llvm/test/CodeGen/SPIRV/select.ll b/llvm/test/CodeGen/SPIRV/select-builtin.ll
similarity index 67%
rename from llvm/test/CodeGen/SPIRV/select.ll
rename to llvm/test/CodeGen/SPIRV/select-builtin.ll
index b34e91be1dbcda..6717970d160fcf 100644
--- a/llvm/test/CodeGen/SPI...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants