Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,14 +608,9 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
return TranslatedTy;
}

// Representation in LLVM IR before the translator is a pointer array wrapped
// in a structure:
// %struct.__spirv_JointMatrixINTEL = type { [R x [C x [L x [S x type]]]]* }
// where R = Rows, C = Columnts, L = Layout + 1, S = Scope + 1
// this '+1' for the Layout and Scope is required because both of them can
// be '0', but array size can not be '0'.
// The result should look like SPIR-V friendly LLVM IR:
// %spirv.JointMatrixINTEL._char_2_2_0_3
// Representation in LLVM IR before the translator is a pointer to an opaque
// structure:
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%scope%_%use%
// Here we check the structure name yet again. Another option would be to
// check SPIR-V friendly function calls (by their name) and obtain return
// or their parameter types, assuming, that the appropriate types are Matrix
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3324,6 +3324,9 @@ class SPIRVJointMatrixINTELInst : public SPIRVJointMatrixINTELInstBase {
_SPIRV_OP(JointMatrixLoad, true, 6, true)
_SPIRV_OP(JointMatrixStore, false, 5, true)
_SPIRV_OP(JointMatrixMad, true, 7)
_SPIRV_OP(JointMatrixSUMad, true, 7)
_SPIRV_OP(JointMatrixUSMad, true, 7)
_SPIRV_OP(JointMatrixUUMad, true, 7)
_SPIRV_OP(JointMatrixWorkItemLength, true, 4)
#undef _SPIRV_OP

Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ _SPIRV_OP_INTERNAL(TypeJointMatrixINTEL, internal::OpTypeJointMatrixINTEL)
_SPIRV_OP_INTERNAL(JointMatrixLoadINTEL, internal::OpJointMatrixLoadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixStoreINTEL, internal::OpJointMatrixStoreINTEL)
_SPIRV_OP_INTERNAL(JointMatrixMadINTEL, internal::OpJointMatrixMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixSUMadINTEL, internal::OpJointMatrixSUMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixUSMadINTEL, internal::OpJointMatrixUSMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixUUMadINTEL, internal::OpJointMatrixUUMadINTEL)
_SPIRV_OP_INTERNAL(JointMatrixWorkItemLengthINTEL,
internal::OpJointMatrixWorkItemLengthINTEL)
_SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)
Expand Down
6 changes: 6 additions & 0 deletions lib/SPIRV/libSPIRV/spirv_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ enum InternalOp {
IOpJointMatrixLoadINTEL = 6120,
IOpJointMatrixStoreINTEL = 6121,
IOpJointMatrixMadINTEL = 6122,
IOpJointMatrixSUMadINTEL = 6128,
IOpJointMatrixUSMadINTEL = 6129,
IOpJointMatrixUUMadINTEL = 6130,
IOpArithmeticFenceINTEL = 6145,
IOpJointMatrixWorkItemLengthINTEL = 6410,
IOpComplexFMulINTEL = 6415,
Expand Down Expand Up @@ -105,6 +108,9 @@ _SPIRV_OP(Op, TypeJointMatrixINTEL)
_SPIRV_OP(Op, JointMatrixLoadINTEL)
_SPIRV_OP(Op, JointMatrixStoreINTEL)
_SPIRV_OP(Op, JointMatrixMadINTEL)
_SPIRV_OP(Op, JointMatrixSUMadINTEL)
_SPIRV_OP(Op, JointMatrixUSMadINTEL)
_SPIRV_OP(Op, JointMatrixUUMadINTEL)
_SPIRV_OP(Op, JointMatrixWorkItemLengthINTEL)
_SPIRV_OP(Capability, HWThreadQueryINTEL)
_SPIRV_OP(BuiltIn, SubDeviceIDINTEL)
Expand Down
22 changes: 21 additions & 1 deletion test/extensions/INTEL/SPV_INTEL_joint_matrix/joint_matrix.ll
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
; CHECK-SPIRV: JointMatrixLoadINTEL [[#ATy]] [[#A:]] [[#Aptr:]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
; CHECK-SPIRV: JointMatrixLoadINTEL [[#BTy]] [[#B:]] [[#Bptr:]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
; CHECK-SPIRV: JointMatrixMadINTEL [[#CTy]] [[#CMad]] [[#A]] [[#B]] [[#C]] [[#Three]]
; CHECK-SPIRV: JointMatrixSUMadINTEL [[#CTy]] [[#UnusedMad1:]] [[#A]] [[#B]] [[#C]] [[#Three]]
; CHECK-SPIRV: JointMatrixUSMadINTEL [[#CTy]] [[#UnusedMad2:]] [[#A]] [[#B]] [[#C]] [[#Three]]
; CHECK-SPIRV: JointMatrixUUMadINTEL [[#CTy]] [[#UnusedMad3:]] [[#A]] [[#B]] [[#C]] [[#Three]]

; CHECK-SPIRV: JointMatrixStoreINTEL [[#Cptr:]] [[#C]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
; CHECK-SPIRV: CompositeConstruct [[#CTy]] [[#Cnew:]] [[#FortyTwo]]
; CHECK-SPIRV: Store [[#PtrToZero:]] [[#Zero]]
Expand All @@ -49,7 +53,11 @@
; CHECK-LLVM: [[C:%.*]] = phi %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [ [[CLoaded]], %entry ], [ [[CMad:%.*]], %for.body.i ]
; CHECK-LLVM: [[A:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* @_Z79__spirv_JointMatrixLoadINTEL_RPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS4cliii(i8 addrspace(4)* [[APtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
; CHECK-LLVM: [[B:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS4cliii(i8 addrspace(4)* [[BPtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
; CHECK-LLVM: [[CMad:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
; CHECK-LLVM: [[CMad1:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
; CHECK-LLVM: [[CMad2:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixSUMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
; CHECK-LLVM: [[CMad3:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixUSMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
; CHECK-LLVM: [[CMad4:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z29__spirv_JointMatrixUUMadINTELPU3AS141__spirv_JointMatrixINTEL__char_2_16_0_3_0PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)

; CHECK-LLVM: call spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS4sPU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3liii(i16 addrspace(4)* [[CPtr]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i64 [[Stride]], i32 0, i32 3, i32 0)
; CHECK-LLVM: call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z26__spirv_CompositeConstructi(i32 42)
; CHECK-LLVM: store i32 0, i32 addrspace(4)* [[StoredZero:%.*]], align 4
Expand Down Expand Up @@ -115,6 +123,9 @@ for.body.i: ; preds = %for.cond.i
%add.ptr17.i = addrspacecast i8 addrspace(1)* %add.ptr17.i56 to i8 addrspace(4)*
%call18.i = tail call spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* @_Z28__spirv_JointMatrixLoadINTELIaLm16ELm2ELN5__spv12MatrixLayoutE3ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT_XT0_EXT1_EXT2_EXT3_EEEPS5_mS1_S3_i(i8 addrspace(4)* %add.ptr17.i, i64 %_arg_1, i32 0, i32 3, i32 0) #3
%call19.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z27__spirv_JointMatrixMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
%call20.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixSUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
%call21.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUSMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
%call22.i = tail call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)* %call13.i, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)* %call18.i, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* %C.0.i, i32 3) #3
%add.i = add nuw nsw i32 %k.0.i, 16
br label %for.cond.i, !llvm.loop !19

Expand Down Expand Up @@ -142,6 +153,15 @@ declare dso_local spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*
; Function Attrs: convergent
declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z27__spirv_JointMatrixMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixSUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUSMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare dso_local spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)* @_Z29__spirv_JointMatrixUUMadINTELIasLm2ELm16ELm2ELN5__spv12MatrixLayoutE0ELS1_3ELS1_0ELNS0_5Scope4FlagE3EEPNS0_24__spirv_JointMatrixINTELIT0_XT1_EXT3_EXT6_EXT7_EEEPNS4_IT_XT1_EXT2_EXT4_EXT7_EEEPNS4_IS8_XT2_EXT3_EXT5_EXT7_EEES7_S3_(%spirv.JointMatrixINTEL._char_2_16_0_3_0 addrspace(4)*, %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i32) local_unnamed_addr #1

; Function Attrs: convergent
declare dso_local spir_func void @_Z29__spirv_JointMatrixStoreINTELIsLm2ELm2ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEEmS1_S3_i(i16 addrspace(4)*, %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(4)*, i64, i32, i32, i32) local_unnamed_addr #1

Expand Down