Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 11 additions & 20 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,26 +450,17 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
auto *CTI = MT->getComponentTypeInterpretation();
if (!CTI)
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
transTypeToOCLTypeName(MT->getCompType()),
Params, !UseTPT));
std::string ComponentTypeName;
switch (static_cast<SPIRVConstant *>(CTI)->getZExtIntValue()) {
case internal::InternalJointMatrixCTI::TF32:
ComponentTypeName = "tf32";
break;
case internal::InternalJointMatrixCTI::Bfloat16:
ComponentTypeName = "bfloat16";
break;
case internal::InternalJointMatrixCTI::PackedInt2:
case internal::InternalJointMatrixCTI::PackedInt4:
// Do nothing just now
break;
default:
llvm_unreachable("Unexpected joint matrix component type");
}
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
ComponentTypeName, Params, !UseTPT));
return mapType(
T, llvm::TargetExtType::get(*Context, "spirv.JointMatrixINTEL",
transType(MT->getCompType()), Params));
const unsigned CTIValue =
static_cast<SPIRVConstant *>(CTI)->getZExtIntValue();
assert(CTIValue <= internal::InternalJointMatrixCTI::PackedInt4 &&
"Unknown matrix component type interpretation");
Params.push_back(CTIValue);
return mapType(
T, llvm::TargetExtType::get(*Context, "spirv.JointMatrixINTEL",
transType(MT->getCompType()), Params));
}
case OpTypeForwardPointer: {
SPIRVTypeForwardPointer *FP =
Expand Down
77 changes: 0 additions & 77 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,21 +628,6 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(Type *ET, unsigned AddrSpc) {
transType(ET)));
}
} else {
// JointMatrixINTEL type is not necessarily an opaque type, it can be
// represented as a structure with pointer to a multidimensional array
// member.
if (ST && ST->hasName()) {
StringRef STName = ST->getName();
if (STName.startswith(kSPIRVTypeName::PrefixAndDelim)) {
SmallVector<std::string, 8> Postfixes;
auto TN = decodeSPIRVTypeName(STName, Postfixes);
if (TN == kSPIRVTypeName::JointMatrixINTEL) {
SPIRVType *TranslatedTy = transSPIRVJointMatrixINTELType(Postfixes);
PointeeTypeMap[TypeKey] = TranslatedTy;
return TranslatedTy;
}
}
}
SPIRVType *ElementType = transType(ET);
// ET, as a recursive type, may contain exactly the same pointer T, so it
// may happen that after translation of ET we already have translated T,
Expand Down Expand Up @@ -673,66 +658,6 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
return TranslatedTy;
}

// Representation in LLVM IR before the translator is a pointer to an opaque
// structure:
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%layout%_%scope%_%use%
// Here we check the structure name yet again. Another option would be to
// check SPIR-V friendly function calls (by their name) and obtain return
// or their parameter types, assuming, that the appropriate types are Matrix
// structure type. But in the near future, we will reuse Composite
// instructions to do, for example, matrix initialization directly on AMX
// register by OpCompositeConstruct. And we can't claim, that the Result type
// of OpCompositeConstruct instruction is always the joint matrix type, it's
// simply not true.
SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
SmallVector<std::string, 8> Postfixes) {
auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
unsigned long long N = 0;
if (consumeUnsignedInteger(Postfix, 10, N))
BM->getErrorLog().checkError(
false, SPIRVEC_InvalidLlvmModule,
"TypeJointMatrixINTEL expects integer parameters");
return getUInt32(M, N);
};
std::vector<SPIRVValue *> Args;
for (size_t I = 1; I != Postfixes.size(); ++I)
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));

Type *ElemTy = nullptr;
StringRef Ty{Postfixes[0]};
auto NumBits = llvm::StringSwitch<unsigned>(Ty)
.Case("char", 8)
.Case("short", 16)
.Case("int", 32)
.Case("long", 64)
.Default(0);
if (NumBits) {
ElemTy = IntegerType::get(M->getContext(), NumBits);
} else if (Ty == "half") {
ElemTy = Type::getHalfTy(M->getContext());
} else if (Ty == "float") {
ElemTy = Type::getFloatTy(M->getContext());
} else if (Ty == "double") {
ElemTy = Type::getDoubleTy(M->getContext());
} else if (Ty == "bfloat16") {
ElemTy = Type::getInt16Ty(M->getContext());
// TODO: add BF16 CTI when we do breaking change
// auto *CTI = transConstant(getUInt32(M, static_cast<uint64_t>(
// internal::InternalJointMatrixCTI::Bfloat16)));
// Args.push_back(CTI);
// BM->addCapability(internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
} else if (Ty == "tf32") {
ElemTy = Type::getFloatTy(M->getContext());
auto *CTI = transConstant(getUInt32(
M, static_cast<uint64_t>(internal::InternalJointMatrixCTI::TF32)));
Args.push_back(CTI);
BM->addCapability(internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
} else {
llvm_unreachable("Unexpected type for matrix!");
}
return BM->addJointMatrixINTELType(transType(ElemTy), Args);
}

SPIRVType *LLVMToSPIRVBase::transSPIRVOpaqueType(StringRef STName,
unsigned AddrSpace) {
std::pair<StringRef, unsigned> Key = {STName, AddrSpace};
Expand Down Expand Up @@ -789,8 +714,6 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVOpaqueType(StringRef STName,
return SaveType(BM->addQueueType());
else if (TN == kSPIRVTypeName::PipeStorage)
return SaveType(BM->addPipeStorageType());
else if (TN == kSPIRVTypeName::JointMatrixINTEL)
return SaveType(transSPIRVJointMatrixINTELType(Postfixes));
else if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_vector_compute) &&
TN == kSPIRVTypeName::BufferSurfaceINTEL) {
auto Access = getAccessQualifier(STName);
Expand Down
Loading