Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,9 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(Type *ET, unsigned AddrSpc) {
SPIRVType *TranslatedTy = nullptr;
if (ET->isPointerTy() &&
BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers)) {
TranslatedTy = BM->addUntypedPointerKHRType(
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)));
TranslatedTy = BM->addPointerType(
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)),
nullptr);
} else {
ElementType = transType(ET);
TranslatedTy = transPointerType(ElementType, AddrSpc);
Expand All @@ -761,8 +762,9 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
return transPointerType(ET, SPIRAS_Private);
if (BM->isAllowedToUseExtension(ExtensionID::SPV_KHR_untyped_pointers) &&
!(ET->isTypeArray() || ET->isTypeVector() || ET->isSPIRVOpaqueType())) {
TranslatedTy = BM->addUntypedPointerKHRType(
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)));
TranslatedTy = BM->addPointerType(
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)),
nullptr);
} else {
TranslatedTy = BM->addPointerType(
SPIRSPIRVAddrSpaceMap::map(static_cast<SPIRAddressSpace>(AddrSpc)), ET);
Expand Down Expand Up @@ -2347,10 +2349,8 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
}
SPIRVType *VarTy = TranslatedTy;
if (V->getType()->getPointerAddressSpace() == SPIRAS_Generic) {
// TODO: refactor addPointerType and addUntypedPointerKHRType in one
// method if possible.
if (TranslatedTy->isTypeUntypedPointerKHR())
VarTy = BM->addUntypedPointerKHRType(StorageClassFunction);
VarTy = BM->addPointerType(StorageClassFunction, nullptr);
else
VarTy = BM->addPointerType(StorageClassFunction,
TranslatedTy->getPointerElementType());
Expand Down Expand Up @@ -2697,11 +2697,8 @@ LLVMToSPIRVBase::transValueWithoutDecoration(Value *V, SPIRVBasicBlock *BB,
SPIRVType *LLVMToSPIRVBase::mapType(Type *T, SPIRVType *BT) {
assert(!T->isPointerTy() && "Pointer types cannot be stored in the type map");
auto EmplaceStatus = TypeMap.try_emplace(T, BT);
// TODO: Uncomment the assertion, once the type mapping issue is resolved
// assert(EmplaceStatus.second && "The type was already added to the map");
assert(EmplaceStatus.second && "The type was already added to the map");
SPIRVDBG(dbgs() << "[mapType] " << *T << " => "; spvdbgs() << *BT << '\n');
if (!EmplaceStatus.second)
return TypeMap[T];
return BT;
}

Expand Down Expand Up @@ -4302,8 +4299,8 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
SPIRVType *IntegralTy = transType(II->getType()->getStructElementType(1));
// IntegralTy is the type of the result. We want to create a pointer to this
// that we can pass to OpenCLLIB::modf to store the integral part.
SPIRVTypePointer *IntegralPtrTy =
BM->addPointerType(StorageClassFunction, IntegralTy);
SPIRVType *GenericPtrTy = BM->addPointerType(StorageClassFunction, IntegralTy);
auto *IntegralPtrTy = dyn_cast<SPIRVTypePointer>(GenericPtrTy);
// We need to use the entry BB of the function calling llvm.modf.*, instead
// of the current BB. For that, we'll find current BB's parent and get its
// first BB, which is the entry BB of the function.
Expand Down Expand Up @@ -4829,7 +4826,7 @@ SPIRVValue *LLVMToSPIRVBase::transIntrinsicInst(IntrinsicInst *II,
auto *SrcTy = PtrOp->getType();
SPIRVType *DstTy = nullptr;
if (SrcTy->isTypeUntypedPointerKHR())
DstTy = BM->addUntypedPointerKHRType(StorageClassFunction);
DstTy = BM->addPointerType(StorageClassFunction, nullptr);
else
DstTy = BM->addPointerType(StorageClassFunction,
SrcTy->getPointerElementType());
Expand Down
33 changes: 16 additions & 17 deletions lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,7 @@ class SPIRVModuleImpl : public SPIRVModule {
const std::vector<SPIRVType *> &) override;
SPIRVTypeInt *addIntegerType(unsigned BitWidth) override;
SPIRVTypeOpaque *addOpaqueType(const std::string &) override;
SPIRVTypePointer *addPointerType(SPIRVStorageClassKind, SPIRVType *) override;
SPIRVTypeUntypedPointerKHR *
addUntypedPointerKHRType(SPIRVStorageClassKind) override;
SPIRVType *addPointerType(SPIRVStorageClassKind, SPIRVType *) override;
SPIRVTypeImage *addImageType(SPIRVType *,
const SPIRVTypeImageDescriptor &) override;
SPIRVTypeImage *addImageType(SPIRVType *, const SPIRVTypeImageDescriptor &,
Expand Down Expand Up @@ -1023,29 +1021,30 @@ SPIRVTypeFloat *SPIRVModuleImpl::addFloatType(unsigned BitWidth,
return addType(Ty);
}

SPIRVTypePointer *
SPIRVModuleImpl::addPointerType(SPIRVStorageClassKind StorageClass,
SPIRVType *ElementType) {
SPIRVType *SPIRVModuleImpl::addPointerType(SPIRVStorageClassKind StorageClass,
SPIRVType *ElementType = nullptr) {
if (ElementType == nullptr) {
// Untyped pointer
auto Loc = UntypedPtrTyMap.find(StorageClass);
if (Loc != UntypedPtrTyMap.end())
return Loc->second;

auto *Ty = new SPIRVTypeUntypedPointerKHR(this, getId(), StorageClass);
UntypedPtrTyMap[StorageClass] = Ty;
return addType(Ty);
}

// Typed pointer
auto Desc = std::make_pair(StorageClass, ElementType);
auto Loc = PointerTypeMap.find(Desc);
if (Loc != PointerTypeMap.end())
return Loc->second;

auto *Ty = new SPIRVTypePointer(this, getId(), StorageClass, ElementType);
PointerTypeMap[Desc] = Ty;
return addType(Ty);
}

SPIRVTypeUntypedPointerKHR *
SPIRVModuleImpl::addUntypedPointerKHRType(SPIRVStorageClassKind StorageClass) {
auto Loc = UntypedPtrTyMap.find(StorageClass);
if (Loc != UntypedPtrTyMap.end())
return Loc->second;

auto *Ty = new SPIRVTypeUntypedPointerKHR(this, getId(), StorageClass);
UntypedPtrTyMap[StorageClass] = Ty;
return addType(Ty);
}

SPIRVTypeFunction *SPIRVModuleImpl::addFunctionType(
SPIRVType *ReturnType, const std::vector<SPIRVType *> &ParameterTypes) {
return addType(
Expand Down
5 changes: 1 addition & 4 deletions lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,7 @@ class SPIRVModule {
virtual SPIRVTypeSampledImage *addSampledImageType(SPIRVTypeImage *T) = 0;
virtual SPIRVTypeInt *addIntegerType(unsigned) = 0;
virtual SPIRVTypeOpaque *addOpaqueType(const std::string &) = 0;
virtual SPIRVTypePointer *addPointerType(SPIRVStorageClassKind,
SPIRVType *) = 0;
virtual SPIRVTypeUntypedPointerKHR *
addUntypedPointerKHRType(SPIRVStorageClassKind) = 0;
virtual SPIRVType *addPointerType(SPIRVStorageClassKind, SPIRVType *) = 0;
virtual SPIRVTypeStruct *openStructType(unsigned, const std::string &) = 0;
virtual SPIRVEntry *addTypeStructContinuedINTEL(unsigned NumMembers) = 0;
virtual void closeStructType(SPIRVTypeStruct *, bool) = 0;
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVType.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ class SPIRVTypePointer : public SPIRVTypePointerBase<OpTypePointer, 4> {
std::vector<SPIRVEntry *> getNonLiteralOperands() const override {
return std::vector<SPIRVEntry *>(1, getEntry(ElemTypeId));
}
static bool classof(const SPIRVEntry *E) {
return E->getOpCode() == OpTypePointer;
}

protected:
_SPIRV_DEF_ENCDEC3(Id, ElemStorageClass, ElemTypeId)
Expand Down
Loading