Skip to content

Commit edba4b3

Browse files
MrSidimsFreddyLeaf
authored andcommitted
Add ComponentTypeInterpretation for joint matrix type (intel#1835)
It specifies how to interpret 'Component Type' when components of a joint matrix are storages for values of different types, for example float for TF32, unsigned short for bfloat16. At this point only tf32 type interpretation is added during SPIR-V generation. Adding it to bf16 is a breaking change and requires adaptation across drivers. Spec update: intel#8175 Signed-off-by: Sidorov, Dmitry [email protected] Original commit: KhronosGroup/SPIRV-LLVM-Translator@b7c5218
1 parent 9cc334c commit edba4b3

File tree

7 files changed

+292
-24
lines changed

7 files changed

+292
-24
lines changed

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,28 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
445445
(unsigned)S};
446446
if (auto *Use = MT->getUse())
447447
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
448+
auto *CTI = MT->getComponentTypeInterpretation();
449+
if (!CTI)
450+
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
451+
transTypeToOCLTypeName(MT->getCompType()),
452+
Params, !UseTPT));
453+
std::string ComponentTypeName;
454+
switch (static_cast<SPIRVConstant *>(CTI)->getZExtIntValue()) {
455+
case internal::InternalJointMatrixCTI::TF32:
456+
ComponentTypeName = "tf32";
457+
break;
458+
case internal::InternalJointMatrixCTI::Bfloat16:
459+
ComponentTypeName = "bfloat16";
460+
break;
461+
case internal::InternalJointMatrixCTI::PackedInt2:
462+
case internal::InternalJointMatrixCTI::PackedInt4:
463+
// Do nothing just now
464+
break;
465+
default:
466+
llvm_unreachable("Unexpected joint matrix component type");
467+
}
448468
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
449-
transTypeToOCLTypeName(MT->getCompType()),
450-
Params, !UseTPT));
469+
ComponentTypeName, Params, !UseTPT));
451470
}
452471
case OpTypeForwardPointer: {
453472
SPIRVTypeForwardPointer *FP =

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
614614

615615
// Representation in LLVM IR before the translator is a pointer to an opaque
616616
// structure:
617-
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%scope%_%use%
617+
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%layout%_%scope%_%use%
618618
// Here we check the structure name yet again. Another option would be to
619619
// check SPIR-V friendly function calls (by their name) and obtain return
620620
// or their parameter types, assuming, that the appropriate types are Matrix
@@ -625,6 +625,18 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
625625
// simply not true.
626626
SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
627627
SmallVector<std::string, 8> Postfixes) {
628+
auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
629+
unsigned long long N = 0;
630+
if (consumeUnsignedInteger(Postfix, 10, N))
631+
BM->getErrorLog().checkError(
632+
false, SPIRVEC_InvalidLlvmModule,
633+
"TypeJointMatrixINTEL expects integer parameters");
634+
return getUInt32(M, N);
635+
};
636+
std::vector<SPIRVValue *> Args;
637+
for (size_t I = 1; I != Postfixes.size(); ++I)
638+
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
639+
628640
Type *ElemTy = nullptr;
629641
StringRef Ty{Postfixes[0]};
630642
auto NumBits = llvm::StringSwitch<unsigned>(Ty)
@@ -633,32 +645,30 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
633645
.Case("int", 32)
634646
.Case("long", 64)
635647
.Default(0);
636-
if (NumBits)
648+
if (NumBits) {
637649
ElemTy = IntegerType::get(M->getContext(), NumBits);
638-
else if (Ty == "half")
650+
} else if (Ty == "half") {
639651
ElemTy = Type::getHalfTy(M->getContext());
640-
else if (Ty == "float")
652+
} else if (Ty == "float") {
641653
ElemTy = Type::getFloatTy(M->getContext());
642-
else if (Ty == "double")
654+
} else if (Ty == "double") {
643655
ElemTy = Type::getDoubleTy(M->getContext());
644-
else if (Ty == "bfloat16")
656+
} else if (Ty == "bfloat16") {
645657
ElemTy = Type::getInt16Ty(M->getContext());
646-
else
658+
// TODO: add BF16 CTI when we do breaking change
659+
// auto *CTI = transConstant(getUInt32(M, static_cast<uint64_t>(
660+
// internal::InternalJointMatrixCTI::Bfloat16)));
661+
// Args.push_back(CTI);
662+
// BM->addCapability(internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
663+
} else if (Ty == "tf32") {
664+
ElemTy = Type::getFloatTy(M->getContext());
665+
auto *CTI = transConstant(getUInt32(
666+
M, static_cast<uint64_t>(internal::InternalJointMatrixCTI::TF32)));
667+
Args.push_back(CTI);
668+
BM->addCapability(internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
669+
} else {
647670
llvm_unreachable("Unexpected type for matrix!");
648-
649-
auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
650-
unsigned long long N = 0;
651-
if (consumeUnsignedInteger(Postfix, 10, N)) {
652-
BM->getErrorLog().checkError(
653-
false, SPIRVEC_InvalidLlvmModule,
654-
"TypeJointMatrixINTEL expects integer parameters");
655-
return 0;
656-
}
657-
return getUInt32(M, N);
658-
};
659-
std::vector<SPIRVValue *> Args;
660-
for (size_t I = 1; I != Postfixes.size(); ++I)
661-
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
671+
}
662672
return BM->addJointMatrixINTELType(transType(ElemTy), Args);
663673
}
664674

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,14 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
205205
{CapabilitySubgroupAvcMotionEstimationIntraINTEL});
206206
ADD_VEC_INIT(internal::CapabilityJointMatrixWIInstructionsINTEL,
207207
{internal::CapabilityJointMatrixINTEL});
208+
ADD_VEC_INIT(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
209+
{internal::CapabilityJointMatrixINTEL});
210+
ADD_VEC_INIT(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
211+
{internal::CapabilityJointMatrixINTEL});
212+
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
213+
{internal::CapabilityJointMatrixINTEL});
214+
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
215+
{internal::CapabilityJointMatrixINTEL});
208216
}
209217

210218
template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,14 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
649649
"TensorFloat32ConversionINTEL");
650650
add(internal::CapabilityJointMatrixWIInstructionsINTEL,
651651
"JointMatrixWIInstructionsINTEL");
652+
add(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
653+
"JointMatrixTF32ComponentTypeINTEL");
654+
add(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
655+
"JointMatrixBF16ComponentTypeINTEL");
656+
add(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
657+
"JointMatrixPackedInt2ComponentTypeINTEL");
658+
add(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
659+
"JointMatrixPackedInt4ComponentTypeINTEL");
652660
}
653661
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
654662

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,9 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
10881088
SPIRVValue *getLayout() const { return Args[2]; }
10891089
SPIRVValue *getScope() const { return Args[3]; }
10901090
SPIRVValue *getUse() const { return Args.size() > 4 ? Args[4] : nullptr; }
1091+
SPIRVValue *getComponentTypeInterpretation() const {
1092+
return Args.size() > 5 ? Args[5] : nullptr;
1093+
}
10911094
};
10921095

10931096
} // namespace SPIRV

llvm-spirv/lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ enum InternalCapability {
9999
ICapabilityComplexFloatMulDivINTEL = 6414,
100100
ICapabilityTensorFloat32ConversionINTEL = 6425,
101101
ICapabilityMaskedGatherScatterINTEL = 6427,
102-
ICapabilityJointMatrixWIInstructionsINTEL = 6435
102+
ICapabilityJointMatrixWIInstructionsINTEL = 6435,
103+
ICapabilityJointMatrixTF32ComponentTypeINTEL = 6436,
104+
ICapabilityJointMatrixBF16ComponentTypeINTEL = 6437,
105+
ICapabilityJointMatrixPackedInt2ComponentTypeINTEL = 6438,
106+
ICapabilityJointMatrixPackedInt4ComponentTypeINTEL = 6439
103107
};
104108

105109
enum InternalFunctionControlMask { IFunctionControlOptNoneINTELMask = 0x10000 };
@@ -120,6 +124,14 @@ enum InternalJointMatrixLayout {
120124

121125
enum InternalJointMatrixUse { MatrixA = 0, MatrixB = 1, Accumulator = 2 };
122126

127+
enum InternalJointMatrixCTI {
128+
None = 0,
129+
TF32 = 1,
130+
Bfloat16 = 2,
131+
PackedInt2 = 3,
132+
PackedInt4 = 4
133+
};
134+
123135
enum InternalBuiltIn {
124136
IBuiltInSubDeviceIDINTEL = 6135,
125137
IBuiltInGlobalHWThreadIDINTEL = 6136,
@@ -128,6 +140,10 @@ enum InternalBuiltIn {
128140
#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
129141
_SPIRV_OP(Capability, JointMatrixINTEL)
130142
_SPIRV_OP(Capability, JointMatrixWIInstructionsINTEL)
143+
_SPIRV_OP(Capability, JointMatrixTF32ComponentTypeINTEL)
144+
_SPIRV_OP(Capability, JointMatrixBF16ComponentTypeINTEL)
145+
_SPIRV_OP(Capability, JointMatrixPackedInt2ComponentTypeINTEL)
146+
_SPIRV_OP(Capability, JointMatrixPackedInt4ComponentTypeINTEL)
131147
_SPIRV_OP(Op, TypeJointMatrixINTEL)
132148
_SPIRV_OP(Op, JointMatrixLoadINTEL)
133149
_SPIRV_OP(Op, JointMatrixStoreINTEL)

0 commit comments

Comments
 (0)