Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/SPIRV/OCLUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ SPIRAddressSpace getOCLOpaqueTypeAddrSpace(Op OpCode) {
case OpTypeSampler:
return SPIRV_SAMPLER_T_ADDR_SPACE;
case internal::OpTypeJointMatrixINTEL:
case internal::OpTypeJointMatrixINTELv2:
case OpTypeCooperativeMatrixKHR:
return SPIRAS_Global;
default:
Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/SPIRVInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ const static char PipeStorage[] = "PipeStorage";
const static char ConstantPipeStorage[] = "ConstantPipeStorage";
const static char VmeImageINTEL[] = "VmeImageINTEL";
const static char JointMatrixINTEL[] = "JointMatrixINTEL";
const static char BufferSurfaceINTEL[] = "BufferSurfaceINTEL";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sarnex @vmustya please correct me if I'm wrong - this is being deprecated, right?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sarnex @vmustya any updates on this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, yes as far as I know it's being deprecated but @vmustya should confirm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, the BufferSurfaceINTEL stateful raw buffers are still used for some workloads.

const static char CooperativeMatrixKHR[] = "CooperativeMatrixKHR";
} // namespace kSPIRVTypeName

Expand Down Expand Up @@ -976,6 +977,7 @@ template <> inline void SPIRVMap<std::string, Op, SPIRVOpaqueType>::init() {
_SPIRV_OP(AvcRefResultINTEL)
_SPIRV_OP(AvcSicResultINTEL)
_SPIRV_OP(VmeImageINTEL)
_SPIRV_OP(BufferSurfaceINTEL)
_SPIRV_OP(CooperativeMatrixKHR)
#undef _SPIRV_OP
add("JointMatrixINTEL", internal::OpTypeJointMatrixINTEL);
Expand Down
25 changes: 18 additions & 7 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,15 +472,26 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
auto *MT = static_cast<SPIRVTypeJointMatrixINTEL *>(T);
auto R = static_cast<SPIRVConstant *>(MT->getRows())->getZExtIntValue();
auto C = static_cast<SPIRVConstant *>(MT->getColumns())->getZExtIntValue();
auto L = static_cast<SPIRVConstant *>(MT->getLayout())->getZExtIntValue();
auto S = static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue();
SmallVector<unsigned, 5> Params = {(unsigned)R, (unsigned)C, (unsigned)L,
(unsigned)S};
std::vector<unsigned> Params = {(unsigned)R, (unsigned)C};
if (auto *Layout = MT->getLayout())
Params.push_back(static_cast<SPIRVConstant *>(Layout)->getZExtIntValue());
Params.push_back(
static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue());
if (auto *Use = MT->getUse())
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
transTypeToOCLTypeName(MT->getCompType()),
Params, !UseTPT));
auto *CTI = MT->getComponentTypeInterpretation();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch contradicts current SPV_INTEL_joint_matrix specification, where type interpretation is a part of MulAdd. Note, IGC also expects type interpretation be a part of MulAdd and not a part of the type. Feel free to IM me to discuss this.

Copy link
Contributor

@MrSidims MrSidims May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mateuszchudyk please follow up with @vmaksimo about this comment (I'm OOO for the next 2 weeks), she will point out on spec changes and IGC patch, that adds type interpretation support.
Basically I'm fine to merge it as is, but please make sure, that we know, what we are doing with matrix special types. Please also note, that there is no such thing as Int4 interpretation (it is now a proper TypeInt 4 - see SPV_INTEL_int4 (this is to be backported by us soon)) and Int2 interpretation (and there won't be any counterpart).

if (!CTI)
return mapType(
T, llvm::TargetExtType::get(*Context, "spirv.JointMatrixINTEL",
transType(MT->getCompType()), Params));
const unsigned CTIValue =
static_cast<SPIRVConstant *>(CTI)->getZExtIntValue();
assert(CTIValue <= internal::InternalJointMatrixCTI::PackedInt4 &&
"Unknown matrix component type interpretation");
Params.push_back(CTIValue);
return mapType(
T, llvm::TargetExtType::get(*Context, "spirv.JointMatrixINTEL",
transType(MT->getCompType()), Params));
}
case OpTypeCooperativeMatrixKHR: {
auto *MT = static_cast<SPIRVTypeCooperativeMatrixKHR *>(T);
Expand Down
69 changes: 1 addition & 68 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,21 +649,6 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(Type *ET, unsigned AddrSpc) {
transType(ET)));
}
} else {
// JointMatrixINTEL type is not necessarily an opaque type, it can be
// represented as a structure with pointer to a multidimensional array
// member.
if (ST && ST->hasName()) {
StringRef STName = ST->getName();
if (STName.startswith(kSPIRVTypeName::PrefixAndDelim)) {
SmallVector<std::string, 8> Postfixes;
auto TN = decodeSPIRVTypeName(STName, Postfixes);
if (TN == kSPIRVTypeName::JointMatrixINTEL) {
SPIRVType *TranslatedTy = transSPIRVJointMatrixINTELType(Postfixes);
PointeeTypeMap[TypeKey] = TranslatedTy;
return TranslatedTy;
}
}
}
SPIRVType *ElementType = transType(ET);
// ET, as a recursive type, may contain exactly the same pointer T, so it
// may happen that after translation of ET we already have translated T,
Expand Down Expand Up @@ -698,56 +683,6 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
return TranslatedTy;
}

// Representation in LLVM IR before the translator is a pointer to an opaque
// structure:
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%scope%_%use%
// Here we check the structure name yet again. Another option would be to
// check SPIR-V friendly function calls (by their name) and obtain return
// or their parameter types, assuming, that the appropriate types are Matrix
// structure type. But in the near future, we will reuse Composite
// instructions to do, for example, matrix initialization directly on AMX
// register by OpCompositeConstruct. And we can't claim, that the Result type
// of OpCompositeConstruct instruction is always the joint matrix type, it's
// simply not true.
SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
SmallVector<std::string, 8> Postfixes) {
Type *ElemTy = nullptr;
StringRef Ty{Postfixes[0]};
auto NumBits = llvm::StringSwitch<unsigned>(Ty)
.Case("char", 8)
.Case("short", 16)
.Case("int", 32)
.Case("long", 64)
.Default(0);
if (NumBits)
ElemTy = IntegerType::get(M->getContext(), NumBits);
else if (Ty == "half")
ElemTy = Type::getHalfTy(M->getContext());
else if (Ty == "float")
ElemTy = Type::getFloatTy(M->getContext());
else if (Ty == "double")
ElemTy = Type::getDoubleTy(M->getContext());
else if (Ty == "bfloat16")
ElemTy = Type::getInt16Ty(M->getContext());
else
llvm_unreachable("Unexpected type for matrix!");

auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
unsigned long long N = 0;
if (consumeUnsignedInteger(Postfix, 10, N)) {
BM->getErrorLog().checkError(
false, SPIRVEC_InvalidLlvmModule,
"TypeJointMatrixINTEL expects integer parameters");
return 0;
}
return getUInt32(M, N);
};
std::vector<SPIRVValue *> Args;
for (size_t I = 1; I != Postfixes.size(); ++I)
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
return BM->addJointMatrixINTELType(transType(ElemTy), Args);
}

SPIRVType *LLVMToSPIRVBase::transSPIRVOpaqueType(StringRef STName,
unsigned AddrSpace) {
std::pair<StringRef, unsigned> Key = {STName, AddrSpace};
Expand Down Expand Up @@ -804,9 +739,7 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVOpaqueType(StringRef STName,
return SaveType(BM->addQueueType());
else if (TN == kSPIRVTypeName::PipeStorage)
return SaveType(BM->addPipeStorageType());
else if (TN == kSPIRVTypeName::JointMatrixINTEL) {
return SaveType(transSPIRVJointMatrixINTELType(Postfixes));
} else
else
return SaveType(
BM->addOpaqueGenericType(SPIRVOpaqueTypeOpCodeMap::map(TN)));
}
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ SPIRVEntry *SPIRVEntry::create(Op OpCode) {
static const OpToFactoryMapTy OpToFactoryMap(std::begin(Table),
std::end(Table));

// TODO: To remove this when we make a switch to new version
if (OpCode == internal::OpTypeJointMatrixINTELv2)
OpCode = internal::OpTypeJointMatrixINTEL;

OpToFactoryMapTy::const_iterator Loc = OpToFactoryMap.find(OpCode);
if (Loc != OpToFactoryMap.end())
return Loc->second();
Expand Down
10 changes: 10 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,20 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
{CapabilitySubgroupAvcMotionEstimationIntraINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixWIInstructionsINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
{internal::CapabilityJointMatrixINTEL});
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixCheckedInstructionsINTEL,
{CapabilityCooperativeMatrixKHR});
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixPrefetchINTEL,
{CapabilityCooperativeMatrixKHR});
ADD_VEC_INIT(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL,
{CapabilityCooperativeMatrixKHR});
}

template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {
Expand Down
84 changes: 77 additions & 7 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,7 @@ class SPIRVCompositeConstruct : public SPIRVInstruction {
case OpTypeArray:
case OpTypeStruct:
case internal::OpTypeJointMatrixINTEL:
case internal::OpTypeJointMatrixINTELv2:
case OpTypeCooperativeMatrixKHR:
break;
default:
Expand Down Expand Up @@ -3406,10 +3407,17 @@ template <Op OC>
class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
protected:
SPIRVCapVec getRequiredCapability() const override {
SPIRVType *ResCompTy = this->getType();
if (ResCompTy->isTypeCooperativeMatrixKHR())
return getVec(internal::CapabilityBfloat16ConversionINTEL,
internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
return getVec(internal::CapabilityBfloat16ConversionINTEL);
}

std::optional<ExtensionID> getRequiredExtension() const override {
SPIRVType *ResCompTy = this->getType();
if (ResCompTy->isTypeCooperativeMatrixKHR())
this->getModule()->addExtension(ExtensionID::SPV_INTEL_joint_matrix);
return ExtensionID::SPV_INTEL_bfloat16_conversion;
}

Expand Down Expand Up @@ -3438,8 +3446,25 @@ class SPIRVBfloat16ConversionINTELInstBase : public SPIRVUnaryInst<OC> {
}

auto InstName = OpCodeNameMap::map(OC);
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
auto *Module = this->getModule();
SPIRVErrorLog &SPVErrLog = Module->getErrorLog();

// Cooperative matrix type is allowed as input/output of the instruction
// if SPV_INTEL_joint_matrix is enabled
if (ResCompTy->isTypeCooperativeMatrixKHR()) {
SPVErrLog.checkError(
Module->isAllowedToUseExtension(ExtensionID::SPV_INTEL_joint_matrix),
SPIRVEC_InvalidInstruction,
InstName + "\nCan be used with "
"cooperative matrices only when SPV_INTEL_joint_matrix is "
"enabled\n");
assert(InCompTy->isTypeCooperativeMatrixKHR() &&
"Input must also be a cooperative matrix");
ResCompTy = static_cast<SPIRVTypeCooperativeMatrixKHR *>(ResCompTy)
->getCompType();
InCompTy =
static_cast<SPIRVTypeCooperativeMatrixKHR *>(InCompTy)->getCompType();
}
if (OC == internal::OpConvertFToBF16INTEL) {
SPVErrLog.checkError(
ResCompTy->isTypeInt(16), SPIRVEC_InvalidInstruction,
Expand Down Expand Up @@ -3492,10 +3517,10 @@ class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
SPIRV##x##INTEL;
_SPIRV_OP(JointMatrixLoad, true, 6, true)
_SPIRV_OP(JointMatrixStore, false, 5, true)
_SPIRV_OP(JointMatrixMad, true, 7)
_SPIRV_OP(JointMatrixSUMad, true, 7)
_SPIRV_OP(JointMatrixUSMad, true, 7)
_SPIRV_OP(JointMatrixUUMad, true, 7)
_SPIRV_OP(JointMatrixMad, true, 6, true)
_SPIRV_OP(JointMatrixSUMad, true, 6, true)
_SPIRV_OP(JointMatrixUSMad, true, 6, true)
_SPIRV_OP(JointMatrixUUMad, true, 6, true)
// TODO: move to SPIRVJointMatrixINTELWorkItemInst
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
#undef _SPIRV_OP
Expand Down Expand Up @@ -3529,7 +3554,27 @@ class SPIRVCooperativeMatrixPrefetchINTELInstBase
typedef SPIRVInstTemplate<SPIRVCooperativeMatrixPrefetchINTELInstBase, \
internal::Op##x##INTEL, __VA_ARGS__> \
SPIRV##x##INTEL;
_SPIRV_OP(CooperativeMatrixPrefetch, false, 8, true, 5)
_SPIRV_OP(CooperativeMatrixPrefetch, false, 6, true, 3)
#undef _SPIRV_OP

class SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase
: public SPIRVInstTemplateBase {
protected:
std::optional<ExtensionID> getRequiredExtension() const override {
return ExtensionID::SPV_INTEL_joint_matrix;
}
SPIRVCapVec getRequiredCapability() const override {
return getVec(
internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL);
}
};

#define _SPIRV_OP(x, ...) \
typedef SPIRVInstTemplate< \
SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase, \
internal::Op##x##INTEL, __VA_ARGS__> \
SPIRV##x##INTEL;
_SPIRV_OP(CooperativeMatrixApplyFunction, true, 5)
#undef _SPIRV_OP

class SPIRVCooperativeMatrixKHRInstBase : public SPIRVInstTemplateBase {
Expand Down Expand Up @@ -3813,10 +3858,17 @@ template <Op OC>
class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst<OC> {
protected:
SPIRVCapVec getRequiredCapability() const override {
SPIRVType *ResCompTy = this->getType();
if (ResCompTy->isTypeCooperativeMatrixKHR())
return getVec(internal::CapabilityTensorFloat32RoundingINTEL,
internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
return getVec(internal::CapabilityTensorFloat32RoundingINTEL);
}

std::optional<ExtensionID> getRequiredExtension() const override {
SPIRVType *ResCompTy = this->getType();
if (ResCompTy->isTypeCooperativeMatrixKHR())
this->getModule()->addExtension(ExtensionID::SPV_INTEL_joint_matrix);
return ExtensionID::SPV_INTEL_tensor_float32_conversion;
}

Expand Down Expand Up @@ -3845,7 +3897,25 @@ class SPIRVTensorFloat32RoundingINTELInstBase : public SPIRVUnaryInst<OC> {
}

auto InstName = OpCodeNameMap::map(OC);
SPIRVErrorLog &SPVErrLog = this->getModule()->getErrorLog();
auto *Module = this->getModule();
SPIRVErrorLog &SPVErrLog = Module->getErrorLog();

// Cooperative matrix type is allowed as input/output of the instruction
// if SPV_INTEL_joint_matrix is enabled
if (ResCompTy->isTypeCooperativeMatrixKHR()) {
SPVErrLog.checkError(
Module->isAllowedToUseExtension(ExtensionID::SPV_INTEL_joint_matrix),
SPIRVEC_InvalidInstruction,
InstName + "\nCan be used with "
"cooperative matrices only when SPV_INTEL_joint_matrix is "
"enabled\n");
assert(InCompTy->isTypeCooperativeMatrixKHR() &&
"Input must also be a cooperative matrix");
ResCompTy = static_cast<SPIRVTypeCooperativeMatrixKHR *>(ResCompTy)
->getCompType();
InCompTy =
static_cast<SPIRVTypeCooperativeMatrixKHR *>(InCompTy)->getCompType();
}

SPVErrLog.checkError(
ResCompTy->isTypeFloat(32), SPIRVEC_InvalidInstruction,
Expand Down
10 changes: 10 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,18 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
add(internal::CapabilityCacheControlsINTEL, "CacheControlsINTEL");
add(internal::CapabilityJointMatrixWIInstructionsINTEL,
"JointMatrixWIInstructionsINTEL");
add(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
"JointMatrixTF32ComponentTypeINTEL");
add(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
"JointMatrixBF16ComponentTypeINTEL");
add(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
"JointMatrixPackedInt2ComponentTypeINTEL");
add(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
"JointMatrixPackedInt4ComponentTypeINTEL");
add(internal::CapabilityCooperativeMatrixPrefetchINTEL,
"CooperativeMatrixPrefetchINTEL");
add(internal::CapabilityCooperativeMatrixInvocationInstructionsINTEL,
"CooperativeMatrixInvocationInstructionsINTEL");
add(internal::CapabilityCooperativeMatrixCheckedInstructionsINTEL,
"CooperativeMatrixCheckedInstructionsINTEL");
add(internal::CapabilityBindlessImagesINTEL, "BindlessImagesINTEL");
Expand Down
1 change: 1 addition & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ inline bool isTypeOpCode(Op OpCode) {
isSubgroupAvcINTELTypeOpCode(OpCode) || OC == OpTypeVmeImageINTEL ||
isVCOpCode(OpCode) || OC == internal::OpTypeTokenINTEL ||
OC == internal::OpTypeJointMatrixINTEL ||
OC == internal::OpTypeJointMatrixINTELv2 ||
OC == OpTypeCooperativeMatrixKHR;
}

Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _SPIRV_OP_INTERNAL(ArithmeticFenceINTEL, internal::OpArithmeticFenceINTEL)
_SPIRV_OP_INTERNAL(ConvertFToBF16INTEL, internal::OpConvertFToBF16INTEL)
_SPIRV_OP_INTERNAL(ConvertBF16ToFINTEL, internal::OpConvertBF16ToFINTEL)
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
_SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)
_SPIRV_OP_INTERNAL(JointMatrixMadINTEL, internal::OpJointMatrixMadINTEL)
Expand All @@ -24,6 +25,8 @@ _SPIRV_OP_INTERNAL(CooperativeMatrixConstructCheckedINTEL,
internal::OpCooperativeMatrixConstructCheckedINTEL)
_SPIRV_OP_INTERNAL(CooperativeMatrixPrefetchINTEL,
internal::OpCooperativeMatrixPrefetchINTEL)
_SPIRV_OP_INTERNAL(CooperativeMatrixApplyFunctionINTEL,
internal::OpCooperativeMatrixApplyFunctionINTEL)
_SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
_SPIRV_OP_INTERNAL(ComplexFDivINTEL, internal::ComplexFDivINTEL)
_SPIRV_OP_INTERNAL(MaskedGatherINTEL, internal::OpMaskedGatherINTEL)
Expand Down
Loading
Loading