Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 0 additions & 115 deletions lib/SPIRV/SPIRVRegularizeLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,118 +347,6 @@ Value *SPIRVRegularizeLLVMBase::extendBitInstBoolArg(Instruction *II) {
}
}

void SPIRVRegularizeLLVMBase::adaptStructTypes(StructType *ST) {
if (!ST->hasName())
return;
StringRef STName = ST->getName();
STName.consume_front("struct.");
STName.consume_front("__spv::");
StringRef MangledName = STName.substr(0, STName.find('.'));

// Representation in LLVM IR before the translator is a pointer array wrapped
// in a structure:
// %struct.__spirv_JointMatrixINTEL = type { [R x [C x [L x [S x type]]]]* }
// where R = Rows, C = Columnts, L = Layout + 1, S = Scope + 1
// this '+1' for the Layout and Scope is required because both of them can
// be '0', but array size can not be '0'.
// The result should look like SPIR-V friendly LLVM IR:
// %spirv.JointMatrixINTEL._char_2_2_0_3
// Here we check the structure name yet again. Another option would be to
// check SPIR-V friendly function calls (by their name) and obtain return
// or their parameter types, assuming, that the appropriate types are Matrix
// structure type. But in the near future, we will reuse Composite
// instructions to do, for example, matrix initialization directly on AMX
// register by OpCompositeConstruct. And we can't claim, that the Result type
// of OpCompositeConstruct instruction is always the joint matrix type, it's
// simply not true.
if (MangledName == "__spirv_JointMatrixINTEL" && !ST->isOpaquePointerTy()) {
auto *PtrTy = dyn_cast<PointerType>(ST->getElementType(0));
assert(PtrTy &&
"Expected a pointer to an array to represent joint matrix type");
std::vector<size_t> TypeLayout;
ArrayType *ArrayTy =
dyn_cast<ArrayType>(PtrTy->getNonOpaquePointerElementType());
assert(ArrayTy && "Expected a pointer element type of an array type to "
"represent joint matrix type");
TypeLayout.push_back(ArrayTy->getNumElements());
for (size_t I = 1; I != 4; ++I) {
ArrayTy = dyn_cast<ArrayType>(ArrayTy->getElementType());
assert(ArrayTy &&
"Expected a element type to represent joint matrix type");
TypeLayout.push_back(ArrayTy->getNumElements());
}
// JointMatrixINTEL type can have optional 'Use' parameter, which is encoded
// as another array dimention. In case if it has default 'Unnecessary' (4)
// parameter - ignore it.
if (isa<ArrayType>(ArrayTy->getElementType())) {
ArrayTy = cast<ArrayType>(ArrayTy->getElementType());
uint32_t UseInt = ArrayTy->getNumElements();
assert(UseInt <= 4 && "Use parameter encoded in the array must be < 5 ");
if (UseInt != 4)
TypeLayout.push_back(UseInt);
}

auto *ElemTy = ArrayTy->getElementType();
std::string ElemTyStr;
if (ElemTy->isIntegerTy()) {
auto *IntElemTy = cast<IntegerType>(ElemTy);
switch (IntElemTy->getBitWidth()) {
case 8:
ElemTyStr = "char";
break;
case 16:
ElemTyStr = "short";
break;
case 32:
ElemTyStr = "int";
break;
case 64:
ElemTyStr = "long";
break;
default:
ElemTyStr = "i" + std::to_string(IntElemTy->getBitWidth());
}
}
// Check half type like this as well, but in DPC++ it most likelly will
// be a class
else if (ElemTy->isHalfTy())
ElemTyStr = "half";
else if (ElemTy->isFloatTy())
ElemTyStr = "float";
else if (ElemTy->isDoubleTy())
ElemTyStr = "double";
else {
// Half type is special: in DPC++ we use `class half` instead of `half`
// type natively supported by Clang.
auto *STElemTy = dyn_cast<StructType>(ElemTy);
if (!STElemTy && !STElemTy->hasName())
llvm_unreachable("Unexpected type for matrix!");
if (isSYCLHalfType(ElemTy))
ElemTyStr = "half";
if (isSYCLBfloat16Type(ElemTy))
ElemTyStr = "bfloat16";
if (ElemTyStr.size() == 0)
llvm_unreachable("Unexpected type for matrix!");
}
std::stringstream SPVName;
SPVName << kSPIRVTypeName::PrefixAndDelim
<< kSPIRVTypeName::JointMatrixINTEL << kSPIRVTypeName::Delimiter
<< kSPIRVTypeName::PostfixDelim << ElemTyStr
<< kSPIRVTypeName::PostfixDelim << std::to_string(TypeLayout[0])
<< kSPIRVTypeName::PostfixDelim << std::to_string(TypeLayout[1])
<< kSPIRVTypeName::PostfixDelim << std::to_string(TypeLayout[2] - 1)
<< kSPIRVTypeName::PostfixDelim
<< std::to_string(TypeLayout[3] - 1);
if (TypeLayout.size() == 5)
SPVName << kSPIRVTypeName::PostfixDelim
<< std::to_string(TypeLayout[4] - 1);
// Note, that this structure is not opaque and there is no way to make it
// opaque but to recreate it entirely and replace it everywhere. Lets
// keep the structure as is, dealing with it during SPIR-V generation.
ST->setName(SPVName.str());
}
}

bool SPIRVRegularizeLLVMBase::runRegularizeLLVM(Module &Module) {
M = &Module;
Ctx = &M->getContext();
Expand Down Expand Up @@ -623,9 +511,6 @@ bool SPIRVRegularizeLLVMBase::regularize() {
}
}

for (StructType *ST : M->getIdentifiedStructTypes())
adaptStructTypes(ST);

if (SPIRVDbgSaveRegularizedModule)
saveLLVMModule(M, RegularizedModuleTmpFile);
return true;
Expand Down
1 change: 0 additions & 1 deletion lib/SPIRV/SPIRVRegularizeLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ class SPIRVRegularizeLLVMBase {
Value *extendBitInstBoolArg(llvm::Instruction *OldInst);

static std::string lowerLLVMIntrinsicName(llvm::IntrinsicInst *II);
void adaptStructTypes(llvm::StructType *ST);
static char ID;

private:
Expand Down
11 changes: 3 additions & 8 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,14 +606,9 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
return TranslatedTy;
}

// Representation in LLVM IR before the translator is a pointer array wrapped
// in a structure:
// %struct.__spirv_JointMatrixINTEL = type { [R x [C x [L x [S x type]]]]* }
// where R = Rows, C = Columnts, L = Layout + 1, S = Scope + 1
// this '+1' for the Layout and Scope is required because both of them can
// be '0', but array size can not be '0'.
// The result should look like SPIR-V friendly LLVM IR:
// %spirv.JointMatrixINTEL._char_2_2_0_3
// Representation in LLVM IR before the translator is a pointer to an opaque
// structure:
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%scope%_%use%
// Here we check the structure name yet again. Another option would be to
// check SPIR-V friendly function calls (by their name) and obtain return
// or their parameter types, assuming, that the appropriate types are Matrix
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3330,6 +3330,9 @@ class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
_SPIRV_OP(JointMatrixLoad, true, 6, true)
_SPIRV_OP(JointMatrixStore, false, 5, true)
_SPIRV_OP(JointMatrixMad, true, 7)
_SPIRV_OP(JointMatrixSUMad, true, 7)
_SPIRV_OP(JointMatrixUSMad, true, 7)
_SPIRV_OP(JointMatrixUUMad, true, 7)
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
#undef _SPIRV_OP

Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ _SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
_SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)
_SPIRV_OP_INTERNAL(JointMatrixMadINTEL, internal::OpJointMatrixMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixSUMadINTEL, internal::OpJointMatrixSUMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixUSMadINTEL, internal::OpJointMatrixUSMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixUUMadINTEL, internal::OpJointMatrixUUMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL,
internal::OpJointMatrixWorkItemLengthINTEL)
_SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
Expand Down
6 changes: 6 additions & 0 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ enum InternalOp {
IOpJointMatrixLoadINTEL = 6120,
IOpJointMatrixStoreINTEL = 6121,
IOpJointMatrixMadINTEL = 6122,
IOpJointMatrixSUMadINTEL = 6128,
IOpJointMatrixUSMadINTEL = 6129,
IOpJointMatrixUUMadINTEL = 6130,
IOpArithmeticFenceINTEL = 6145,
IOpJointMatrixWorkItemLengthINTEL = 6410,
IOpComplexFMulINTEL = 6415,
Expand Down Expand Up @@ -138,6 +141,9 @@ _SPIRV_OP(Op, TypeJointMatrixINTEL)
_SPIRV_OP(Op, JointMatrixLoadINTEL)
_SPIRV_OP(Op, JointMatrixStoreINTEL)
_SPIRV_OP(Op, JointMatrixMadINTEL)
_SPIRV_OP(Op, JointMatrixSUMadINTEL)
_SPIRV_OP(Op, JointMatrixUSMadINTEL)
_SPIRV_OP(Op, JointMatrixUUMadINTEL)
_SPIRV_OP(Op, JointMatrixWorkItemLengthINTEL)
_SPIRV_OP(Capability, HWThreadQueryINTEL)
_SPIRV_OP(BuiltIn, SubDeviceIDINTEL)
Expand Down
Loading