diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp index 0b9a363592..0771f24266 100644 --- a/lib/SPIRV/SPIRVWriter.cpp +++ b/lib/SPIRV/SPIRVWriter.cpp @@ -735,8 +735,9 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(Type *ET, unsigned AddrSpc) { SPIRVType *TranslatedTy = nullptr; if (ET->isPointerTy() && BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers)) { - TranslatedTy = BM->addUntypedPointerKHRType( - SPIRSPIRVAddrSpaceMap::map(static_cast(AddrSpc))); + TranslatedTy = BM->addPointerType( + SPIRSPIRVAddrSpaceMap::map(static_cast(AddrSpc)), + nullptr); } else { ElementType = transType(ET); TranslatedTy = transPointerType(ElementType, AddrSpc); @@ -761,8 +762,9 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) { return transPointerType(ET, SPIRAS_Private); if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers) && !(ET->isTypeArray() || ET->isTypeVector() || ET->isSPIRVOpaqueType())) { - TranslatedTy = BM->addUntypedPointerKHRType( - SPIRSPIRVAddrSpaceMap::map(static_cast(AddrSpc))); + TranslatedTy = BM->addPointerType( + SPIRSPIRVAddrSpaceMap::map(static_cast(AddrSpc)), + nullptr); } else { TranslatedTy = BM->addPointerType( SPIRSPIRVAddrSpaceMap::map(static_cast(AddrSpc)), ET); @@ -2347,10 +2349,8 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB, } SPIRVType *VarTy = TranslatedTy; if (V->getType()->getPointerAddressSpace() == SPIRAS_Generic) { - // TODO: refactor addPointerType and addUntypedPointerKHRType in one - // method if possible. if (TranslatedTy->isTypeUntypedPointerKHR()) - VarTy = BM->addUntypedPointerKHRType(StorageClassFunction); + VarTy = BM->addPointerType(StorageClassFunction, nullptr); else VarTy = BM->addPointerType(StorageClassFunction, TranslatedTy->getPointerElementType()); @@ -2697,11 +2697,8 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB, SPIRVType *LLVMToSPIRVBase::mapType(Type *T, SPIRVType *BT) { assert(!T->isPointerTy() && "Pointer types cannot be stored in the type map"); auto EmplaceStatus = TypeMap.try_emplace(T, BT); - // TODO: Uncomment the assertion, once the type mapping issue is resolved - // assert(EmplaceStatus.second && "The type was already added to the map"); + assert(EmplaceStatus.second && "The type was already added to the map"); SPIRVDBG(dbgs() << "[mapType] " << *T << " => "; spvdbgs() << *BT << '\n'); - if (!EmplaceStatus.second) - return TypeMap[T]; return BT; } @@ -4302,8 +4299,8 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II, SPIRVType *IntegralTy = transType(II->getType()->getStructElementType(1)); // IntegralTy is the type of the result. We want to create a pointer to this // that we can pass to OpenCLLIB::modf to store the integral part. - SPIRVTypePointer *IntegralPtrTy = - BM->addPointerType(StorageClassFunction, IntegralTy); + SPIRVType *GenericPtrTy = BM->addPointerType(StorageClassFunction, IntegralTy); + auto *IntegralPtrTy = dyn_cast(GenericPtrTy); // We need to use the entry BB of the function calling llvm.modf.*, instead // of the current BB. For that, we'll find current BB's parent and get its // first BB, which is the entry BB of the function. @@ -4829,7 +4826,7 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II, auto *SrcTy = PtrOp->getType(); SPIRVType *DstTy = nullptr; if (SrcTy->isTypeUntypedPointerKHR()) - DstTy = BM->addUntypedPointerKHRType(StorageClassFunction); + DstTy = BM->addPointerType(StorageClassFunction, nullptr); else DstTy = BM->addPointerType(StorageClassFunction, SrcTy->getPointerElementType()); diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.cpp b/lib/SPIRV/libSPIRV/SPIRVModule.cpp index 966b4b621a..eda598ee18 100644 --- a/lib/SPIRV/libSPIRV/SPIRVModule.cpp +++ b/lib/SPIRV/libSPIRV/SPIRVModule.cpp @@ -259,9 +259,7 @@ class SPIRVModuleImpl : public SPIRVModule { const std::vector &) override; SPIRVTypeInt *addIntegerType(unsigned BitWidth) override; SPIRVTypeOpaque *addOpaqueType(const std::string &) override; - SPIRVTypePointer *addPointerType(SPIRVStorageClassKind, SPIRVType *) override; - SPIRVTypeUntypedPointerKHR * - addUntypedPointerKHRType(SPIRVStorageClassKind) override; + SPIRVType *addPointerType(SPIRVStorageClassKind, SPIRVType *) override; SPIRVTypeImage *addImageType(SPIRVType *, const SPIRVTypeImageDescriptor &) override; SPIRVTypeImage *addImageType(SPIRVType *, const SPIRVTypeImageDescriptor &, @@ -1023,29 +1021,30 @@ SPIRVTypeFloat *SPIRVModuleImpl::addFloatType(unsigned BitWidth, return addType(Ty); } -SPIRVTypePointer * -SPIRVModuleImpl::addPointerType(SPIRVStorageClassKind StorageClass, - SPIRVType *ElementType) { +SPIRVType *SPIRVModuleImpl::addPointerType(SPIRVStorageClassKind StorageClass, + SPIRVType *ElementType = nullptr) { + if (ElementType == nullptr) { + // Untyped pointer + auto Loc = UntypedPtrTyMap.find(StorageClass); + if (Loc != UntypedPtrTyMap.end()) + return Loc->second; + + auto *Ty = new SPIRVTypeUntypedPointerKHR(this, getId(), StorageClass); + UntypedPtrTyMap[StorageClass] = Ty; + return addType(Ty); + } + + // Typed pointer auto Desc = std::make_pair(StorageClass, ElementType); auto Loc = PointerTypeMap.find(Desc); if (Loc != PointerTypeMap.end()) return Loc->second; + auto *Ty = new SPIRVTypePointer(this, getId(), StorageClass, ElementType); PointerTypeMap[Desc] = Ty; return addType(Ty); } -SPIRVTypeUntypedPointerKHR * -SPIRVModuleImpl::addUntypedPointerKHRType(SPIRVStorageClassKind StorageClass) { - auto Loc = UntypedPtrTyMap.find(StorageClass); - if (Loc != UntypedPtrTyMap.end()) - return Loc->second; - - auto *Ty = new SPIRVTypeUntypedPointerKHR(this, getId(), StorageClass); - UntypedPtrTyMap[StorageClass] = Ty; - return addType(Ty); -} - SPIRVTypeFunction *SPIRVModuleImpl::addFunctionType( SPIRVType *ReturnType, const std::vector &ParameterTypes) { return addType( diff --git a/lib/SPIRV/libSPIRV/SPIRVModule.h b/lib/SPIRV/libSPIRV/SPIRVModule.h index a1cd44b36f..a0565c9019 100644 --- a/lib/SPIRV/libSPIRV/SPIRVModule.h +++ b/lib/SPIRV/libSPIRV/SPIRVModule.h @@ -257,10 +257,7 @@ class SPIRVModule { virtual SPIRVTypeSampledImage *addSampledImageType(SPIRVTypeImage *T) = 0; virtual SPIRVTypeInt *addIntegerType(unsigned) = 0; virtual SPIRVTypeOpaque *addOpaqueType(const std::string &) = 0; - virtual SPIRVTypePointer *addPointerType(SPIRVStorageClassKind, - SPIRVType *) = 0; - virtual SPIRVTypeUntypedPointerKHR * - addUntypedPointerKHRType(SPIRVStorageClassKind) = 0; + virtual SPIRVType *addPointerType(SPIRVStorageClassKind, SPIRVType *) = 0; virtual SPIRVTypeStruct *openStructType(unsigned, const std::string &) = 0; virtual SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) = 0; virtual void closeStructType(SPIRVTypeStruct *, bool) = 0; diff --git a/lib/SPIRV/libSPIRV/SPIRVType.h b/lib/SPIRV/libSPIRV/SPIRVType.h index d19f092ec1..420453b03b 100644 --- a/lib/SPIRV/libSPIRV/SPIRVType.h +++ b/lib/SPIRV/libSPIRV/SPIRVType.h @@ -323,6 +323,9 @@ class SPIRVTypePointer : public SPIRVTypePointerBase { std::vector getNonLiteralOperands() const override { return std::vector(1, getEntry(ElemTypeId)); } + static bool classof(const SPIRVEntry *E) { + return E->getOpCode() == OpTypePointer; + } protected: _SPIRV_DEF_ENCDEC3(Id, ElemStorageClass, ElemTypeId)