Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/SPIRV/OCLUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ SPIRAddressSpace getOCLOpaqueTypeAddrSpace(Op OpCode) {
case OpTypeSampler:
return SPIRV_SAMPLER_T_ADDR_SPACE;
case internal::OpTypeJointMatrixINTEL:
case internal::OpTypeJointMatrixINTELv2:
return SPIRAS_Global;
default:
if (isSubgroupAvcINTELTypeOpCode(OpCode))
Expand Down
9 changes: 5 additions & 4 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,11 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
auto *MT = static_cast<SPIRVTypeJointMatrixINTEL *>(T);
auto R = static_cast<SPIRVConstant *>(MT->getRows())->getZExtIntValue();
auto C = static_cast<SPIRVConstant *>(MT->getColumns())->getZExtIntValue();
auto L = static_cast<SPIRVConstant *>(MT->getLayout())->getZExtIntValue();
auto S = static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue();
SmallVector<unsigned, 5> Params = {(unsigned)R, (unsigned)C, (unsigned)L,
(unsigned)S};
std::vector<unsigned> Params = {(unsigned)R, (unsigned)C};
if (auto *Layout = MT->getLayout())
Params.push_back(static_cast<SPIRVConstant *>(Layout)->getZExtIntValue());
Params.push_back(
static_cast<SPIRVConstant *>(MT->getScope())->getZExtIntValue());
if (auto *Use = MT->getUse())
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
auto *CTI = MT->getComponentTypeInterpretation();
Expand Down
4 changes: 4 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ SPIRVEntry *SPIRVEntry::create(Op OpCode) {
static const OpToFactoryMapTy OpToFactoryMap(std::begin(Table),
std::end(Table));

// TODO: To remove this when we make a switch to new version
if (OpCode == internal::OpTypeJointMatrixINTELv2)
OpCode = internal::OpTypeJointMatrixINTEL;

OpToFactoryMapTy::const_iterator Loc = OpToFactoryMap.find(OpCode);
if (Loc != OpToFactoryMap.end())
return Loc->second();
Expand Down
9 changes: 5 additions & 4 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,7 @@ class SPIRVCompositeConstruct : public SPIRVInstruction {
case OpTypeArray:
case OpTypeStruct:
case internal::OpTypeJointMatrixINTEL:
case internal::OpTypeJointMatrixINTELv2:
break;
default:
assert(false && "Invalid type");
Expand Down Expand Up @@ -3329,10 +3330,10 @@ class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
SPIRV##x##INTEL;
_SPIRV_OP(JointMatrixLoad, true, 6, true)
_SPIRV_OP(JointMatrixStore, false, 5, true)
_SPIRV_OP(JointMatrixMad, true, 7)
_SPIRV_OP(JointMatrixSUMad, true, 7)
_SPIRV_OP(JointMatrixUSMad, true, 7)
_SPIRV_OP(JointMatrixUUMad, true, 7)
_SPIRV_OP(JointMatrixMad, true, 6, true)
_SPIRV_OP(JointMatrixSUMad, true, 6, true)
_SPIRV_OP(JointMatrixUSMad, true, 6, true)
_SPIRV_OP(JointMatrixUUMad, true, 6, true)
Comment on lines +3333 to +3336
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maksimsab move of scope parameter happens here (see the second 'true').

Note, that 'scope' according to old spec is mandatory parameter and per new spec is removed at all. So to balance it here making both representations translatable we have to add is as optional for a while.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

// TODO: move to SPIRVJointMatrixINTELWorkItemInst
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
#undef _SPIRV_OP
Expand Down
3 changes: 2 additions & 1 deletion lib/SPIRV/libSPIRV/SPIRVOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ inline bool isTypeOpCode(Op OpCode) {
return (OpTypeVoid <= OC && OC <= OpTypePipe) || OC == OpTypePipeStorage ||
isSubgroupAvcINTELTypeOpCode(OpCode) || OC == OpTypeVmeImageINTEL ||
isVCOpCode(OpCode) || OC == internal::OpTypeTokenINTEL ||
OC == internal::OpTypeJointMatrixINTEL;
OC == internal::OpTypeJointMatrixINTEL ||
OC == internal::OpTypeJointMatrixINTELv2;
}

inline bool isSpecConstantOpCode(Op OpCode) {
Expand Down
1 change: 1 addition & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ _SPIRV_OP_INTERNAL(ArithmeticFenceINTEL, internal::OpArithmeticFenceINTEL)
_SPIRV_OP_INTERNAL(ConvertFToBF16INTEL, internal::OpConvertFToBF16INTEL)
_SPIRV_OP_INTERNAL(ConvertBF16ToFINTEL, internal::OpConvertBF16ToFINTEL)
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
_SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
_SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)
_SPIRV_OP_INTERNAL(JointMatrixMadINTEL, internal::OpJointMatrixMadINTEL)
Expand Down
16 changes: 12 additions & 4 deletions lib/SPIRV/libSPIRV/SPIRVType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ bool SPIRVType::isTypeStruct() const { return OpCode == OpTypeStruct; }
bool SPIRVType::isTypeVector() const { return OpCode == OpTypeVector; }

bool SPIRVType::isTypeJointMatrixINTEL() const {
return OpCode == internal::OpTypeJointMatrixINTEL;
return OpCode == internal::OpTypeJointMatrixINTEL ||
OpCode == internal::OpTypeJointMatrixINTELv2;
}

bool SPIRVType::isTypeVectorBool() const {
Expand Down Expand Up @@ -279,13 +280,20 @@ void SPIRVTypeForwardPointer::decode(std::istream &I) {
}

SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL(
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
SPIRVModule *M, SPIRVId TheId, Op OC, SPIRVType *CompType,
std::vector<SPIRVValue *> Args)
: SPIRVType(M, FixedWC + Args.size(), OC, TheId), CompType(CompType),
Args(Args) {}
Args(std::move(Args)) {}

SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL(
SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
std::vector<SPIRVValue *> Args)
: SPIRVType(M, FixedWC + Args.size(), internal::OpTypeJointMatrixINTEL,
TheId),
CompType(CompType), Args(std::move(Args)) {}

SPIRVTypeJointMatrixINTEL::SPIRVTypeJointMatrixINTEL()
: SPIRVType(OC), CompType(nullptr),
: SPIRVType(internal::OpTypeJointMatrixINTEL), CompType(nullptr),
Args({nullptr, nullptr, nullptr, nullptr}) {}

void SPIRVTypeJointMatrixINTEL::encode(spv_ostream &O) const {
Expand Down
35 changes: 29 additions & 6 deletions lib/SPIRV/libSPIRV/SPIRVType.h
Original file line number Diff line number Diff line change
Expand Up @@ -1060,13 +1060,18 @@ class SPIRVTypeTokenINTEL : public SPIRVType {
};

class SPIRVTypeJointMatrixINTEL : public SPIRVType {
Op OC;
SPIRVType *CompType;
std::vector<SPIRVValue *> Args;

public:
const static Op OC = internal::OpTypeJointMatrixINTEL;
const static SPIRVWord FixedWC = 3;
// Complete constructor
// Complete constructor with non-default OC
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, Op OC,
SPIRVType *CompType,
std::vector<SPIRVValue *> Args);

// Incomplete constructor for default OC
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, SPIRVType *CompType,
std::vector<SPIRVValue *> Args);
// Incomplete constructor
Expand All @@ -1085,11 +1090,29 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
SPIRVType *getCompType() const { return CompType; }
SPIRVValue *getRows() const { return Args[0]; }
SPIRVValue *getColumns() const { return Args[1]; }
SPIRVValue *getLayout() const { return Args[2]; }
SPIRVValue *getScope() const { return Args[3]; }
SPIRVValue *getUse() const { return Args.size() > 4 ? Args[4] : nullptr; }

SPIRVValue *getLayout() const {
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
return Args[2];
return nullptr;
}

SPIRVValue *getScope() const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please separate this methods by blank lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied

if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
return Args[3];
return Args[2];
}

SPIRVValue *getUse() const {
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
return Args.size() > 4 ? Args[4] : nullptr;
return Args[3];
}

SPIRVValue *getComponentTypeInterpretation() const {
return Args.size() > 5 ? Args[5] : nullptr;
if (this->getOpCode() == internal::OpTypeJointMatrixINTEL)
return Args.size() > 5 ? Args[5] : nullptr;
return Args.size() > 4 ? Args[4] : nullptr;
}
};

Expand Down
2 changes: 2 additions & 0 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ enum InternalOp {
IOpJointMatrixUSMadINTEL = 6129,
IOpJointMatrixUUMadINTEL = 6130,
IOpArithmeticFenceINTEL = 6145,
IOpTypeJointMatrixINTELv2 = 6184,
IOpJointMatrixWorkItemLengthINTEL = 6410,
IOpComplexFMulINTEL = 6415,
IOpComplexFDivINTEL = 6416,
Expand Down Expand Up @@ -143,6 +144,7 @@ _SPIRV_OP(Capability, JointMatrixBF16ComponentTypeINTEL)
_SPIRV_OP(Capability, JointMatrixPackedInt2ComponentTypeINTEL)
_SPIRV_OP(Capability, JointMatrixPackedInt4ComponentTypeINTEL)
_SPIRV_OP(Op, TypeJointMatrixINTEL)
_SPIRV_OP(Op, TypeJointMatrixINTELv2)
_SPIRV_OP(Op, JointMatrixLoadINTEL)
_SPIRV_OP(Op, JointMatrixStoreINTEL)
_SPIRV_OP(Op, JointMatrixMadINTEL)
Expand Down