@@ -614,7 +614,7 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
614614
615615// Representation in LLVM IR before the translator is a pointer to an opaque
616616// structure:
617- // %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%scope%_%use%
617+ // %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%layout%_% scope%_%use%
618618// Here we check the structure name yet again. Another option would be to
619619// check SPIR-V friendly function calls (by their name) and obtain return
620620// or their parameter types, assuming, that the appropriate types are Matrix
@@ -625,6 +625,18 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
625625// simply not true.
626626SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType (
627627 SmallVector<std::string, 8 > Postfixes) {
628+ auto ParseInteger = [this ](StringRef Postfix) -> ConstantInt * {
629+ unsigned long long N = 0 ;
630+ if (consumeUnsignedInteger (Postfix, 10 , N))
631+ BM->getErrorLog ().checkError (
632+ false , SPIRVEC_InvalidLlvmModule,
633+ " TypeJointMatrixINTEL expects integer parameters" );
634+ return getUInt32 (M, N);
635+ };
636+ std::vector<SPIRVValue *> Args;
637+ for (size_t I = 1 ; I != Postfixes.size (); ++I)
638+ Args.emplace_back (transConstant (ParseInteger (Postfixes[I])));
639+
628640 Type *ElemTy = nullptr ;
629641 StringRef Ty{Postfixes[0 ]};
630642 auto NumBits = llvm::StringSwitch<unsigned >(Ty)
@@ -633,32 +645,30 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
633645 .Case (" int" , 32 )
634646 .Case (" long" , 64 )
635647 .Default (0 );
636- if (NumBits)
648+ if (NumBits) {
637649 ElemTy = IntegerType::get (M->getContext (), NumBits);
638- else if (Ty == " half" )
650+ } else if (Ty == " half" ) {
639651 ElemTy = Type::getHalfTy (M->getContext ());
640- else if (Ty == " float" )
652+ } else if (Ty == " float" ) {
641653 ElemTy = Type::getFloatTy (M->getContext ());
642- else if (Ty == " double" )
654+ } else if (Ty == " double" ) {
643655 ElemTy = Type::getDoubleTy (M->getContext ());
644- else if (Ty == " bfloat16" )
656+ } else if (Ty == " bfloat16" ) {
645657 ElemTy = Type::getInt16Ty (M->getContext ());
646- else
658+ // TODO: add BF16 CTI when we do breaking change
659+ // auto *CTI = transConstant(getUInt32(M, static_cast<uint64_t>(
660+ // internal::InternalJointMatrixCTI::Bfloat16)));
661+ // Args.push_back(CTI);
662+ // BM->addCapability(internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
663+ } else if (Ty == " tf32" ) {
664+ ElemTy = Type::getFloatTy (M->getContext ());
665+ auto *CTI = transConstant (getUInt32 (
666+ M, static_cast <uint64_t >(internal::InternalJointMatrixCTI::TF32)));
667+ Args.push_back (CTI);
668+ BM->addCapability (internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
669+ } else {
647670 llvm_unreachable (" Unexpected type for matrix!" );
648-
649- auto ParseInteger = [this ](StringRef Postfix) -> ConstantInt * {
650- unsigned long long N = 0 ;
651- if (consumeUnsignedInteger (Postfix, 10 , N)) {
652- BM->getErrorLog ().checkError (
653- false , SPIRVEC_InvalidLlvmModule,
654- " TypeJointMatrixINTEL expects integer parameters" );
655- return 0 ;
656- }
657- return getUInt32 (M, N);
658- };
659- std::vector<SPIRVValue *> Args;
660- for (size_t I = 1 ; I != Postfixes.size (); ++I)
661- Args.emplace_back (transConstant (ParseInteger (Postfixes[I])));
671+ }
662672 return BM->addJointMatrixINTELType (transType (ElemTy), Args);
663673}
664674
0 commit comments