Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,28 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
(unsigned)S};
if (auto *Use = MT->getUse())
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
auto *CTI = MT->getComponentTypeInterpretation();
if (!CTI)
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
transTypeToOCLTypeName(MT->getCompType()),
Params, !UseTPT));
std::string ComponentTypeName;
switch (static_cast<SPIRVConstant *>(CTI)->getZExtIntValue()) {
case internal::InternalJointMatrixCTI::TF32:
ComponentTypeName = "tf32";
break;
case internal::InternalJointMatrixCTI::Bfloat16:
ComponentTypeName = "bfloat16";
break;
case internal::InternalJointMatrixCTI::PackedInt2:
case internal::InternalJointMatrixCTI::PackedInt4:
// Do nothing just now
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: add "TODO" mark here

break;
default:
llvm_unreachable("Unexpected joint matrix component type");
}
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
transTypeToOCLTypeName(MT->getCompType()),
Params, !UseTPT));
ComponentTypeName, Params, !UseTPT));
}
case OpTypeForwardPointer: {
SPIRVTypeForwardPointer *FP =
Expand Down
52 changes: 32 additions & 20 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,20 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
// simply not true.
SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
SmallVector<std::string, 8> Postfixes) {
auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
unsigned long long N = 0;
if (consumeUnsignedInteger(Postfix, 10, N)) {
BM->getErrorLog().checkError(
false, SPIRVEC_InvalidLlvmModule,
"TypeJointMatrixINTEL expects integer parameters");
return 0;
}
return getUInt32(M, N);
};
std::vector<SPIRVValue *> Args;
for (size_t I = 1; I != Postfixes.size(); ++I)
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));

Type *ElemTy = nullptr;
StringRef Ty{Postfixes[0]};
auto NumBits = llvm::StringSwitch<unsigned>(Ty)
Expand All @@ -629,32 +643,30 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
.Case("int", 32)
.Case("long", 64)
.Default(0);
if (NumBits)
if (NumBits) {
ElemTy = IntegerType::get(M->getContext(), NumBits);
else if (Ty == "half")
} else if (Ty == "half") {
ElemTy = Type::getHalfTy(M->getContext());
else if (Ty == "float")
} else if (Ty == "float") {
ElemTy = Type::getFloatTy(M->getContext());
else if (Ty == "double")
} else if (Ty == "double") {
ElemTy = Type::getDoubleTy(M->getContext());
else if (Ty == "bfloat16")
} else if (Ty == "bfloat16") {
ElemTy = Type::getInt16Ty(M->getContext());
else
// TODO: add BF16 CTI when we do breaking change
// auto *CTI = transConstant(getUInt32(M, static_cast<uint64_t>(
// internal::InternalJointMatrixCTI::Bfloat16)));
// Args.push_back(CTI);
// BM->addCapability(internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
} else if (Ty == "tf32") {
ElemTy = Type::getFloatTy(M->getContext());
auto *CTI = transConstant(getUInt32(
M, static_cast<uint64_t>(internal::InternalJointMatrixCTI::TF32)));
Args.push_back(CTI);
BM->addCapability(internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
} else {
llvm_unreachable("Unexpected type for matrix!");

auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
unsigned long long N = 0;
if (consumeUnsignedInteger(Postfix, 10, N)) {
BM->getErrorLog().checkError(
false, SPIRVEC_InvalidLlvmModule,
"TypeJointMatrixINTEL expects integer parameters");
return 0;
}
return getUInt32(M, N);
};
std::vector<SPIRVValue *> Args;
for (size_t I = 1; I != Postfixes.size(); ++I)
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
}
return BM->addJointMatrixINTELType(transType(ElemTy), Args);
}

Expand Down
8 changes: 8 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,14 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
{CapabilitySubgroupAvcMotionEstimationIntraINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixWIInstructionsINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
}

template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {
Expand Down
8 changes: 8 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,14 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
"TensorFloat32ConversionINTEL");
add(internal::CapabilityJointMatrixWIInstructionsINTEL,
"JointMatrixWIInstructionsINTEL");
add(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
"JointMatrixTF32ComponentTypeINTEL");
add(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
"JointMatrixBF16ComponentTypeINTEL");
add(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
"JointMatrixPackedInt2ComponentTypeINTEL");
add(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
"JointMatrixPackedInt4ComponentTypeINTEL");
}
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)

Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVType.h
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,9 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
SPIRVValue *getLayout() const { return Args[2]; }
SPIRVValue *getScope() const { return Args[3]; }
SPIRVValue *getUse() const { return Args.size() > 4 ? Args[4] : nullptr; }
SPIRVValue *getComponentTypeInterpretation() const {
return Args.size() > 5 ? Args[5] : nullptr;
}
};

} // namespace SPIRV
Expand Down
17 changes: 16 additions & 1 deletion lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ enum InternalCapability {
ICapabilityComplexFloatMulDivINTEL = 6414,
ICapabilityTensorFloat32ConversionINTEL = 6425,
ICapabilityMaskedGatherScatterINTEL = 6427,
ICapabilityJointMatrixWIInstructionsINTEL = 6435
ICapabilityJointMatrixWIInstructionsINTEL = 6435,
ICapabilityJointMatrixTF32ComponentTypeINTEL = 6436,
ICapabilityJointMatrixBF16ComponentTypeINTEL = 6437,
ICapabilityJointMatrixPackedInt2ComponentTypeINTEL = 6438,
ICapabilityJointMatrixPackedInt4ComponentTypeINTEL = 6439
};

enum InternalFunctionControlMask { IFunctionControlOptNoneINTELMask = 0x10000 };
Expand All @@ -99,6 +103,13 @@ enum InternalJointMatrixLayout {

enum InternalJointMatrixUse { MatrixA = 0, MatrixB = 1, Accumulator = 2 };

enum InternalJointMatrixCTI {
TF32 = 0,
Bfloat16 = 1,
PackedInt2 = 2,
PackedInt4 = 3
};

enum InternalBuiltIn {
IBuiltInSubDeviceIDINTEL = 6135,
IBuiltInGlobalHWThreadIDINTEL = 6136,
Expand All @@ -107,6 +118,10 @@ enum InternalBuiltIn {
#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
_SPIRV_OP(Capability, JointMatrixINTEL)
_SPIRV_OP(Capability, JointMatrixWIInstructionsINTEL)
_SPIRV_OP(Capability, JointMatrixTF32ComponentTypeINTEL)
_SPIRV_OP(Capability, JointMatrixBF16ComponentTypeINTEL)
_SPIRV_OP(Capability, JointMatrixPackedInt2ComponentTypeINTEL)
_SPIRV_OP(Capability, JointMatrixPackedInt4ComponentTypeINTEL)
_SPIRV_OP(Op, TypeJointMatrixINTEL)
_SPIRV_OP(Op, JointMatrixLoadINTEL)
_SPIRV_OP(Op, JointMatrixStoreINTEL)
Expand Down
Loading